SSIM function in TensorFlow 2.x

630 views Asked by At

I'm trying to use tf.image.ssim() as my loss function for training my model and I explored a liitle bit how people have implemented it. Here they are:

  1. Working with SSIM loss function in tensorflow for RGB images
  2. Use SSIM loss function with Keras

I have couple questions:

  1. In both these threads, the dynamic range suggested for tf.image.ssim is 2 when the inputs are normalized between -1 and 1. But I have ran a small sanity check to see if it works or not. Here is the code:
from PIL import Image
import numpy as np
from skimage.util import random_noise
import matplotlib.pyplot as plt
import tensorflow as tf

im = Image.open('E:\\DATA\\train_image_(124).png')
im_arr = np.asarray(im) # convert PIL Image to ndarray

noise_img = random_noise(im_arr, mode='gaussian', var=0.0005) # random_noise() method will convert image in [0, 255] to [0, 1.0]
noise_img = (255*noise_img).astype(np.uint8)

img = Image.fromarray(noise_img)

#normalizing between 0 and 1 and reshaping for SSIM calculation
x = np.reshape((np.asarray(im)/255), [256, 256, 1])
y = np.reshape((np.asarray(img)/255), [256, 256, 1])

#normalizing between -1 and 1 and reshaping for SSIM calculation
x_a = np.reshape((2*(np.asarray(im)/255) - 1), [256, 256, 1])
y_a = np.reshape((2*(np.asarray(img)/255) - 1), [256, 256, 1])

print('No norm: ', str(tf.image.ssim(np.reshape(im_arr, [256, 256, 1]), np.reshape(noise_img, [256, 256, 1]), 255).numpy()))
print('Norm_01: ', str(tf.image.ssim(x, y, 1).numpy()))
print('Norm_11: ', str(tf.image.ssim(x_a, y_a, 2).numpy()))

To my understanding, all 3 print statement should give the same value of SSIM, but they don't. When the range is 0 to 1 and 0 to 255, the SSIM result is same, but with the range of - 1 to 1, it is different. To double check, I have calculated the SSIM in MATLAB too, and that nearly agrees with the first two cases. So, is there any other way to compute SSIM/use SSIM as a loss function in TF2? I did the same experiment with compare_ssim from skimage, but that one seems to have same result. Am I missing something?

  1. Also, when I am using tf.reduce_mean(tf.keras.losses.mean_squared_error(target, gen_output)) as my loss function, everything is okay. But when I am using tf.reduce_mean(tf.image.ssim(x, y, dynamic_range) as the loss function, I'm getting NaN values. Both the threads mentioned above uses either tensorflow 1.x or model.fit on tensorflow2.x for training while I am using tf.GradientTape() to calculate the gradient and update the weights. Is it possible that the GradientTape function is responsible for returning the NaN values? If so, why and what could be a possible solution?
0

There are 0 answers