I am completely new to the concepts of Federated learning. Here is a simulation I am trying to make in my Laptop. Suppose I have two clients cl_1 and cl_2, cl_1 has a subset of MNIST dataset from 0-4 and cl_2 has a subset of 5-9. How can I simulate federated learning training on my personal laptop by using tensor flow-federated to train a model to classify MNIST digits 0-9 ? I am using the following functions for segregation of my datasets for cl_1 and cl_2 and defining my CNN model.
def load_dataset():
# Load and segregate the MNIST dataset
(trainX, trainY), (testX, testY) = mnist.load_data()
train_mask_0_4 = np.isin(trainY, [0, 1, 2, 3, 4])
test_mask_0_4 = np.isin(testY, [0, 1, 2, 3, 4])
train_mask_5_9 = np.isin(trainY, [5, 6, 7, 8, 9])
test_mask_5_9 = np.isin(testY, [5, 6, 7, 8, 9])
x_train_0_4, y_train_0_4 = trainX[train_mask_0_4], trainY[train_mask_0_4]
x_test_0_4, y_test_0_4 = testX[test_mask_0_4], testY[test_mask_0_4]
x_train_5_9, y_train_5_9 = trainX[train_mask_5_9], trainY[train_mask_5_9]
x_test_5_9, y_test_5_9 = testX[test_mask_5_9], testY[test_mask_5_9]
x_train_0_4 = x_train_0_4.reshape((x_train_0_4.shape[0], 28, 28, 1))
x_test_0_4 = x_test_0_4.reshape((x_test_0_4.shape[0], 28, 28, 1))
x_train_5_9 = x_train_5_9.reshape((x_train_5_9.shape[0], 28, 28, 1))
x_test_5_9 = x_test_5_9.reshape((x_test_5_9.shape[0], 28, 28, 1))
y_train_0_4 = to_categorical(y_train_0_4, num_classes=10)
y_test_0_4 = to_categorical(y_test_0_4, num_classes=10)
y_train_5_9 = to_categorical(y_train_5_9, num_classes=10)
y_test_5_9 = to_categorical(y_test_5_9, num_classes=10)
testX = testX.reshape((testX.shape[0], 28, 28, 1))
# one hot encode target values
testY = to_categorical(testY)
return x_train_0_4, y_train_0_4, x_test_0_4, y_test_0_4, x_train_5_9, y_train_5_9, x_test_5_9, y_test_5_9, testY
def define_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu',kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu',kernel_initializer='he_uniform'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(100, activation='relu', kernel_initializer='he_uniform'))
model.add(tf.keras.layers.Dense(10, activation='softmax')) # 5 classes for digits 0-4
return model