Size differs by one for skip connection layers in unet, unable to concatenate due to size difference Keras

58 views Asked by At

I am writing a discriminator architecture for a GAN that uses unet. My encoder and decoder blocks look like this:

    for _ in range(process_blocks):
        x = tf.keras.layers.Conv3D(num_filters, 4, strides=num_strides, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

        downsample_layers.append(x)
        num_filters *= 2

    #bottleneck
    x = tf.keras.layers.Conv3D(num_filters, 4, strides=num_strides, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)

    # Upsample path
    for i in reversed(range(process_blocks)):
        x = tf.keras.layers.Conv3DTranspose(num_filters, 4, strides=num_strides, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Concatenate()([x, downsample_layers[i]])

        num_filters //= 2

I use the downsample_layers array to store the output of layers in the decoder block and attempt to concatenate them later to the outputs of the layers of the encoder blocks. The shape of my input from the generator to the discriminator is (182, 218, 182, 1) which seems to cause some problems. If the output of any layer of the decoder is a tensor in which one of the shape dimensions is odd, the concatenate will raise a value error when attempting to create the skip layer.

(182, 218, 182, 1) -> first decoder block tensor shape output -> (None, 91, 109, 91, 8)
(None, 91, 109, 91, 8) -> second decoder block -> (None, 46, 55, 46, 16)
(None, 46, 55, 46, 8) -> third decoder block -> (None, 23, 28, 23, 32)
(None, 23, 28, 23, 8) -> fourth decoder block -> (None, 12, 14, 12, 64)
(None, 12, 14, 12, 8) -> bottleneck output -> (None, 6, 7, 6, 64)

When these tensors are upsampled again, their shape should be the same along all axis besides the last axis as the tensors output during the downsampling path. However you can see that when the encoder block upsamples twice:

(None, 6, 7, 6, 64) -> first encoder block -> (None, 12, 14, 12, 64)
(None, 12, 14, 12, 64) -> second encoder block -> (None, 24, 28, 24, 32)

Now, the (None, 24, 28, 24, 32) output tensors attempts to concatenate with the (None, 23, 28, 23, 32) tensor for the skip connection, resulting in a shape error when calling concatenate.

Is there any way to deal with this off by one issue? I have tried figuring out a combination of changing stride values (current is 2), kernel sizes and number of encoder decoder blocks but there doesn't seem to be a combo that will fix this issue of one of the downsampling layers by off by one with the skip layer connection for the upsampling path.

0

There are 0 answers