I'm using TensorFlow's pix2pix: Image-to-image translation with a conditional GAN notebook to train a model for my dataset that consists of multispectral satellite images with 12 bands like 512 x 512 x 12.
As the original notebook is used for images with 256 x 256 x 3 dimensions, I had to apply some changes to a few of the important functions like load(), Generator(), Discriminator() and generate_images().
My images have more than 3 bands (channels), specifically 12 bands. So, I have stored each 3 bands as an RBB (.png) image with a depth of 16 bit. That's why I have changed my load function in a way that it reads all four images that form a 12-band image and creates inputs with the shape: 512 x 512 x 12.
The Generator() function was also changed like below:
def Generator():
input_shape = (512, 512, 12)
inputs = tf.keras.layers.Input(shape=input_shape)
# Downsampling layers
down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # First layer
    downsample(128, 4),  # Batchnorm applied
    downsample(256, 4),
    downsample(512, 4),
    downsample(512, 4),
    downsample(512, 4),
    downsample(512, 4),
    downsample(512, 4)
]
# Upsampling layers
up_stack = [
    upsample(512, 4, apply_dropout=True),  # Apply dropout in the first 3 layers
    upsample(512, 4, apply_dropout=True),
    upsample(512, 4, apply_dropout=True),
    upsample(512, 4),
    upsample(256, 4),
    upsample(128, 4),
    upsample(64, 4)
]
initializer = tf.random_normal_initializer(0., 0.02)
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2,
                                       padding='same',
                                       kernel_initializer=initializer,
                                       activation='tanh')  # OUTPUT_CHANNELS should match your requirement
x = inputs
# Downsampling through the model
skips = []
for down in down_stack:
    x = down(x)
    skips.append(x)
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
    x = up(x)
    if skip is not None:
        x = tf.keras.layers.Concatenate()([x, skip])
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
And I changed the Discriminator() like below:
def Discriminator():
initializer = tf.random_normal_initializer(0., 0.02)
# Input and target images will have 12 channels each
inp = tf.keras.layers.Input(shape=[512, 512, 12], name='input_image')
tar = tf.keras.layers.Input(shape=[512, 512, 12], name='target_image')
# Concatenate input and target images
x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 512, 512, 24)
down1 = downsample(64, 4, False)(x)  # (batch_size, 256, 256, 64)
down2 = downsample(128, 4)(down1)    # (batch_size, 128, 128, 128)
down3 = downsample(256, 4)(down2)    # (batch_size, 64, 64, 256)
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 66, 66, 256)
conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                              kernel_initializer=initializer,
                              use_bias=False)(zero_pad1)  # (batch_size, 63, 63, 512)
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 65, 65, 512)
last = tf.keras.layers.Conv2D(1, 4, strides=1,
                              kernel_initializer=initializer)(zero_pad2)  # (batch_size, 62, 62, 1)
return tf.keras.Model(inputs=[inp, tar], outputs=last)
Finally, I changed the generate_images() function like below:
def generate_images(model, test_input, tar, save_dir='Some_Directory'):
prediction = model(test_input, training=True)
num_bands = 12
for i in range(0, num_bands, 3):
    bands = [i, i+1, i+2] if (i+2) < num_bands else [i, i+1, num_bands-1]
    plt.figure(figsize=(15, 5))
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for j in range(3):
        plt.subplot(1, 3, j+1)
        plt.title(title[j])
        # Normalize and select the bands for visualization
        image_display = tf.stack([display_list[j][..., band] for band in bands], axis=-1)
        image_display = (image_display + 1) / 2  # Rescale to [0, 1]
        plt.imshow(image_display)
        plt.axis('off')
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f'generated_image_bands_{i}-{i+1}-{i+2}.png'))
    plt.close()
# Usage example
for example_input, example_target in test_dataset.take(1):
    generate_images(generator, example_input, example_target)
After these changes, when I run the fit() function to train the model using the code below:
fit(train_dataset, test_dataset, steps=40000)
Mostly after some steps (which can range from 1 to a few thousands), I get errors that are usually related to the shape of the tensors.
These errors are like below:
 tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__StridedSlice_device_/job:localhost/replica:0/task:0/device:GPU:0}} slice index 1 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/ 
[[{{node EagerPyFunc}}]] [Op:IteratorGetNext] name: 
InvalidArgumentError: {{function_node __wrapped__StridedSlice_device_/job:localhost/replica:0/task:0/device:GPU:0}} Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1], [1], and [3] instead. [Op:StridedSlice] name: strided_slice/
InvalidArgumentError: Exception encountered when calling layer 'conv2d_transpose_8' (type Conv2DTranspose).
{{function_node
 __wrapped__StridedSlice_device_/job:localhost/replica:0/task:0/device:GPU:0}}
Expected begin, end, and strides to be 1D equal size tensors, but got
shapes [1], [3], and [3] instead. [Op:StridedSlice] name:
model/conv2d_transpose_8/strided_slice/
Call arguments received by layer 'conv2d_transpose_8' (type
Conv2DTranspose):   • inputs=tf.Tensor(shape=(1, 256, 256, 128),
 dtype=float32)
Also, it is worth mentioning that sometimes when I run the fit() function, I get warnings like below but I don't think that these are causing the issue because I had similar errors in previous functions as well but they didn't affect the result.
WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7927400fc1f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7927400fc1f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
I have tried changing the BATCH_SIZE from 1 to 2, reducing the number of samples significantly and changing some attributes inside the Generator() like changing the apply_batchnorm to True or False. How can I fix the issue and train my model for images with 512 x 512 x 12 dimensions?
**Edit: **Here is the result of running generator.summary():
    Model: "model"
    __________________________________________________________________________________________________
    Layer (type)                Output Shape                 Param #   Connected to                  
    ==================================================================================================
    input_1 (InputLayer)        [(None, 512, 512, 12)]       0         []                            
                                                                                                    
    sequential_2 (Sequential)   (None, 256, 256, 64)         12288     ['input_1[0][0]']             
                                                                                                    
    sequential_3 (Sequential)   (None, 128, 128, 128)        131584    ['sequential_2[0][0]']        
                                                                                                    
    sequential_4 (Sequential)   (None, 64, 64, 256)          525312    ['sequential_3[0][0]']        
                                                                                                    
    sequential_5 (Sequential)   (None, 32, 32, 512)          2099200   ['sequential_4[0][0]']        
                                                                                                    
    sequential_6 (Sequential)   (None, 16, 16, 512)          4196352   ['sequential_5[0][0]']        
                                                                                                    
    sequential_7 (Sequential)   (None, 8, 8, 512)            4196352   ['sequential_6[0][0]']        
                                                                                                    
    sequential_8 (Sequential)   (None, 4, 4, 512)            4196352   ['sequential_7[0][0]']        
                                                                                                    
    sequential_9 (Sequential)   (None, 2, 2, 512)            4196352   ['sequential_8[0][0]']        
                                                                                                    
    sequential_10 (Sequential)  (None, 4, 4, 512)            4196352   ['sequential_9[0][0]']        
                                                                                                    
    concatenate (Concatenate)   (None, 4, 4, 1024)           0         ['sequential_10[0][0]',       
                                                                        'sequential_8[0][0]']        
                                                                                                    
    sequential_11 (Sequential)  (None, 8, 8, 512)            8390656   ['concatenate[0][0]']         
                                                                                                    
    concatenate_1 (Concatenate  (None, 8, 8, 1024)           0         ['sequential_11[0][0]',       
    )                                                                   'sequential_7[0][0]']        
                                                                                                    
    sequential_12 (Sequential)  (None, 16, 16, 512)          8390656   ['concatenate_1[0][0]']       
                                                                                                    
    concatenate_2 (Concatenate  (None, 16, 16, 1024)         0         ['sequential_12[0][0]',       
    )                                                                   'sequential_6[0][0]']        
                                                                                                    
    sequential_13 (Sequential)  (None, 32, 32, 512)          8390656   ['concatenate_2[0][0]']       
                                                                                                    
    concatenate_3 (Concatenate  (None, 32, 32, 1024)         0         ['sequential_13[0][0]',       
    )                                                                   'sequential_5[0][0]']        
                                                                                                    
    sequential_14 (Sequential)  (None, 64, 64, 256)          4195328   ['concatenate_3[0][0]']       
                                                                                                    
    concatenate_4 (Concatenate  (None, 64, 64, 512)          0         ['sequential_14[0][0]',       
    )                                                                   'sequential_4[0][0]']        
                                                                                                    
    sequential_15 (Sequential)  (None, 128, 128, 128)        1049088   ['concatenate_4[0][0]']       
                                                                                                    
    concatenate_5 (Concatenate  (None, 128, 128, 256)        0         ['sequential_15[0][0]',       
    )                                                                   'sequential_3[0][0]']        
                                                                                                    
    sequential_16 (Sequential)  (None, 256, 256, 64)         262400    ['concatenate_5[0][0]']       
                                                                                                    
    concatenate_6 (Concatenate  (None, 256, 256, 128)        0         ['sequential_16[0][0]',       
    )                                                                   'sequential_2[0][0]']        
                                                                                                    
    conv2d_transpose_8 (Conv2D  (None, 512, 512, 12)         24588     ['concatenate_6[0][0]']       
    Transpose)                                                                                       
                                                                                                    
    ==================================================================================================
    Total params: 54453516 (207.72 MB)
    Trainable params: 54442636 (207.68 MB)
    Non-trainable params: 10880 (42.50 KB)
    __________________________________________________________________________________________________
 
                        
This answer is just a guess, because it is too long for a comment, but maybe it helps.
I tested your code with the data from the pix2pix link and transformed it to
(512, 512, 12)to fit your model architecture. I could not replicate your error, but ran into the errorwhich makes sense, because with
take(1)(ortakein general) you don't get batches, but iterate over a list of single samples. (take(4000)generates a list of 4000 single samples, which all have shape(512, 512, 12), missing the batch dimension.)My solution to this is expanding the single sample with a batch dimension:
(Alternatively, you can batch the
Datasetwithand then change your training loop from
to
I think you'll loose your steps counter, but it would be more efficient with
prefetch,cacheand so on. See here for performance tips.)As for your warning
(1) is most likely out, because I did not get it. That leaves (2) and (3). You should check you input data if all samples have the same shape (2) and if some are objects (3)(pretty unlikely, but hey).
This link mentions variable length input and single sample input, where the fix is to expand the first dimension, just as I had to to run the code.
You could also test your network with dummy data
If this does not throw an error, you can be sure that it has to be your input data.