Should reconstruction loss be computed as sum or average over image for variational autoencoders?

4.5k views Asked by At

I am following this variational autoencoder tutorial: https://keras.io/examples/generative/vae/.

I know VAE's loss function consists of the reconstruction loss that compares the original image and reconstruction, as well as the KL loss. However, I'm a bit confused about the reconstruction loss and whether it is over the entire image (sum of squared differences) or per pixel (average sum of squared differences). My understanding is that the reconstruction loss should be per pixel (MSE), but the example code I am following multiplies MSE by 28 x 28, the MNIST image dimensions. Is that correct? Furthermore, my assumption is this would make the reconstruction loss term significantly larger than the KL loss and I'm not sure we want that.

I tried removing the multiplication by (28x28), but this resulted in extremely poor reconstructions. Essentially all the reconstructions looked the same regardless of the input. Can I use a lambda parameter to capture the tradeoff between kl divergence and reconstruction, or it that incorrect because the loss has a precise derivation (as opposed to just adding a regularization penalty).

reconstruction_loss = tf.reduce_mean(
    keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
2

There are 2 answers

0
jcrudy On

It isn't really necessary to multiply by the number of pixels. However, whether you do so or not will affect the way your fitting algorithm behaves with respect to the other hyper parameters: your lambda parameter and the learning rate. In essence, if you want to remove the multiplication by 28 x 28 but retain the same fitting behavior, you should divide lambda by 28 x 28 and then multiply your learning rate by 28 x 28. I think you were already approaching this idea in your question, and the piece you were missing is the adjustment to the learning rate.

0
waTeim On

The example

I'm familiar with that example, and I think the 28x28 multiplier is justified because of the operation tf.reduce_mean(kl_loss) which takes the average loss of all the pixels in the image which would result in a number between 0 and 1 and then multiplies it by the number of pixels. Here's another take with an external training loop for creating a VAE.

The problem is posterior collapse

The above would not be an issue since it's just multiplication by a constant if not for as you point out the KL divergence term. The KL loss acts as a regularizer that penalizes latent variable probability distributions that when sampled using a combination of Gaussians are different than the samples created by the encoder. Naturally, the question arises, how much should be reconstruction loss and how much should be the penalty. This is an area of research. Consider β-VAE which purportedly serves to disentangle representations by increasing the importance of KL-loss, on the other hand, increase β too much and you get a phenomenon known as posterior collapse Re-balancing Variational Autoencoder Loss for Molecule Sequence Generation limits β to 0.1 to avoid the problem. But it may not even be that simple as explained in The Usual Suspects? Reassessing Blame for VAE Posterior Collapse. A thorough solution is proposed in Diagnosing and Enhancing VAE Models. While Balancing reconstruction error and Kullback-Leibler divergence in Variational Autoencoders suggest that there is a more simple deterministic (and better) way.

Experimentation and Extension

For something simple like Minst, and that example, in particular, try experimenting. Keep the 28x28 term, and arbitrarily multiply kl_loss by a constant B where 0 <= B < 28*28. Follow the kl loss term and the reconstruction loss term during training and compare it to the first reference graphs.