Since I'm novice to Pytorch, this question might be a very trivial one, but I'd like to ask for your help about how to solve this one.
I've implemented one network from a paper and used all hyper parameters and all layers described in the paper.
But when it starts training, even though I set the learning rate decay as 0.001, the errors didn't go down. Training errors goes around 3.3~3.4 and test errors around 3.5~3.6 during 100 epochs..!
I could change the hyperparameters to improve the model, but since the paper told exact numbers, I'd like to see whether there's an error in the training code that I've implemented.
The code below is the code that I used for training.
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import json
import torch
import math
import time
import os
model = nn.Sequential(Baseline(), Classification(40)).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
batch = 32
train_path = '/content/mtrain'
train_data = os.listdir(train_path)
test_path = '/content/mtest'
test_data = os.listdir(test_path)
train_loader = torch.utils.data.DataLoader(train_data, batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch, shuffle=True)
train_loss, val_loss = [], []
epochs = 100
now = time.time()
print('training start!')
for epoch in range(epochs):
running_loss = 0.0
for bidx, trainb32 in enumerate(train_loader):
bpts, blabel = [], []
for i, data in enumerate(trainb32):
path = os.path.join(train_path, data)
with open(path, 'r') as f:
jdata = json.load(f)
label = jdata['label']
pts = jdata['pts']
bpts.append(pts)
blabel.append(label)
bpts = torch.tensor(bpts).transpose(1, 2).to(device)
blabel = torch.tensor(blabel).to(device)
input = data_aug(bpts).to(device)
optimizer.zero_grad()
y_pred, feat_stn, glob_feat = model(input)
# print(f'global_feat is {global_feat}')
loss = F.nll_loss(y_pred, blabel) + 0.001 * regularizer(feat_stn)
loss.backward()
optimizer.step()
running_loss += loss.item()
if bidx % 10 == 9:
vrunning_loss = 0
vacc = 0
model.eval()
with torch.no_grad():
# val batch
for vbidx, testb32 in enumerate(test_loader):
bpts, blabel = [], []
for j, data in enumerate(testb32):
path = os.path.join(test_path, data)
with open(path, 'r') as f:
jdata = json.load(f)
label = jdata['label']
pts = jdata['pts']
bpts.append(pts)
blabel.append(label)
bpts = torch.tensor(bpts).transpose(1, 2).to(device)
blabel = torch.tensor(blabel).to(device)
input = data_aug(bpts).to(device)
vy_pred, vfeat_stn, vglob_feat = model(input)
# print(f'global_feat is {vglob_feat}')
vloss = F.nll_loss(vy_pred, blabel) + 0.001 * regularizer(vfeat_stn)
_, vy_max = torch.max(vy_pred, dim=1)
vy_acc = torch.sum(vy_max == blabel) / batch
vacc += vy_acc
vrunning_loss += vloss
# print every training 10th batch
train_loss.append(running_loss / len(train_loader))
val_loss.append(vrunning_loss / len(test_loader))
print(f"Epoch {epoch+1}/{epochs} {bidx}/{len(train_loader)}.. "
f"Train loss: {running_loss / 10:.3f}.."
f"Val loss: {vrunning_loss / len(test_loader):.3f}.."
f"Val Accuracy: {vacc/len(test_loader):.3f}.."
f"Time: {time.time() - now}")
now = time.time()
running_loss = 0
model.train()
print(f'training finish! training time is {time.time() - now}')
print(model.parameters())
savePath = '/content/modelpath.pth'
torch.save(model.state_dict(), '/content/modelpath.pth')
Sorry for the basic question, but if there's no error in this training code, it would be very pleasure to let me know and if there is, please give any hint to solve..
I've implemented pointNet code and the full code is available at https://github.com/RaraKim/PointNet/blob/master/PointNet_pytorch.ipynb
Thank you!
I saw your code, and I believe that you have some tensors that are manually declared. In torch tensors the default value of the "requires_grad" flag is False. And I think hence your backpropagation isn't working correctly, can you try to fix that? I will be happy to help you further if the issue still persists.