Keras: Which loss function for grayscale image as label

1.2k views Asked by At

I am new to Deep Learning and Keras and Image Processing. I am working on a project in which I try to compensate motion artifacts in grayscale images using CNNs. Thus, I have a grayscale image as label that has no motion artifacts.

But now I am not sure which loss function and what kind of error metric to use. Maybe I need some kind of 2D cross-correlation loss function? Or does a loss function like mean squared error make sense? A first training with 'mean squared logarithmic error' produced visually good results (prediction looked a lot like the label image) but the accuracy of the CNN was like 0%.

Does someone has experience in that area and can recommend some literature or suggest a suitable loss function and error metric!?

If I need to provide more detailed information, just let me know and I am more than happy to do so.

The used CNN (somewhat like Unet):

input_1 = Input((X_train.shape[1],X_train.shape[2], X_train.shape[3]))

conv1 = Conv2D(16, (3,3), strides=(2,2), activation='relu', padding='same')(input_1)
batch1 = BatchNormalization(axis=3)(conv1)
conv2 = Conv2D(32, (3,3), strides=(2,2), activation='relu', padding='same')(batch1)
batch2 = BatchNormalization(axis=3)(conv2)
conv3 = Conv2D(64, (3,3), strides=(2,2), activation='relu', padding='same')(batch2)
batch3 = BatchNormalization(axis=3)(conv3)
conv4 = Conv2D(128, (3,3), strides=(2,2), activation='relu', padding='same')(batch3)
batch4 = BatchNormalization(axis=3)(conv4)
conv5 = Conv2D(256, (3,3), strides=(2,2), activation='relu', padding='same')(batch4)
batch5 = BatchNormalization(axis=3)(conv5)
conv6 = Conv2D(512, (3,3), strides=(2,2), activation='relu', padding='same')(batch5)
drop1 = Dropout(0.25)(conv6)

upconv1 = Conv2DTranspose(256, (3,3), strides=(1,1), padding='same')(drop1)
upconv2 = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same')(upconv1)
upconv3 = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same')(upconv2)
upconv4 = Conv2DTranspose(32, (3,3), strides=(2,2), padding='same')(upconv3)
upconv5 = Conv2DTranspose(16, (3,3), strides=(2,2), padding='same')(upconv4)
upconv5_1 = concatenate([upconv5,conv2], axis=3)
upconv6 = Conv2DTranspose(8, (3,3), strides=(2,2), padding='same')(upconv5_1)
upconv6_1 = concatenate([upconv6,conv1], axis=3)
upconv7 = Conv2DTranspose(1, (3,3), strides=(2,2), activation='linear', padding='same')(upconv6_1)

model = Model(outputs=upconv7, inputs=input_1)

Thanks for your help!

0

There are 0 answers