I've been reading this research paper- https://arxiv.org/abs/1908.00413, and trying to implement the code from GitHub- https://github.com/hoyeoplee/MeLU, however, I run into a runtime error while training the model. Can anyone suggest the possible reasons that can cause this error?
The model training code is as follows-
def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_filename=None):
if config['use_cuda']:
melu.cuda()
training_set_size = len(total_dataset)
melu.train()
for _ in range(num_epoch):
random.shuffle(total_dataset)
num_batch = int(training_set_size / batch_size)
a,b,c,d = zip(*total_dataset)
for i in range(num_batch):
try:
supp_xs = list(a[batch_size*i:batch_size*(i+1)])
supp_ys = list(b[batch_size*i:batch_size*(i+1)])
query_xs = list(c[batch_size*i:batch_size*(i+1)])
query_ys = list(d[batch_size*i:batch_size*(i+1)])
except IndexError:
continue
melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner'])
if model_save:
torch.save(melu.state_dict(), model_filename)
And the code for global update is as follows-
def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update):
batch_sz = len(support_set_xs)
losses_q = []
if self.use_cuda:
for i in range(batch_sz):
support_set_xs[i] = support_set_xs[i].cuda()
support_set_ys[i] = support_set_ys[i].cuda()
query_set_xs[i] = query_set_xs[i].cuda()
query_set_ys[i] = query_set_ys[i].cuda()
for i in range(batch_sz):
query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update)
loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
losses_q.append(loss_q)
losses_q = torch.stack(losses_q).mean(0)
self.meta_optim.zero_grad()
losses_q.backward()
self.meta_optim.step()
self.store_parameters()
return