ValueError('max_value must not be zero or nan during training or validation')

41 views Asked by At

I can see the max_value is nan in transforms.py

enter image description here

Screenshot 2023-10-27 at 10.14.45 PM.png

because of this I am getting a ValueError because of this exception.

enter image description here

-> maxval = [v.reshape(-1) for v in batch.max_value]

(Pdb) n

--Return--

/home/hpc/rzku/mlvl109h/cine-vn-vortex/cinevn/pl/varnet_module.py(317)()->[tensor([0.002...torch.float64)]

-> maxval = [v.reshape(-1) for v in batch.max_value]

while debugging I saw maxval is calculated like the above, I think it is fine.

I assume batch.max_value is returning nothing, in the transforms it is mentioned just that max_value: Maximum absolute image value. Nothing is happening to max_value in transforms. maxval = [v.reshape(-1) for v in batch.max_value] is only mentioned in varnet_module.

There is something wrong with the batch, but I am not sure what. Let me know if you can guess something. While debugging I also found the sanity check is properly executed(I used only 2 samples here, but I tried with the actual one also). enter image description here The code in varnet_module is:

def train_val_forward(self, batch):
        if torch.any(batch.max_value == 0) or torch.any(batch.max_value.isnan()):
            raise ValueError('max_value must not be zero or nan during training or validation')
        
        if mask is not None:
            mask = batch.mask
            # normalize
            norm_val = None
            kspace = batch.masked_kspace
            noised_kspace = batch.noised_kspace
            

            if self.normalize:
                norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]
            # forward through network
            output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
            # crop phase oversampling
            target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)
            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if norm_val is None:
                target_for_loss = target
                data_range = batch.max_value
            else:
                target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
                data_range = batch.max_value / norm_val

            #loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
            loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)
            # unnormalize output
            if norm_val is not None:
                output = output * norm_val[(...,) + (None,) * (output.ndim - 1)]  # don't use inplace operation here!







            if self.normalize:
                noised_norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                noised_kspace = noised_kspace / noised_norm_val[(...,) + (None,) * (noised_kspace.ndim - 1)]
            # forward through network
            noised_output = self(noised_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
            # crop phase oversampling
            target, noised_output = transforms.center_crop_to_smallest(batch.target, noised_output, ndim=2)
            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if noised_norm_val is None:
                noised_target_for_loss = target
                noised_data_range = batch.max_value
            else:
                noised_target_for_loss = target / noised_norm_val[(...,) + (None,) * (target.ndim - 1)]
                noised_data_range = batch.max_value / noised_norm_val
            #noised_loss = F.mse_loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), reduction='mean')
            #noised_loss = self.loss(noised_pred=noised_output.unsqueeze(1), noised_targ=noised_target_for_loss.unsqueeze(1), noised_data_range=noised_data_range)


            # unnormalize output
            if noised_norm_val is not None:
                noised_output = noised_output * noised_norm_val[(...,) + (None,) * (noised_output.ndim - 1)]  # don't use inplace operation here!
            
            # Calculate consistency loss
            consistency_loss = self.consistency_loss_fn(noised_output,output)
            
            # Add the consistency loss to the total loss
            loss += consistency_loss


            


        else:
            #loss_func = 0
            # normalize
            norm_val = None
            kspace = batch.masked_kspace

            if self.normalize:
                norm_val = batch.target.abs().flatten(1).max(dim=1).values
                # pytorch broadcasts from right to left, hence we need to expand the dimensions of norm_val manually
                kspace = kspace / norm_val[(...,) + (None,) * (kspace.ndim - 1)]

            # forward through network
            output = self(kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)

            # crop phase oversampling
            target, output = transforms.center_crop_to_smallest(batch.target, output, ndim=2)

            # normalize target similar to output for loss calculation and unsqueeze to add channel dimension
            if norm_val is None:
                target_for_loss = target
                data_range = batch.max_value
            else:
                target_for_loss = target / norm_val[(...,) + (None,) * (target.ndim - 1)]
                data_range = batch.max_value / norm_val
            loss = self.loss(pred=output.unsqueeze(1), targ=target_for_loss.unsqueeze(1), data_range=data_range)

            # unnormalize output
            if norm_val is not None:
                output = output * norm_val[(...,) + (None,) * (output.ndim - 1)]  # don't use inplace operation here!

            

        return target, output, loss

    
        

    def training_step(self, batch,batch_idx):
        _, _, loss = self.train_val_forward(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=False, sync_dist=True)
        return loss

    def validation_step(self, batch,batch_idx, dataloader_idx=0):
        target, output, loss = self.train_val_forward(batch)
        return {'output': output, 'target': target, 'val_loss': loss}

    def on_validation_batch_end(self, outputs, batch,batch_idx, dataloader_idx=0):
        if not isinstance(outputs, dict):
            raise RuntimeError('outputs must be a dict')
        # update metrics
        target = outputs['target'].abs()
        output = outputs['output'].abs()
        maxval = [v.reshape(-1) for v in batch.max_value]
        if batch.annotations.isnan().any():
            center = None
        else:
            center = [annotation[0].to(int) for annotation in batch.annotations]
        self.val_loss.update(outputs['val_loss'])
        self.nmse.update(batch.fname, batch.slice_num, target, output)
        self.ssim.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
        self.psnr.update(batch.fname, batch.slice_num, target, output, maxvals=maxval)
        self.hfen.update(batch.fname, batch.slice_num, target, output)
        if self.ssim_xt is not None and center is not None:
            self.ssim_xt.update(batch.fname, batch.slice_num, target, output, center, maxvals=maxval)

    def on_validation_epoch_end(self):
        # logging
        self.log('validation_loss', self.val_loss, prog_bar=True)
        #self.log('consistency loss', self.consistency_loss_fn, prog_bar=True)
        self.log('val_metrics/nmse', self.nmse)
        self.log('val_metrics/ssim', self.ssim, prog_bar=True)
        self.log('val_metrics/psnr', self.psnr)
        self.log('val_metrics/hfen', self.hfen)
        if self.ssim_xt is not None and self.ssim_xt._update_count > 0:
            self.log('val_metrics/ssim_xt', self.ssim_xt)

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        output = self(batch.masked_kspace, batch.mask, batch.num_low_frequencies, batch.sens_maps)
        output = transforms.batched_crop_to_recon_size(output, batch.header)
        return {'output': output}

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optim, self.lr_step_size, self.lr_gamma)

        return [optim], [scheduler]
0

There are 0 answers