Loss not evolving when using skip connections

575 views Asked by At

I'm trying to implement this paper in Keras : https://arxiv.org/pdf/1603.09056.pdf which uses Conv-Deconv with skip connections to create an image denoising network. My network is working pretty well if I make symmetrical skip connections between corresponding Conv-Deconv layers but if I add a connection between the input and the output (like in the paper), my network is impossible to train. Is it me that don't understand the paper ?

"However, our network learns for the additive corruption from the input since there is a skip connection between the input and the output of the network"

Here is the network described in the paper :

enter image description here

And here is my network :

input_img = Input(shape=(None,None,3))

############################
####### CONVOLUTIONS #######
############################

c1 = Convolution2D(64, (3, 3))(input_img)
a1 = Activation('relu')(c1)

c2 = Convolution2D(64, (3, 3))(a1)
a2 = Activation('relu')(c2)

c3 = Convolution2D(64, (3, 3))(a2)
a3 = Activation('relu')(c3)

c4 = Convolution2D(64, (3, 3))(a3)
a4 = Activation('relu')(c4)

c5 = Convolution2D(64, (3, 3))(a4)
a5 = Activation('relu')(c5)

############################
###### DECONVOLUTIONS ######
############################

d1 = Conv2DTranspose(64, (3, 3))(a5)
a6 = Activation('relu')(d1)

m1 = add([a4, a6])
a7 = Activation('relu')(m1)

d2 = Conv2DTranspose(64, (3, 3))(a7)
a8 = Activation('relu')(d2)

m2 = add([a3, a8])
a9 = Activation('relu')(m2)

d3 = Conv2DTranspose(64, (3, 3))(a9)
a10 = Activation('relu')(d3)

m3 = add([a2, a10])
a11 = Activation('relu')(m3)

d4 = Conv2DTranspose(64, (3, 3))(a11)
a12 = Activation('relu')(d4) 

m4 = add([a1, a12])
a13 = Activation('relu')(m4)

d5 = Conv2DTranspose(3, (3, 3))(a13)
a14 = Activation('relu')(d5)

m5 = add([input_img, a14]) # Everything goes well without this line
out = Activation('relu')(m5)

model = Model(input_img, out) 
model.compile(optimizer='adam', loss='mse')

If I train it, here is what I get :

Epoch 1/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 - val_loss: 0.0015
Epoch 2/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 -  val_loss: 0.0015
Epoch 3/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 -   val_loss: 0.0015
Epoch 4/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

31257/31257 [==============================] - 89s - loss: 0.0015 - val_loss: 0.0015
Epoch 5/10
31250/31257 [============================>.] - ETA: 0s - loss: 0.0015
Current PSNR: 28.1152534485

What is wrong with my network ?

1

There are 1 answers

0
Daniel Möller On

The activation 'relu' never returns a negative value.

Since you're adding the input to the output (a14) and you need to "denoise" (remove noise), it's certainly expected that the output (a14) contains both positive and negative values. (You want to darken light spots and lighten dark spots).

Because of that, the activation in a14 cannot be 'relu'. It must be something both positive and negative, and capable of reaching the range of the noise. Probably a 'tanh' or a custom activation. If your input goes from 0 to 1, a 'tanh' would probably be the best option.

(Not sure about the previous layers, perhaps a few of them using 'tanh' would make the process easier)


Sometimes those long convolutional networks do get stuck, I'm training a U-net here, and it took a while to make it converge. When it gets stuck, sometimes it's better to build the model again (new weight initalizations) and try over.

See details here: How to build a multi-class convolutional neural network with Keras