MultiClass Image Segmentation

238 views Asked by At

I have trained a Unet model for multiclass image segmentation. I have four class and this is code:

def Unet(input_shape=(128,128, 3),
                 num_classes=4):
    inputs = Input(shape=input_shape)

    down1 = Conv2D(64, (3, 3), padding='same')(inputs)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1 = Conv2D(64, (3, 3), padding='same')(down1)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)

    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2 = Conv2D(128, (3, 3), padding='same')(down2)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)

    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3 = Conv2D(256, (3, 3), padding='same')(down3)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)

    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4 = Conv2D(512, (3, 3), padding='same')(down4)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)

    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    center = Conv2D(1024, (3, 3), padding='same')(center)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)

    up4 = UpSampling2D((2, 2))(center)
    up4 = concatenate([down4, up4], axis=3)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)

    up3 = UpSampling2D((2, 2))(up4)
    up3 = concatenate([down3, up3], axis=3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)

    up2 = UpSampling2D((2, 2))(up3)
    up2 = concatenate([down2, up2], axis=3)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)

    up1 = UpSampling2D((2, 2))(up2)
    up1 = concatenate([down1, up1], axis=3)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)

    
    classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1)

    model = Model(inputs=inputs, outputs=classify)
    lr = 1e-4
    model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss="sparse_categorical_crossentropy", metrics=['accuracy'])

    return model

model = Unet()

Model summary is

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 128, 128, 64) 1792        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 128, 128, 64) 256         conv2d_69[0][0]                  
__________________________________________________________________________________________________
activation_66 (Activation)      (None, 128, 128, 64) 0           batch_normalization_66[0][0]     
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 128, 128, 64) 36928       activation_66[0][0]              
__________________________________________________________________________________________________
batch_normalization_67 (BatchNo (None, 128, 128, 64) 256         conv2d_70[0][0]                  
__________________________________________________________________________________________________
activation_67 (Activation)      (None, 128, 128, 64) 0           batch_normalization_67[0][0]     
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 64)   0           activation_67[0][0]              
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 64, 64, 128)  73856       max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
batch_normalization_68 (BatchNo (None, 64, 64, 128)  512         conv2d_71[0][0]                  
__________________________________________________________________________________________________
activation_68 (Activation)      (None, 64, 64, 128)  0           batch_normalization_68[0][0]     
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 64, 64, 128)  147584      activation_68[0][0]              
__________________________________________________________________________________________________
batch_normalization_69 (BatchNo (None, 64, 64, 128)  512         conv2d_72[0][0]                  
__________________________________________________________________________________________________
activation_69 (Activation)      (None, 64, 64, 128)  0           batch_normalization_69[0][0]     
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 128)  0           activation_69[0][0]              
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 32, 32, 256)  295168      max_pooling2d_13[0][0]           
__________________________________________________________________________________________________
batch_normalization_70 (BatchNo (None, 32, 32, 256)  1024        conv2d_73[0][0]                  
__________________________________________________________________________________________________
activation_70 (Activation)      (None, 32, 32, 256)  0           batch_normalization_70[0][0]     
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 32, 32, 256)  590080      activation_70[0][0]              
__________________________________________________________________________________________________
batch_normalization_71 (BatchNo (None, 32, 32, 256)  1024        conv2d_74[0][0]                  
__________________________________________________________________________________________________
activation_71 (Activation)      (None, 32, 32, 256)  0           batch_normalization_71[0][0]     
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 256)  0           activation_71[0][0]              
__________________________________________________________________________________________________
conv2d_75 (Conv2D)              (None, 16, 16, 512)  1180160     max_pooling2d_14[0][0]           
__________________________________________________________________________________________________
batch_normalization_72 (BatchNo (None, 16, 16, 512)  2048        conv2d_75[0][0]                  
__________________________________________________________________________________________________
activation_72 (Activation)      (None, 16, 16, 512)  0           batch_normalization_72[0][0]     
__________________________________________________________________________________________________
conv2d_76 (Conv2D)              (None, 16, 16, 512)  2359808     activation_72[0][0]              
__________________________________________________________________________________________________
batch_normalization_73 (BatchNo (None, 16, 16, 512)  2048        conv2d_76[0][0]                  
__________________________________________________________________________________________________
activation_73 (Activation)      (None, 16, 16, 512)  0           batch_normalization_73[0][0]     
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 512)    0           activation_73[0][0]              
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 8, 8, 1024)   4719616     max_pooling2d_15[0][0]           
__________________________________________________________________________________________________
batch_normalization_74 (BatchNo (None, 8, 8, 1024)   4096        conv2d_77[0][0]                  
__________________________________________________________________________________________________
activation_74 (Activation)      (None, 8, 8, 1024)   0           batch_normalization_74[0][0]     
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 8, 8, 1024)   9438208     activation_74[0][0]              
__________________________________________________________________________________________________
batch_normalization_75 (BatchNo (None, 8, 8, 1024)   4096        conv2d_78[0][0]                  
__________________________________________________________________________________________________
activation_75 (Activation)      (None, 8, 8, 1024)   0           batch_normalization_75[0][0]     
__________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, 16, 16, 1024) 0           activation_75[0][0]              
__________________________________________________________________________________________________
concatenate_12 (Concatenate)    (None, 16, 16, 1536) 0           activation_73[0][0]              
                                                                 up_sampling2d_12[0][0]           
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 16, 16, 512)  7078400     concatenate_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_76 (BatchNo (None, 16, 16, 512)  2048        conv2d_79[0][0]                  
__________________________________________________________________________________________________
activation_76 (Activation)      (None, 16, 16, 512)  0           batch_normalization_76[0][0]     
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 16, 16, 512)  2359808     activation_76[0][0]              
__________________________________________________________________________________________________
batch_normalization_77 (BatchNo (None, 16, 16, 512)  2048        conv2d_80[0][0]                  
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 16, 16, 512)  0           batch_normalization_77[0][0]     
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 16, 16, 512)  2359808     activation_77[0][0]              
__________________________________________________________________________________________________
batch_normalization_78 (BatchNo (None, 16, 16, 512)  2048        conv2d_81[0][0]                  
__________________________________________________________________________________________________
activation_78 (Activation)      (None, 16, 16, 512)  0           batch_normalization_78[0][0]     
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 32, 32, 512)  0           activation_78[0][0]              
__________________________________________________________________________________________________
concatenate_13 (Concatenate)    (None, 32, 32, 768)  0           activation_71[0][0]              
                                                                 up_sampling2d_13[0][0]           
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 32, 32, 256)  1769728     concatenate_13[0][0]             
__________________________________________________________________________________________________
batch_normalization_79 (BatchNo (None, 32, 32, 256)  1024        conv2d_82[0][0]                  
__________________________________________________________________________________________________
activation_79 (Activation)      (None, 32, 32, 256)  0           batch_normalization_79[0][0]     
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 32, 32, 256)  590080      activation_79[0][0]              
__________________________________________________________________________________________________
batch_normalization_80 (BatchNo (None, 32, 32, 256)  1024        conv2d_83[0][0]                  
__________________________________________________________________________________________________
activation_80 (Activation)      (None, 32, 32, 256)  0           batch_normalization_80[0][0]     
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 32, 32, 256)  590080      activation_80[0][0]              
__________________________________________________________________________________________________
batch_normalization_81 (BatchNo (None, 32, 32, 256)  1024        conv2d_84[0][0]                  
__________________________________________________________________________________________________
activation_81 (Activation)      (None, 32, 32, 256)  0           batch_normalization_81[0][0]     
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 64, 64, 256)  0           activation_81[0][0]              
__________________________________________________________________________________________________
concatenate_14 (Concatenate)    (None, 64, 64, 384)  0           activation_69[0][0]              
                                                                 up_sampling2d_14[0][0]           
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 64, 64, 128)  442496      concatenate_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_82 (BatchNo (None, 64, 64, 128)  512         conv2d_85[0][0]                  
__________________________________________________________________________________________________
activation_82 (Activation)      (None, 64, 64, 128)  0           batch_normalization_82[0][0]     
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 64, 64, 128)  147584      activation_82[0][0]              
__________________________________________________________________________________________________
batch_normalization_83 (BatchNo (None, 64, 64, 128)  512         conv2d_86[0][0]                  
__________________________________________________________________________________________________
activation_83 (Activation)      (None, 64, 64, 128)  0           batch_normalization_83[0][0]     
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 64, 64, 128)  147584      activation_83[0][0]              
__________________________________________________________________________________________________
batch_normalization_84 (BatchNo (None, 64, 64, 128)  512         conv2d_87[0][0]                  
__________________________________________________________________________________________________
activation_84 (Activation)      (None, 64, 64, 128)  0           batch_normalization_84[0][0]     
__________________________________________________________________________________________________
up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0           activation_84[0][0]              
__________________________________________________________________________________________________
concatenate_15 (Concatenate)    (None, 128, 128, 192 0           activation_67[0][0]              
                                                                 up_sampling2d_15[0][0]           
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 128, 128, 64) 110656      concatenate_15[0][0]             
__________________________________________________________________________________________________
batch_normalization_85 (BatchNo (None, 128, 128, 64) 256         conv2d_88[0][0]                  
__________________________________________________________________________________________________
activation_85 (Activation)      (None, 128, 128, 64) 0           batch_normalization_85[0][0]     
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 128, 128, 64) 36928       activation_85[0][0]              
__________________________________________________________________________________________________
batch_normalization_86 (BatchNo (None, 128, 128, 64) 256         conv2d_89[0][0]                  
__________________________________________________________________________________________________
activation_86 (Activation)      (None, 128, 128, 64) 0           batch_normalization_86[0][0]     
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 128, 128, 64) 36928       activation_86[0][0]              
__________________________________________________________________________________________________
batch_normalization_87 (BatchNo (None, 128, 128, 64) 256         conv2d_90[0][0]                  
__________________________________________________________________________________________________
activation_87 (Activation)      (None, 128, 128, 64) 0           batch_normalization_87[0][0]     
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 128, 128, 4)  260         activation_87[0][0]              
==================================================================================================
Total params: 34,540,932
Trainable params: 34,527,236
Non-trainable params: 13,696

I trained the model with my images and mask images. My question is how can I identify the class number when model predicts for test images. For example, the Model has predicted a test image 'imageabc1.png' and I must find different 4 classes mask belongs the image. How can I do that?

1

There are 1 answers

2
LB95 On

As you are using softmax, the class number corresponds to the channel in the output of the network. Suppose you have predicted a training image:

pred = model.predict(training_image)

Class one is found as pred[:,:,:,0], Class two is found as pred[:,:,:,1] , etc.