Using Autograd .backward() function to calculate an intermediate value in the forward pass of Pytorch model

561 views Asked by At

Hello I am new to Pytorch. I have a simple pytorch module where the output of the module is scalar loss function that depends on the derivative of some polynomial functions. Let's say the output of the forward pass is: input*derivative(x^2+y^2).

One way to implement this, is to explicitly write down the deriviates of the polynomials used and have that be part of the forward model. So output=inputs*(2x+2y). However, this is not robust as if I include more polynomials, I have to manually add more derivative functions which can be time consuming and prone to errors.

I want to initialize the polynomials, use Autograd to get their derivatives, plug that derivative into the output formula. let's say the polynomial function is called n. I do n.backward(retain_graph=True) inside the forward pass. However, it does not seem to work properly as I get very different answers (of the magnitude of the derivatives of loss function vs the model parameters) as when I use the analytic expression in the forward pass.

Note, that both the output of the f.backward and the analytic expression of the derivative match. So it is computing the derivative of the polynomials correctly, but it is having a hard time associating this with the final loss function. Meaning that the backward() call is also messing up the model parameters while it is trying to get the derivatives for the polynomial coefficients. I am sure this is because of my poor understanding of pytorch and adding the f.backward() call inside the forward pass is somehow messing up the loss.backward() call.

Here is a simplified example: The problem is that the value model.learn.grad is not the same when using the analytic method and the autograd .backward() method


class Model(nn.Module):
   
    def __init__(self, grin_type='radial',loss_type='analytic', device='cpu', dtype=torch.float32): 
        super(Model, self).__init__()
        
        self.dtype=dtype
        self.device=device
        self.loss_type=loss_type
        self.grin_type=grin_type
        self.x=torch.tensor(2.,dtype = dtype, device=self.device) #mm
        
        self.learn=nn.Parameter(torch.tensor(5.,dtype = dtype, device=self.device))
        
    def forward(self,inputs,plotting=0): 
        
        
        if self.loss_type=='analytic':
        
            outputs=inputs*self.learn*(2.*self.x)
            
        elif self.loss_type=='autograd':
                                    
            self.der=self.calc_derivative(self.x)
            outputs=inputs*self.der
                    
        return outputs
    
    
    def poly_fun(self,x):
        
        return self.learn*torch.square(x)
    
    
    def calc_derivative(self,x):
        xn=x.clone().detach().requires_grad_(True)
        n=self.poly_fun(xn)
            
        dloss_dx=torch.autograd.grad(outputs=n, inputs=xn,create_graph=True)[0]*n/n
        
        return dloss_dx
       
0

There are 0 answers