I am trying to build a federated learning model. In my scenario, I have 3 workers and an orchestrator. The workers start the training and at the end of each training round, the models are being sent to the orchestrator, the orchestrator calculates the federated average and sends back the new model, the workers train on that new model etc. The custom network is an AutoEncoder that I have built from scratch.
Unfortunately I am getting this error message from the workers: RuntimeError: forward() is missing value for argument 'inputs'. Declaration: forward(ClassType self, Tensor inputs, Tensor outputs) -> (Tensor) which is weird because my forward function is defined as follows, inside the AE class:
class AutoEncoder(nn.Module):
def __init__(self, code_size):
super().__init__()
self.code_size = code_size
# Encoder specification
self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
self.enc_linear_1 = nn.Linear(53 * 53 * 20, 50)
self.enc_linear_2 = nn.Linear(50, self.code_size)
# Decoder specification
self.dec_linear_1 = nn.Linear(self.code_size, 160)
self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)
def forward(self, images):
code = self.encode(images)
out = self.decode(code)
return out, code
def encode(self, images):
code = self.enc_cnn_1(images)
code = F.selu(F.max_pool2d(code, 2))
code = self.enc_cnn_2(code)
code = F.selu(F.max_pool2d(code, 2))
code = code.view([images.size(0), -1])
code = F.selu(self.enc_linear_1(code))
code = self.enc_linear_2(code)
return code
def decode(self, code):
out = F.selu(self.dec_linear_1(code))
out = F.sigmoid(self.dec_linear_2(out))
out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
return out
Loss function (cross entropy)
```
@torch.jit.script
def loss_fn(inputs, outputs):
return torch.nn.functional.mse_loss(input=inputs, target=outputs)
def set_gradients(model, finetuning):
"""Helper function to exclude all gradients from training
used for transfer learning in feature extract mode; See: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
Args:
model (torch.nn.Module): model object.
finetuning (bool): if true, nothing will be changed; transfer learning will be used in finetuning mode, i.e., all gradients are trained;
if false, all gradients get excluded from training, used in feature extract mode
"""
if not finetuning:
for param in model.parameters():
param.requires_grad = False
```
```
def initialize_model():
model = AutoEncoder(code_size)
set_gradients(model, False)
return model
async def train_model_on_worker(
worker: websocket_client.WebsocketClientWorker,
traced_model: torch.jit.ScriptModule,
dataset_key: str,
batch_size: int,
curr_round: int,
lr: float,
):
traced_model.train()
print("train mode on")
train_config = sy.TrainConfig(
model=traced_model,
loss_fn=loss_fn,
batch_size=batch_size,
shuffle=True,
epochs=1,
optimizer="Adam",
optimizer_args={"lr": lr}
)
logger.info(worker.id + " send trainconfig")
train_config.send(worker)
print("Model sent to the worker")
logger.info(worker.id + " start training")
await worker.async_fit(dataset_key=DATASET_KEY, return_ids=[0])
logger.info(worker.id + " training done")
results = dict()
logger.info(worker.id + " get model")
model = train_config.model_ptr.get().obj
results["worker_id"] = worker.id
results["model"] = model
return results
def validate_model(identifier, model, dataloader, criterion):
model.eval() # changes the mode of the model, in evaluation mode we don't have dropout
loss = []
for i, (inputs,_) in enumerat(dataloader):
print("validation mode on")
#with torch.set_grad_enabled(False):
outputs, code = model(Variable(inputs)) #a tensor with 2 values: one for leak and one for no leak
loss = criterion(outputs, inputs)
loss = loss.sqrt()
loss.append(loss.item())
print("Loss = %.3f" % loss.data)
async def main():
args = define_and_get_arguments()
hook = sy.TorchHook(torch) #with this we can override some pytorch methods with pysyft
# Create WebsocketClientWorkers using IDs, Ports and IP addresses from arguments
worker_instances = []
for i in range(len(args.workers) // 3):
j = i * 3
worker_instances.append(websocket_client.WebsocketClientWorker(
id=args.workers[j], port=args.workers[j + 1], host=args.workers[j + 2], hook=hook, verbose=args.verbose))
model = initialize_model()
# optional loading of predefined model weights (= dictionary):
if args.basic_model:
model.load_state_dict(torch.load(args.basic_model))
# model serialization (creating an object of type ScriptModule):
model.eval()
traced_model = torch.jit.trace(model, torch.rand([1, 3, 224, 224], dtype=torch.float)) #we need to change the form of the model in order to make it
#serialisable and send it to the workers
# Data / picture transformation:
data_transforms = transforms.Compose([
transforms.Resize(INPUT_SIZE),
transforms.CenterCrop(INPUT_SIZE),
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Create validation dataset and dataloader
validation_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'val'), data_transforms)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
# Create test dataset and dataloader
test_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'test'), data_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
# Lists to plot loss and acc after training
train_loss_values = []
train_acc_values = []
val_loss_values = []
val_acc_values = []
np.set_printoptions(formatter={"float": "{: .0f}".format})
for curr_round in range(1, args.training_rounds + 1):
logger.info("Training round %s/%s", curr_round, args.training_rounds)
print("entered training ")
# reduce learn rate every 5 training rounds (adaptive learnrate)
lr = args.lr * (0.1 ** ((curr_round - 1) // 5))
completed, pending = await asyncio.wait(
[
*[
train_model_on_worker(
worker=worker,
traced_model=traced_model,
dataset_key=DATASET_KEY,
batch_size=args.batch_size,
curr_round=curr_round,
lr=lr,
)
for worker in worker_instances
]
],
timeout=40
)
results = []
for entry in completed:
print("entry")
print(entry)
results.append(entry.result())
for entry in pending:
entry.cancel()
new_worker_instances = []
for entry in results:
for worker in worker_instances:
if (entry["worker_id"] == worker.id):
new_worker_instances.append(worker)
worker_instances = new_worker_instances
# Setup the loss function
criterion = torch.nn.functional.mse_loss()
#optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)
# Federate models (note that this will also change the model in models[0]
models = {}
for worker in results:
if worker["model"] is not None:
models[worker["worker_id"]] = worker["model"]
logger.info("aggregation")
traced_model = utils.federated_avg(models)
logger.info("aggregation done")
# Evaluate federated model
logger.info("Validate..")
loss = validate_model("Federated", traced_model, validation_dataloader, criterion)
logger.info("Validation done")
val_loss_values.append(loss)
#val_acc_values.append(acc)
if __name__ == "__main__":
# Logging setup
date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
FORMAT = "%(asctime)s | %(message)s"
logging.basicConfig(filename='logs/orchestrator_' + date_time + '.log', format=FORMAT)
logger = logging.getLogger("orchestrator")
logger.setLevel(level=logging.INFO)
asyncio.get_event_loop().run_until_complete(main())
The code of the workers:
def load_dataset(dataset_path):
"""Helper function for setting up the local datasets.
Args:
dataset_path (string): path to dataset, images must be arranged in this way
dataset_path/train/class1/xxx.jpg
dataset_path/train/class2/yyy.jpg
"""
data_transform = transforms.Compose([
transforms.RandomResizedCrop(INPUT_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(os.path.join(dataset_path, 'train'), data_transform)
return dataset
def start_websocket_server(id, port, dataset, verbose):
"""Helper function for spinning up a websocket server.
Args:
id (str or id): the unique id of the worker.
port (int): the port on which the server should be run.
dataset: dataset, which the worker should provide.
verbose (bool): a verbose option - will print all messages sent/received to stdout.
"""
hook = sy.TorchHook(torch)
server = WebsocketServerWorker(id=id, host="0.0.0.0", port=port, hook=hook, verbose=verbose)
server.add_dataset(dataset, key=DATASET_KEY)
server.start()
return server
def _fit(self, model, dataset_key, loss_fn):
logger = logging.getLogger("worker")
logger.info(dataset_key)
print("dataset key")
model.train()
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
#data_loader = self._create_data_loader(
#dataset_key=dataset_key, shuffle=self.train_config.shuffle
#)
print("worker")
print(data_loader)
loss = None
iteration_count = 0
for _ in range(self.train_config.epochs):
for data in enumerate(data_loader):
# Set gradients to zero
self.optimizer.zero_grad()
# Update model
output,code = model(data)
logger.info(data)
logger.info(output)
loss = loss_fn(data, output)
loss.backward()
self.optimizer.step()
# Update and check interation count
iteration_count += 1
if iteration_count >= self.train_config.max_nr_batches >= 0:
break
return model
if __name__ == "__main__":
# Parse args
args = define_and_get_arguments()
# Logging setup
date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s"
logging.basicConfig(filename='logs/worker_' + args.id + '_' + date_time + '.log', format=FORMAT)
logger = logging.getLogger("worker")
logger.setLevel(level=logging.INFO)
# Load train dataset
dataset = load_dataset(args.dataset_path)
# Start server
server = start_websocket_server(
id=args.id,
port=args.port,
dataset=dataset,
verbose=args.verbose,
)
Does anybody know what the problem is?