I'm trying to create an image denoising ConvNet in Keras and I want to create my own loss function. I want it to take a noisy image as an input and to get the noise as an output. This loss function is pretty much like a MSE loss but which will make my network learn to remove the clean image and not the noise from the input noisy image.
The loss function I want to implement with y the noisy image, x the clean image and R(y) the predicted image:
I've tried to make it by myself but I don't know how to make the loss access to my noisy images since it changes all the time.
def residual_loss(noisy_img):
def loss(y_true, y_pred):
return np.mean(np.square(y_pred - (noisy_img - y_true), axis=-1)
return loss
Basically, what I need to do is something like this :
input_img = Input(shape=(None,None,3))
c1 = Convolution2D(64, (3, 3))(input_img)
a1 = Activation('relu')(c1)
c2 = Convolution2D(64, (3, 3))(a1)
a2 = Activation('relu')(c2)
c3 = Convolution2D(64, (3, 3))(a2)
a3 = Activation('relu')(c3)
c4 = Convolution2D(64, (3, 3))(a3)
a4 = Activation('relu')(c4)
c5 = Convolution2D(3, (3, 3))(a4)
out = Activation('relu')(c5)
model = Model(input_img, out)
model.compile(optimizer='adam', loss=residual_loss(input_img))
But if I try this, I get :
IndexError: tuple index out of range
What can I do ?
Since it's quite unusual to use the "input" in the loss function (it's not meant for that), I think it's worth saying:
It's not the role of the loss function to separate the noise. The loss function is just a measure of "how far from right you are".
It's your model that will separate things, and the results you expect from your model are
y_true
.You should use a regular loss, with
X_training = noisy images
andY_training = noises
.That said...
You can create a tensor for
noisy_img
outside the loss function and keep it stored. All operations inside a loss function must be tensor functions, so use the keras backend for that:But you must take batch sizes into account, this var being outside the loss function will need you to fit just one batch per epoch.
Training one batch per epoch:
For using just a mean squared error, organize your data like this:
Now you just use a "mse", or any other built-in loss.