I have trained a Unet model for multiclass image segmentation. I have four class and this is code:
def Unet(input_shape=(128,128, 3),
num_classes=4):
inputs = Input(shape=input_shape)
down1 = Conv2D(64, (3, 3), padding='same')(inputs)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1 = Conv2D(64, (3, 3), padding='same')(down1)
down1 = BatchNormalization()(down1)
down1 = Activation('relu')(down1)
down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2 = Conv2D(128, (3, 3), padding='same')(down2)
down2 = BatchNormalization()(down2)
down2 = Activation('relu')(down2)
down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3 = Conv2D(256, (3, 3), padding='same')(down3)
down3 = BatchNormalization()(down3)
down3 = Activation('relu')(down3)
down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)
down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4 = Conv2D(512, (3, 3), padding='same')(down4)
down4 = BatchNormalization()(down4)
down4 = Activation('relu')(down4)
down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)
center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
center = BatchNormalization()(center)
center = Activation('relu')(center)
center = Conv2D(1024, (3, 3), padding='same')(center)
center = BatchNormalization()(center)
center = Activation('relu')(center)
up4 = UpSampling2D((2, 2))(center)
up4 = concatenate([down4, up4], axis=3)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up4 = Conv2D(512, (3, 3), padding='same')(up4)
up4 = BatchNormalization()(up4)
up4 = Activation('relu')(up4)
up3 = UpSampling2D((2, 2))(up4)
up3 = concatenate([down3, up3], axis=3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up3 = Conv2D(256, (3, 3), padding='same')(up3)
up3 = BatchNormalization()(up3)
up3 = Activation('relu')(up3)
up2 = UpSampling2D((2, 2))(up3)
up2 = concatenate([down2, up2], axis=3)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up2 = Conv2D(128, (3, 3), padding='same')(up2)
up2 = BatchNormalization()(up2)
up2 = Activation('relu')(up2)
up1 = UpSampling2D((2, 2))(up2)
up1 = concatenate([down1, up1], axis=3)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
up1 = Conv2D(64, (3, 3), padding='same')(up1)
up1 = BatchNormalization()(up1)
up1 = Activation('relu')(up1)
classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1)
model = Model(inputs=inputs, outputs=classify)
lr = 1e-4
model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss="sparse_categorical_crossentropy", metrics=['accuracy'])
return model
model = Unet()
Model summary is
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
conv2d_69 (Conv2D) (None, 128, 128, 64) 1792 input_4[0][0]
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 128, 128, 64) 256 conv2d_69[0][0]
__________________________________________________________________________________________________
activation_66 (Activation) (None, 128, 128, 64) 0 batch_normalization_66[0][0]
__________________________________________________________________________________________________
conv2d_70 (Conv2D) (None, 128, 128, 64) 36928 activation_66[0][0]
__________________________________________________________________________________________________
batch_normalization_67 (BatchNo (None, 128, 128, 64) 256 conv2d_70[0][0]
__________________________________________________________________________________________________
activation_67 (Activation) (None, 128, 128, 64) 0 batch_normalization_67[0][0]
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 64) 0 activation_67[0][0]
__________________________________________________________________________________________________
conv2d_71 (Conv2D) (None, 64, 64, 128) 73856 max_pooling2d_12[0][0]
__________________________________________________________________________________________________
batch_normalization_68 (BatchNo (None, 64, 64, 128) 512 conv2d_71[0][0]
__________________________________________________________________________________________________
activation_68 (Activation) (None, 64, 64, 128) 0 batch_normalization_68[0][0]
__________________________________________________________________________________________________
conv2d_72 (Conv2D) (None, 64, 64, 128) 147584 activation_68[0][0]
__________________________________________________________________________________________________
batch_normalization_69 (BatchNo (None, 64, 64, 128) 512 conv2d_72[0][0]
__________________________________________________________________________________________________
activation_69 (Activation) (None, 64, 64, 128) 0 batch_normalization_69[0][0]
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 128) 0 activation_69[0][0]
__________________________________________________________________________________________________
conv2d_73 (Conv2D) (None, 32, 32, 256) 295168 max_pooling2d_13[0][0]
__________________________________________________________________________________________________
batch_normalization_70 (BatchNo (None, 32, 32, 256) 1024 conv2d_73[0][0]
__________________________________________________________________________________________________
activation_70 (Activation) (None, 32, 32, 256) 0 batch_normalization_70[0][0]
__________________________________________________________________________________________________
conv2d_74 (Conv2D) (None, 32, 32, 256) 590080 activation_70[0][0]
__________________________________________________________________________________________________
batch_normalization_71 (BatchNo (None, 32, 32, 256) 1024 conv2d_74[0][0]
__________________________________________________________________________________________________
activation_71 (Activation) (None, 32, 32, 256) 0 batch_normalization_71[0][0]
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 256) 0 activation_71[0][0]
__________________________________________________________________________________________________
conv2d_75 (Conv2D) (None, 16, 16, 512) 1180160 max_pooling2d_14[0][0]
__________________________________________________________________________________________________
batch_normalization_72 (BatchNo (None, 16, 16, 512) 2048 conv2d_75[0][0]
__________________________________________________________________________________________________
activation_72 (Activation) (None, 16, 16, 512) 0 batch_normalization_72[0][0]
__________________________________________________________________________________________________
conv2d_76 (Conv2D) (None, 16, 16, 512) 2359808 activation_72[0][0]
__________________________________________________________________________________________________
batch_normalization_73 (BatchNo (None, 16, 16, 512) 2048 conv2d_76[0][0]
__________________________________________________________________________________________________
activation_73 (Activation) (None, 16, 16, 512) 0 batch_normalization_73[0][0]
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 512) 0 activation_73[0][0]
__________________________________________________________________________________________________
conv2d_77 (Conv2D) (None, 8, 8, 1024) 4719616 max_pooling2d_15[0][0]
__________________________________________________________________________________________________
batch_normalization_74 (BatchNo (None, 8, 8, 1024) 4096 conv2d_77[0][0]
__________________________________________________________________________________________________
activation_74 (Activation) (None, 8, 8, 1024) 0 batch_normalization_74[0][0]
__________________________________________________________________________________________________
conv2d_78 (Conv2D) (None, 8, 8, 1024) 9438208 activation_74[0][0]
__________________________________________________________________________________________________
batch_normalization_75 (BatchNo (None, 8, 8, 1024) 4096 conv2d_78[0][0]
__________________________________________________________________________________________________
activation_75 (Activation) (None, 8, 8, 1024) 0 batch_normalization_75[0][0]
__________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, 16, 16, 1024) 0 activation_75[0][0]
__________________________________________________________________________________________________
concatenate_12 (Concatenate) (None, 16, 16, 1536) 0 activation_73[0][0]
up_sampling2d_12[0][0]
__________________________________________________________________________________________________
conv2d_79 (Conv2D) (None, 16, 16, 512) 7078400 concatenate_12[0][0]
__________________________________________________________________________________________________
batch_normalization_76 (BatchNo (None, 16, 16, 512) 2048 conv2d_79[0][0]
__________________________________________________________________________________________________
activation_76 (Activation) (None, 16, 16, 512) 0 batch_normalization_76[0][0]
__________________________________________________________________________________________________
conv2d_80 (Conv2D) (None, 16, 16, 512) 2359808 activation_76[0][0]
__________________________________________________________________________________________________
batch_normalization_77 (BatchNo (None, 16, 16, 512) 2048 conv2d_80[0][0]
__________________________________________________________________________________________________
activation_77 (Activation) (None, 16, 16, 512) 0 batch_normalization_77[0][0]
__________________________________________________________________________________________________
conv2d_81 (Conv2D) (None, 16, 16, 512) 2359808 activation_77[0][0]
__________________________________________________________________________________________________
batch_normalization_78 (BatchNo (None, 16, 16, 512) 2048 conv2d_81[0][0]
__________________________________________________________________________________________________
activation_78 (Activation) (None, 16, 16, 512) 0 batch_normalization_78[0][0]
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 32, 32, 512) 0 activation_78[0][0]
__________________________________________________________________________________________________
concatenate_13 (Concatenate) (None, 32, 32, 768) 0 activation_71[0][0]
up_sampling2d_13[0][0]
__________________________________________________________________________________________________
conv2d_82 (Conv2D) (None, 32, 32, 256) 1769728 concatenate_13[0][0]
__________________________________________________________________________________________________
batch_normalization_79 (BatchNo (None, 32, 32, 256) 1024 conv2d_82[0][0]
__________________________________________________________________________________________________
activation_79 (Activation) (None, 32, 32, 256) 0 batch_normalization_79[0][0]
__________________________________________________________________________________________________
conv2d_83 (Conv2D) (None, 32, 32, 256) 590080 activation_79[0][0]
__________________________________________________________________________________________________
batch_normalization_80 (BatchNo (None, 32, 32, 256) 1024 conv2d_83[0][0]
__________________________________________________________________________________________________
activation_80 (Activation) (None, 32, 32, 256) 0 batch_normalization_80[0][0]
__________________________________________________________________________________________________
conv2d_84 (Conv2D) (None, 32, 32, 256) 590080 activation_80[0][0]
__________________________________________________________________________________________________
batch_normalization_81 (BatchNo (None, 32, 32, 256) 1024 conv2d_84[0][0]
__________________________________________________________________________________________________
activation_81 (Activation) (None, 32, 32, 256) 0 batch_normalization_81[0][0]
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 64, 64, 256) 0 activation_81[0][0]
__________________________________________________________________________________________________
concatenate_14 (Concatenate) (None, 64, 64, 384) 0 activation_69[0][0]
up_sampling2d_14[0][0]
__________________________________________________________________________________________________
conv2d_85 (Conv2D) (None, 64, 64, 128) 442496 concatenate_14[0][0]
__________________________________________________________________________________________________
batch_normalization_82 (BatchNo (None, 64, 64, 128) 512 conv2d_85[0][0]
__________________________________________________________________________________________________
activation_82 (Activation) (None, 64, 64, 128) 0 batch_normalization_82[0][0]
__________________________________________________________________________________________________
conv2d_86 (Conv2D) (None, 64, 64, 128) 147584 activation_82[0][0]
__________________________________________________________________________________________________
batch_normalization_83 (BatchNo (None, 64, 64, 128) 512 conv2d_86[0][0]
__________________________________________________________________________________________________
activation_83 (Activation) (None, 64, 64, 128) 0 batch_normalization_83[0][0]
__________________________________________________________________________________________________
conv2d_87 (Conv2D) (None, 64, 64, 128) 147584 activation_83[0][0]
__________________________________________________________________________________________________
batch_normalization_84 (BatchNo (None, 64, 64, 128) 512 conv2d_87[0][0]
__________________________________________________________________________________________________
activation_84 (Activation) (None, 64, 64, 128) 0 batch_normalization_84[0][0]
__________________________________________________________________________________________________
up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0 activation_84[0][0]
__________________________________________________________________________________________________
concatenate_15 (Concatenate) (None, 128, 128, 192 0 activation_67[0][0]
up_sampling2d_15[0][0]
__________________________________________________________________________________________________
conv2d_88 (Conv2D) (None, 128, 128, 64) 110656 concatenate_15[0][0]
__________________________________________________________________________________________________
batch_normalization_85 (BatchNo (None, 128, 128, 64) 256 conv2d_88[0][0]
__________________________________________________________________________________________________
activation_85 (Activation) (None, 128, 128, 64) 0 batch_normalization_85[0][0]
__________________________________________________________________________________________________
conv2d_89 (Conv2D) (None, 128, 128, 64) 36928 activation_85[0][0]
__________________________________________________________________________________________________
batch_normalization_86 (BatchNo (None, 128, 128, 64) 256 conv2d_89[0][0]
__________________________________________________________________________________________________
activation_86 (Activation) (None, 128, 128, 64) 0 batch_normalization_86[0][0]
__________________________________________________________________________________________________
conv2d_90 (Conv2D) (None, 128, 128, 64) 36928 activation_86[0][0]
__________________________________________________________________________________________________
batch_normalization_87 (BatchNo (None, 128, 128, 64) 256 conv2d_90[0][0]
__________________________________________________________________________________________________
activation_87 (Activation) (None, 128, 128, 64) 0 batch_normalization_87[0][0]
__________________________________________________________________________________________________
conv2d_91 (Conv2D) (None, 128, 128, 4) 260 activation_87[0][0]
==================================================================================================
Total params: 34,540,932
Trainable params: 34,527,236
Non-trainable params: 13,696
I trained the model with my images and mask images. My question is how can I identify the class number when model predicts for test images. For example, the Model has predicted a test image 'imageabc1.png' and I must find different 4 classes mask belongs the image. How can I do that?
As you are using softmax, the class number corresponds to the channel in the output of the network. Suppose you have predicted a training image:
Class one is found as
pred[:,:,:,0]
, Class two is found aspred[:,:,:,1]
, etc.