I can see the max_value is nan in transforms.py
Screenshot 2023-10-27 at 10.14.45 PM.png
because of this I am getting a ValueError because of this exception.
-> maxval = [v.reshape(-1) for v in batch.max_value]
(Pdb) n
--Return--
/home/hpc/rzku/mlvl109h/cine-vn-vortex/cinevn/pl/varnet_module.py(317)()->[tensor([0.002...torch.float64)]
-> maxval = [v.reshape(-1) for v in batch.max_value]
while debugging I saw maxval is calculated like the above, I think it is fine.
I assume batch.max_value is returning nothing, in the transforms it is mentioned just that max_value: Maximum absolute image value. Nothing is happening to max_value in transforms. maxval = [v.reshape(-1) for v in batch.max_value] is only mentioned in varnet_module.
There is something wrong with the batch, but I am not sure what. Let me know if you can guess something. While debugging I also found the sanity check is properly executed(I used only 2 samples here, but I tried with the actual one also). The code in varnet_module is:
def train_val_forward(self, batch):
if torch.any(batch.max_value == 0) or torch.any(batch.max_value.isnan()):
raise ValueError('max_value must not be zero or nan during training or validation')
if mask is not None:
mask = batch.mask
# normalize
norm_val = None
kspace = batch.masked_kspace
noised_kspace = batch.noised_kspace
if self.normalize:
norm_val = batch.target.abs().flatten(1).max(dim=1).values
# pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]
# forward through network
output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
# crop phase oversampling
target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)
# normalize target similar to output for loss calculation and unsqueeze to add channel dimension
if norm_val is None:
target_for_loss = target
data_range = batch.max_value
else:
target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
data_range = batch.max_value / norm_val
#loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)
# unnormalize output
if norm_val is not None:
output = output * norm_val[(...,) + (None,) * (output.ndim - 1)] # don't use inplace operation here!
if self.normalize:
noised_norm_val = batch.target.abs().flatten(1).max(dim=1).values
# pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
noised_kspace = noised_kspace / noised_norm_val[(...,) + (None,) * (noised_kspace.ndim - 1)]
# forward through network
noised_output = self(noised_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
# crop phase oversampling
target, noised_output = transforms.center_crop_to_smallest(batch.target, noised_output, ndim=2)
# normalize target similar to output for loss calculation and unsqueeze to add channel dimension
if noised_norm_val is None:
noised_target_for_loss = target
noised_data_range = batch.max_value
else:
noised_target_for_loss = target / noised_norm_val[(...,) + (None,) * (target.ndim - 1)]
noised_data_range = batch.max_value / noised_norm_val
#noised_loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
#noised_loss = self.loss(noised_pred=noised_output.unsqueeze(1), noised_targ=noised_target_for_loss.unsqueeze(1), noised_data_range=noised_data_range)
# unnormalize output
if noised_norm_val is not None:
noised_output = noised_output * noised_norm_val[(...,) + (None,) * (noised_output.ndim - 1)] # don't use inplace operation here!
# Calculate consistency loss
consistency_loss = self.consistency_loss_fn(noised_output,output)
# Add the consistency loss to the total loss
loss += consistency_loss
else:
#loss_func = 0
# normalize
norm_val = None
kspace = batch.masked_kspace
if self.normalize:
norm_val = batch.target.abs().flatten(1).max(dim=1).values
# pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]
# forward through network
output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
# crop phase oversampling
target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)
# normalize target similar to output for loss calculation and unsqueeze to add channel dimension
if norm_val is None:
target_for_loss = target
data_range = batch.max_value
else:
target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
data_range = batch.max_value / norm_val
loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)
# unnormalize output
if norm_val is not None:
output = output * norm_val[(...,) + (None,) * (output.ndim - 1)] # don't use inplace operation here!
return target, output, loss
def training_step(self, batch,batch_idx):
_, _, loss = self.train_val_forward(batch)
self.log('train_loss', loss, on_step=True, on_epoch=False, sync_dist=True)
return loss
def validation_step(self, batch,batch_idx, dataloader_idx=0):
target, output, loss = self.train_val_forward(batch)
return {'output': output, 'target': target, 'val_loss': loss}
def on_validation_batch_end(self, outputs, batch,batch_idx, dataloader_idx=0):
if not isinstance(outputs, dict):
raise RuntimeError('outputs must be a dict')
# update metrics
target = outputs['target'].abs()
output = outputs['output'].abs()
maxval = [v.reshape(-1) for v in batch.max_value]
if batch.annotations.isnan().any():
center = None
else:
center = [annotation[0].to(int) for annotation in batch.annotations]
self.val_loss.update(outputs['val_loss'])
self.nmse.update(batch.fname, batch.slice_num, target, output)
self.ssim.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
self.psnr.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
self.hfen.update(batch.fname, batch.slice_num, target, output)
if self.ssim_xt is not None and center is not None:
self.ssim_xt.update(batch.fname, batch.slice_num, target, output, center, maxvals=maxval)
def on_validation_epoch_end(self):
# logging
self.log('validation_loss', self.val_loss, prog_bar=True)
#self.log('consistency loss', self.consistency_loss_fn, prog_bar=True)
self.log('val_metrics/nmse', self.nmse)
self.log('val_metrics/ssim', self.ssim, prog_bar=True)
self.log('val_metrics/psnr', self.psnr)
self.log('val_metrics/hfen', self.hfen)
if self.ssim_xt is not None and self.ssim_xt._update_count > 0:
self.log('val_metrics/ssim_xt', self.ssim_xt)
def test_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch.masked_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
output = transforms.batched_crop_to_recon_size(output, batch.header)
return {'output': output}
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optim, self.lr_step_size, self.lr_gamma)
return [optim], [scheduler]