PySyft Worker overfitting

142 views Asked by At

I try to train a image classification (cifar10) with pysyft. My trainsetup has 10 workers where every worker gets betwen 800 and 1200 images of the dataset.

My Problem is that after about 250-300 epochs, the train loss is at about 0.005 and the model stops improving though the test accuracy is just at about 45% with an increasing loss 1.5 -> 8.5. I tried the same with 100 workers on 500 images where it stoped at 32%. Furthrmore the implementation is part of a comparison between models and FL Frameworks and therefor the model can't be changed and the data will be loaded localy and transformed into a Dataloader. Hence I'm very unexperienced with Pytorch and PySyft it might be that I made some mistakes when training the model though i tried to stay as close as possible with the example.

I trained the model without PySyft and it reached about 85% so I think my dataloader and model should be not the problem. For me it looks like the workers overfit on their own data during the training.

Is there a way to prevent workers to overfit or calculate a loss for the global model instead of the workers?

Trainer:

    
def fl_train(args, model, device, federated_train_loader, optimizer, epoch, log):
    model.train()
    results = []
    metrics = []
    t1 = time.time()
    cel = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        t2 = time.time()
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target.long())
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            results.append(loss.item())
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * BATCH_SIZE, len(federated_train_loader) * BATCH_SIZE,
                100. * batch_idx / len(federated_train_loader), loss.item()))

Model:

class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.conv_layer = nn.Sequential(

            # Conv Layer block 1
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2,2)),

            # Conv Layer block 2
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2,2)),

            # Conv Layer block 3
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
            nn.ReLU(inplace=True),
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(1024, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 10)
        )


    def forward(self, x):
        # CNN layers
        x = self.conv_layer(x)

        # flatten
        x = x.view(-1, 1024)

        # NN layer
        x = self.fc_layer(x)
        return F.log_softmax(x, dim=1)

Main:

model = CNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.02) # TODO momentum is not supported at the moment
log = {}
for epoch in range(1, args.epochs + 1):
    log = fl_train(args, model, device, f_dataloader, optimizer, epoch, log)
    if epoch % 20 == 0:
      log = test(args, model, device, test_loader, epoch, log)
    if epoch % 100 == 0:
      store_results(log, model)

Log:

....
Train Epoch: 317 [0/10400 (0%)] Loss: 0.005194
Train Epoch: 317 [3000/10400 (29%)] Loss: 0.003882
Train Epoch: 317 [6000/10400 (58%)] Loss: 0.003100
Train Epoch: 317 [9000/10400 (87%)] Loss: 0.004298
Train Epoch: 318 [0/10400 (0%)] Loss: 0.007426
Train Epoch: 318 [3000/10400 (29%)] Loss: 0.002255
Train Epoch: 318 [6000/10400 (58%)] Loss: 0.003835
Train Epoch: 318 [9000/10400 (87%)] Loss: 0.005277
Train Epoch: 319 [0/10400 (0%)] Loss: 0.006207
Train Epoch: 319 [3000/10400 (29%)] Loss: 0.003562
Train Epoch: 319 [6000/10400 (58%)] Loss: 0.001904
Train Epoch: 319 [9000/10400 (87%)] Loss: 0.002644
Train Epoch: 320 [0/10400 (0%)] Loss: 0.007491
Train Epoch: 320 [3000/10400 (29%)] Loss: 0.003794
Train Epoch: 320 [6000/10400 (58%)] Loss: 0.002643
Train Epoch: 320 [9000/10400 (87%)] Loss: 0.002981
Test set: Average loss: 9.1279, Accuracy: 458/1000 (46%)

Train Epoch: 321 [0/10400 (0%)] Loss: 0.007153
Train Epoch: 321 [3000/10400 (29%)] Loss: 0.004265
Train Epoch: 321 [6000/10400 (58%)] Loss: 0.002708
Train Epoch: 321 [9000/10400 (87%)] Loss: 0.002518
Train Epoch: 322 [0/10400 (0%)] Loss: 0.006285
Train Epoch: 322 [3000/10400 (29%)] Loss: 0.002357
Train Epoch: 322 [6000/10400 (58%)] Loss: 0.002465
Train Epoch: 322 [9000/10400 (87%)] Loss: 0.002406
Train Epoch: 323 [0/10400 (0%)] Loss: 0.005361
Train Epoch: 323 [3000/10400 (29%)] Loss: 0.004807
Train Epoch: 323 [6000/10400 (58%)] Loss: 0.001903
Train Epoch: 323 [9000/10400 (87%)] Loss: 0.003711
Train Epoch: 324 [0/10400 (0%)] Loss: 0.006609
....
0

There are 0 answers