Federated learning using custom model in Pytorch/Pysyft

630 views Asked by At

I am trying to build a federated learning model. In my scenario, I have 3 workers and an orchestrator. The workers start the training and at the end of each training round, the models are being sent to the orchestrator, the orchestrator calculates the federated average and sends back the new model, the workers train on that new model etc. The custom network is an AutoEncoder that I have built from scratch.

Unfortunately I am getting this error message from the workers: RuntimeError: forward() is missing value for argument 'inputs'. Declaration: forward(ClassType self, Tensor inputs, Tensor outputs) -> (Tensor) which is weird because my forward function is defined as follows, inside the AE class:

class AutoEncoder(nn.Module):

    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size

        # Encoder specification
        self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
        self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
        self.enc_linear_1 = nn.Linear(53 * 53 * 20, 50)
        self.enc_linear_2 = nn.Linear(50, self.code_size)

        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.code_size, 160)
        self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)

    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code

    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.selu(F.max_pool2d(code, 2))

        code = self.enc_cnn_2(code)
        code = F.selu(F.max_pool2d(code, 2))
        code = code.view([images.size(0), -1])
        code = F.selu(self.enc_linear_1(code))

        code = self.enc_linear_2(code)
        return code

    def decode(self, code):
        out = F.selu(self.dec_linear_1(code))
        out = F.sigmoid(self.dec_linear_2(out))
        out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out


Loss function (cross entropy)

```
@torch.jit.script
def loss_fn(inputs, outputs):
    return torch.nn.functional.mse_loss(input=inputs, target=outputs)


def set_gradients(model, finetuning):
    """Helper function to exclude all gradients from training
    used for transfer learning in feature extract mode; See: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

    Args:
        model (torch.nn.Module): model object.
        finetuning (bool):  if true, nothing will be changed; transfer learning will be used in finetuning mode, i.e., all gradients are trained;
                            if false, all gradients get excluded from training, used in feature extract mode
    """

    if not finetuning:
        for param in model.parameters():
            param.requires_grad = False
```


```
def initialize_model():
    model = AutoEncoder(code_size)
    set_gradients(model, False)
    return model
async def train_model_on_worker(
    worker: websocket_client.WebsocketClientWorker,
    traced_model: torch.jit.ScriptModule,
    dataset_key: str,
    batch_size: int,
    curr_round: int,
    lr: float,
):

    traced_model.train()
    print("train mode on")
    train_config = sy.TrainConfig(
        model=traced_model,
        loss_fn=loss_fn,
        batch_size=batch_size,
        shuffle=True,
        epochs=1,
        optimizer="Adam",
        optimizer_args={"lr": lr}
    )
logger.info(worker.id + " send trainconfig")
    train_config.send(worker)
    print("Model sent to the worker")
    logger.info(worker.id + " start training")
    await worker.async_fit(dataset_key=DATASET_KEY, return_ids=[0])
    logger.info(worker.id + " training done")
    results = dict()
logger.info(worker.id + " get model")
    model = train_config.model_ptr.get().obj

    results["worker_id"] = worker.id
    results["model"] = model

    return results
def validate_model(identifier, model, dataloader, criterion):
model.eval() # changes the mode of the model, in evaluation mode we don't have dropout

    loss = []
    for i, (inputs,_) in enumerat(dataloader):
        print("validation mode on")
        #with torch.set_grad_enabled(False):
        outputs, code  = model(Variable(inputs)) #a tensor with 2 values: one for leak and one for no leak
        loss = criterion(outputs, inputs)

        loss = loss.sqrt()
        loss.append(loss.item())
    print("Loss = %.3f" % loss.data)
async def main():
    args = define_and_get_arguments()
    hook = sy.TorchHook(torch) #with this we can override some pytorch methods with pysyft

    # Create WebsocketClientWorkers using IDs, Ports and IP addresses from arguments
    worker_instances = []
    for i in range(len(args.workers) // 3):
        j = i * 3
        worker_instances.append(websocket_client.WebsocketClientWorker(
            id=args.workers[j], port=args.workers[j + 1], host=args.workers[j + 2], hook=hook, verbose=args.verbose))

    model = initialize_model()

    # optional loading of predefined model weights (= dictionary):
    if args.basic_model:
        model.load_state_dict(torch.load(args.basic_model))

    # model serialization (creating an object of type ScriptModule):
    model.eval()
    traced_model = torch.jit.trace(model, torch.rand([1, 3, 224, 224], dtype=torch.float)) #we need to change the form of the model in order to make it 
    #serialisable and send it to the workers

    # Data / picture transformation:
    data_transforms = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create validation dataset and dataloader
    validation_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'val'), data_transforms)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Create test dataset and dataloader
    test_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'test'), data_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Lists to plot loss and acc after training
    train_loss_values = []
    train_acc_values = []
    val_loss_values = []
    val_acc_values = []

    np.set_printoptions(formatter={"float": "{: .0f}".format})

    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)
        print("entered training ")
        # reduce learn rate every 5 training rounds (adaptive learnrate)
        lr = args.lr * (0.1 ** ((curr_round - 1) // 5))
        completed, pending = await asyncio.wait(
            [
                *[
                    train_model_on_worker(
                        worker=worker,
                        traced_model=traced_model,
                        dataset_key=DATASET_KEY,
                        batch_size=args.batch_size,
                        curr_round=curr_round,
                        lr=lr,
                    )
                    for worker in worker_instances
                ]
            ],
            timeout=40
        )

        results = []
        for entry in completed:
            print("entry")
            print(entry)
            results.append(entry.result())

        for entry in pending:
            entry.cancel()

        new_worker_instances = []
        for entry in results:
            for worker in worker_instances:
                if (entry["worker_id"] == worker.id):
                    new_worker_instances.append(worker)

        worker_instances = new_worker_instances


        # Setup the loss function
        criterion = torch.nn.functional.mse_loss()
        #optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)

        # Federate models (note that this will also change the model in models[0]
        models = {}
        for worker in results:
            if worker["model"] is not None:
                models[worker["worker_id"]] = worker["model"]

        logger.info("aggregation")
        traced_model = utils.federated_avg(models)
        logger.info("aggregation done")

        # Evaluate federated model
        logger.info("Validate..")
        loss = validate_model("Federated", traced_model, validation_dataloader, criterion)
        logger.info("Validation done")
        val_loss_values.append(loss)
        #val_acc_values.append(acc)
if __name__ == "__main__":
    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s | %(message)s"
    logging.basicConfig(filename='logs/orchestrator_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("orchestrator")
    logger.setLevel(level=logging.INFO)

    asyncio.get_event_loop().run_until_complete(main())

The code of the workers:
def load_dataset(dataset_path):
    """Helper function for setting up the local datasets.

    Args:
        dataset_path (string):  path to dataset, images must be arranged in this way
                                dataset_path/train/class1/xxx.jpg
                                dataset_path/train/class2/yyy.jpg
    """

    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(os.path.join(dataset_path, 'train'), data_transform)

    return dataset


def start_websocket_server(id, port, dataset, verbose):
    """Helper function for spinning up a websocket server.

    Args:
        id (str or id): the unique id of the worker.
        port (int): the port on which the server should be run.
        dataset: dataset, which the worker should provide.
        verbose (bool): a verbose option - will print all messages sent/received to stdout.
    """

    hook = sy.TorchHook(torch)
    server = WebsocketServerWorker(id=id, host="0.0.0.0", port=port, hook=hook, verbose=verbose)
    server.add_dataset(dataset, key=DATASET_KEY)
    server.start()

    return server

    def _fit(self, model, dataset_key, loss_fn):
        logger = logging.getLogger("worker")
        logger.info(dataset_key)
        print("dataset key")
        model.train()
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
        #data_loader = self._create_data_loader(
            #dataset_key=dataset_key, shuffle=self.train_config.shuffle
        #)
        print("worker")
        print(data_loader)

        loss = None
        iteration_count = 0

        for _ in range(self.train_config.epochs):
            for data in enumerate(data_loader):
                # Set gradients to zero
                self.optimizer.zero_grad()

                # Update model
                output,code = model(data)
                logger.info(data)
                logger.info(output)
                loss = loss_fn(data, output)
                loss.backward()
                self.optimizer.step()

                # Update and check interation count
                iteration_count += 1
                if iteration_count >= self.train_config.max_nr_batches >= 0:
                    break

        return model



if __name__ == "__main__":

    # Parse args
    args = define_and_get_arguments()

    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s"
    logging.basicConfig(filename='logs/worker_' + args.id + '_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("worker")
    logger.setLevel(level=logging.INFO)

    # Load train dataset
    dataset = load_dataset(args.dataset_path)

    # Start server
    server = start_websocket_server(
        id=args.id,
        port=args.port,
        dataset=dataset,
        verbose=args.verbose,
    )
Does anybody know what the problem is?
0

There are 0 answers