Higher-order derivatives in PyTorch increase in computation time exponentially

108 views Asked by At

In my application, I need to take the nth order mixed derivative of a function. However, I found that the torch.autograd.grad computation time increases exponentially as n increases. Is this expected, and is there any way around it?

This is my code for differentiating a function self.F (from R^n -> R^1):

def differentiate(self, x):
    x.requires_grad_(True)
    xi = [x[...,i] for i in range(x.shape[-1])]
    dyi = self.F(torch.stack(xi, dim=-1))
    for i in range(self.dim):
        start_time = time.time()
        dyi = torch.autograd.grad(dyi.sum(), xi[i], retain_graph=True, create_graph=True)[0]
        grad_time = time.time() - start_time
        print(grad_time)
    return dyi

And these are the times printed for each iteration of the above loop:

0.0037012100219726562
0.005133152008056641
0.008165121078491211
0.019922733306884766
0.059255123138427734
0.1910409927368164
0.6340939998626709
2.1612229347229004
11.042078971862793

I assume this is because the size of the computation graph is increasing? Is there any way around this? I thought I might be able to circumvent this issue by taking a functional approach (presumably obviating the need for a computation graph), using torch.func.grad. However, this actually increased the runtime of the same code! Am I not understanding torch.func.grad properly? Would a similar implementation in JAX provide any performance increase?

1

There are 1 answers

1
Jiu_Zou On

maybe the output just 6 elements, like softmax, it has 6 inputs and 6 outputs, but it's derivative is a Jacobian matrix with shape 6x6,if you try to get the second derivative, it will be more than 6x6x6 and it is more than a three-dim hessian matrix.

and you should check jax, it's example, like this https://github.com/HIPS/autograd/blob/master/examples/tanh.py

vector-to-vector gradient will be matrix or multi-dim matrix. so it's time used will increase exponentially.

Mathematically we can only take gradients of scalar-valued functions, but
autograd's elementwise_grad function also handles numpy's familiar vectorization
of scalar functions, which is used in this example.

To be precise, elementwise_grad(fun)(x) always returns the value of a
vector-Jacobian product, where the Jacobian of fun is evaluated at x and the
vector is an all-ones vector with the same size as the output of fun. When
vectorizing a scalar-valued function over many arguments, the Jacobian of the
overall vector-to-vector mapping is diagonal, and so this vector-Jacobian
product simply returns the diagonal elements of the Jacobian, which is the
(elementwise) gradient of the function at each input value over which the
function is vectorized.

or you can refer this softmax_https://zhuanlan.zhihu.com/p/657177292