PyTorch not updating weights when using autograd in loss function

1k views Asked by At

I am trying to use the gradient of a network with respect to its inputs as part of my loss function. However, whenever I try to calculate it, the training proceeds but the weights do not update

import torch
import torch.optim as optim
import torch.autograd as autograd


ic = torch.rand((25, 3))
ic = torch.tensor(ic, requires_grad=True)
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5*torch.stack(100*[ic])) # simplified for minimal working example
    
    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx, 
                           inputs=ic,
                           grad_outputs = torch.ones(ic.shape[0]), # batchwise
                           retain_graph=True
                          )
    dxdxy = torch.tensor(dxdxy, requires_grad=True)
    loss = torch.sum(dxdxy)
    
    loss.backward()
    optimizer.step()
    
    if itr % 5 == 0:
        print(loss)

What am I doing wrong?

1

There are 1 answers

0
Gil Pinsky On BEST ANSWER

When you run autograd.grad without setting flag create_graph to True then you won't obtain an output which is connected to the computational graph, which means that you won't be able to further optimize w.r.t ic (and obtain a higher order derivative as you wish to do here). From torch.autograd.grad's docstring:

create_graph (bool, optional): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False.

Using dxdxy = torch.tensor(dxdxy, requires_grad=True) as you've tried here won't help since the computational graph which is connected to ic has been lost by then (since create_graph is False), and all you do is create a new computational graph with dxdxy being a leaf node.

See the solution attached below (note that when you create ic you can set requires_grad=True and hence the second line is redundant (that's not a logical problem but just longer code):

import torch
import torch.optim as optim
import torch.autograd as autograd

ic = torch.rand((25, 3),requires_grad=True) #<-- requires_grad to True here
#ic = torch.tensor(ic, requires_grad=True) #<-- redundant
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5 * torch.stack(100 * [ic]))  # simplified for minimal working example

    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx,
                           inputs=ic,
                           grad_outputs=torch.ones(ic.shape[0]),  # batchwise
                           retain_graph=True, create_graph=True # <-- important
                           )
    #dxdxy = torch.tensor(dxdxy, requires_grad=True) #<-- won't do the trick. Remove
    loss = torch.sum(dxdxy)

    loss.backward()
    optimizer.step()

    if itr % 5 == 0:
        print(loss)