flwr client and server communication

161 views Asked by At

I am building a federated learning model using flwr framework.

I want to send a message from the client to the server, if the loss of the current communication round t is greater than t-1. But I don't find an example of this in the documentation.

(server.py)Server side code:

import sys
sys.path.append('../')

from Custom_Scripts.debug import debug
from Custom_Scripts.model_architecture import model_initialisation
from Custom_Scripts.preprocess import preprocess_data

# import necessary libaries
import flwr as fl
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from typing import Optional
from tensorflow import keras
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import json
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import f1_score,classification_report

DEBUG = False
num_rounds = 10

def main():
    debug("Reading server data...")
    server_data = pd.read_csv(r"C:\Adarsh work\Pronto dataset FL\Data\server.csv")

    server_data_features = server_data.iloc[:, :-1]
    server_data_labels = server_data.iloc[:, -1]

    n_timesteps = 10
    # Preprocess the data and store them as small arrays
    data_list = preprocess_data(server_data_features, server_data_labels, n_timesteps)

    # Extract the features and labels from the preprocessed data list
    X_train = np.array([item[0] for item in data_list])
    y_train = np.array([item[1] for item in data_list])

    # Split the data into train and test sets
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

    model = model_initialisation()
    # Create strategy
    strategy = fl.server.strategy.FedAdam(
        evaluate_fn = get_eval_fn(model, X_val, y_val),
        on_fit_config_fn=fit_config,
        initial_parameters=fl.common.ndarrays_to_parameters(model.get_weights()),
        fraction_evaluate = 0.0,
        eta=0.1,
        beta_1=0.9,
        beta_2=0.99,
        tau=1e-4,
    )

    # Start Flower server for five rounds of federated learning
    history_ = fl.server.start_server(server_address="localhost:8083",  strategy=strategy, config=fl.server.ServerConfig(num_rounds=num_rounds))
    fl.server.m
    print("losses_ centralised", history_.losses_centralized)
    with open(r'C:\Adarsh work\Pronto dataset FL\Results\FL_results_loss.json', 'w') as f:
        json.dump(history_.losses_centralized, f)

    print("metrics_centralised", history_.metrics_centralized)
    with open(r'C:\Adarsh work\Pronto dataset FL\Results\FL_results_metrics.json', 'w') as f:
        json.dump(history_.metrics_centralized, f)

def fit_config(rnd: int):
    """Return training configuration dict for each round."""

    config = {
         "num_rounds":rnd,
    }
    return config


def get_eval_fn(model, X_val, y_val):
    """Return an evaluation function for server-side evaluation."""

    # The `evaluate` function will be called after every round
    def evaluate(
        server_round: int,
        parameters_ndarrays: List[fl.common.NDArray],
        client_info: Dict[str, Union[bool, bytes, float, int, str]],
         message_from_client: Optional[str] = None  # Add this parameter for the received message
    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        # Convert parameters_ndarrays to Parameters
        parameters = fl.common.ndarrays_to_parameters(parameters_ndarrays)

        # Convert Parameters to a list of NumPy arrays
        model.set_weights(fl.common.parameters_to_ndarrays(parameters))

        loss, accuracy = model.evaluate(X_val, y_val)
        y_pred = model.predict(X_val)
        y_pred_labels = np.argmax(y_pred, axis=1)  # Convert one-hot encoded predictions to integer labels
        
        # Convert integer labels to one-hot encoded format
        y_pred_onehot = to_categorical(y_pred_labels, num_classes=y_val.shape[1])

        f1 = f1_score(y_val, y_pred_onehot, average="weighted")
        print("f1 score is", f1)

        if message_from_client:
            print(f"Message received from client: {message_from_client}")
        return loss, {"accuracy": accuracy, "F1_score": f1}

    return evaluate


if __name__ == "__main__":
    main()

Client.py (Client-side code)

#Gives access to all the scripts in the directory
import sys
sys.path.append('../')


#boiler plate code for all clients
from Custom_Scripts.debug import debug
from Custom_Scripts.model_architecture import model_initialisation
from Custom_Scripts.preprocess import preprocess_data
from Custom_Scripts.ClassPronto import Pronto

#import necessary libraries
import flwr as fl
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
import sys
from typing import Optional
from tqdm import tqdm
data = pd.read_csv("C:\Adarsh work\FedDrfit\Data\client1.csv")
DEBUG = True

class Pronto(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test, drift_status = None):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test
        self.previous_loss = None  # Track previous loss
        self.drift_status = drift_status  # Drift detection callback function
    
    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        history = self.model.fit(self.x_train, self.y_train, batch_size=32, epochs=10, validation_data=(self.x_test, self.y_test))
        current_loss = history.history["loss"][0]

        if self.previous_loss is not None:
            # Drift detection: Compare with previous loss
            if current_loss > self.previous_loss:
                print("Drift Detected! Reinitializing the local model...")
                #self.model = model_initialisation()  # Reinitialize the model
                self.drift_status = "Drifted"

        self.previous_loss = current_loss

        parameters_prime = self.model.get_weights()
        num_examples_train = len(self.x_train)
        results = {
            "loss": current_loss,  # Use the current loss for results
            "accuracy": history.history["accuracy"][0],
            "val_loss": history.history["val_loss"][0],
            "val_accuracy": history.history["val_accuracy"][0],
        }
        return parameters_prime, num_examples_train, results

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test)
        num_examples_test = len(self.x_test)
        return loss, num_examples_test, {"accuracy": accuracy}

def main(data):
    debug("Reading data...")
    # data = pd.read_csv(r"C:\Adarsh work\Pronto dataset FL\Data\client1.csv")

    features = data.iloc[:, :-1]
    labels = data.iloc[:, -1]

    n_timesteps = 10
    # Preprocess the data and store them as small arrays
    data_list = preprocess_data(features, labels, n_timesteps)

    # Extract the features and labels from the preprocessed data list
    X_train = np.array([item[0] for item in data_list])
    y_train = np.array([item[1] for item in data_list])

    # Split the data into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

    # Initialize the model with the new architecture
    model = model_initialisation()

    learning_rate = 0.001
    history = model.fit(X_train, y_train, 
                        epochs=50, 
                        batch_size=1000, 
                        validation_data=(X_test, y_test),  # Use validation_data for X_test and y_test
                        verbose=1)

    # Rest of your code...
    client = Pronto(model, X_train, y_train, X_test, y_test)
    fl.client.start_numpy_client(server_address = "localhost:8083", client=client)

if __name__ == '__main__':
    main(data)
0

There are 0 answers