Linked Questions

Popular Questions

Use Pytorch SSIM loss function in my model

Asked by At

I am trying out this SSIM loss implement by this repo for image restoration.

For the reference of original sample code on author's GitHub, I tried:

model.train()
for epo in range(epoch):
    for i, data in enumerate(trainloader, 0):
        inputs = data
        inputs = Variable(inputs)
        optimizer.zero_grad()
        inputs = inputs.view(bs, 1, 128, 128)
        top = model.upward(inputs)
        outputs = model.downward(top, shortcut = True)
        outputs = outputs.view(bs, 1, 128, 128)

        if i % 20 == 0:
            out = outputs[0].view(128, 128).detach().numpy() * 255
            cv2.imwrite("/home/tk/Documents/recover/SSIM/" + str(epo) + "_" + str(i) + "_re.png", out)

        loss = - criterion(inputs, outputs)
        ssim_value = - loss.data.item()
        print (ssim_value)
        loss.backward()
        optimizer.step()

However, the results didn't come out as I expected. After first 10 epochs, the printed outcome image were all black.

loss = - criterion(inputs, outputs) is proposed by the author, however, for classical Pytorch training code this will be loss = criterion(y_pred, target), therefore should be loss = criterion(inputs, outputs) here.

However, I tried loss = criterion(inputs, outputs) but the results are still the same.

Can anyone share some thoughts about how to properly utilize SSIM loss? Thanks.

Related Questions