Building MNIST Classification model using Federated learning?

81 views Asked by At

I am completely new to the concepts of Federated learning. Here is a simulation I am trying to make in my Laptop. Suppose I have two clients cl_1 and cl_2, cl_1 has a subset of MNIST dataset from 0-4 and cl_2 has a subset of 5-9. How can I simulate federated learning training on my personal laptop by using tensor flow-federated to train a model to classify MNIST digits 0-9 ? I am using the following functions for segregation of my datasets for cl_1 and cl_2 and defining my CNN model.

def load_dataset():
     # Load and segregate the MNIST dataset
     (trainX, trainY), (testX, testY) = mnist.load_data()

     train_mask_0_4 = np.isin(trainY, [0, 1, 2, 3, 4])
     test_mask_0_4 = np.isin(testY, [0, 1, 2, 3, 4])

     train_mask_5_9 = np.isin(trainY, [5, 6, 7, 8, 9])
     test_mask_5_9 = np.isin(testY, [5, 6, 7, 8, 9])

     x_train_0_4, y_train_0_4 = trainX[train_mask_0_4], trainY[train_mask_0_4]
     x_test_0_4, y_test_0_4 = testX[test_mask_0_4], testY[test_mask_0_4]

     x_train_5_9, y_train_5_9 = trainX[train_mask_5_9], trainY[train_mask_5_9]
     x_test_5_9, y_test_5_9 = testX[test_mask_5_9], testY[test_mask_5_9]

     x_train_0_4 = x_train_0_4.reshape((x_train_0_4.shape[0], 28, 28, 1))
     x_test_0_4 = x_test_0_4.reshape((x_test_0_4.shape[0], 28, 28, 1))
     x_train_5_9 = x_train_5_9.reshape((x_train_5_9.shape[0], 28, 28, 1))
     x_test_5_9 = x_test_5_9.reshape((x_test_5_9.shape[0], 28, 28, 1))

     y_train_0_4 = to_categorical(y_train_0_4, num_classes=10)
     y_test_0_4 = to_categorical(y_test_0_4, num_classes=10)
     y_train_5_9 = to_categorical(y_train_5_9, num_classes=10)
     y_test_5_9 = to_categorical(y_test_5_9, num_classes=10)


     testX = testX.reshape((testX.shape[0], 28, 28, 1))
     # one hot encode target values
     testY = to_categorical(testY)

     return x_train_0_4, y_train_0_4, x_test_0_4, y_test_0_4, x_train_5_9, y_train_5_9, x_test_5_9, y_test_5_9, testY

 def define_model():
     model = tf.keras.Sequential()
     model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 1)))
     model.add(tf.keras.layers.MaxPooling2D((2, 2)))
     model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu',kernel_initializer='he_uniform'))
     model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu',kernel_initializer='he_uniform'))
     model.add(tf.keras.layers.MaxPooling2D((2, 2)))
     model.add(tf.keras.layers.Flatten()) 
     model.add(tf.keras.layers.Dense(100, activation='relu', kernel_initializer='he_uniform'))
     model.add(tf.keras.layers.Dense(10, activation='softmax'))  # 5 classes for digits 0-4
     return model
0

There are 0 answers