In my application, I need to take the nth order mixed derivative of a function. However, I found that the torch.autograd.grad computation time increases exponentially as n increases. Is this expected, and is there any way around it?
This is my code for differentiating a function self.F (from R^n -> R^1):
def differentiate(self, x):
x.requires_grad_(True)
xi = [x[...,i] for i in range(x.shape[-1])]
dyi = self.F(torch.stack(xi, dim=-1))
for i in range(self.dim):
start_time = time.time()
dyi = torch.autograd.grad(dyi.sum(), xi[i], retain_graph=True, create_graph=True)[0]
grad_time = time.time() - start_time
print(grad_time)
return dyi
And these are the times printed for each iteration of the above loop:
0.0037012100219726562
0.005133152008056641
0.008165121078491211
0.019922733306884766
0.059255123138427734
0.1910409927368164
0.6340939998626709
2.1612229347229004
11.042078971862793
I assume this is because the size of the computation graph is increasing? Is there any way around this? I thought I might be able to circumvent this issue by taking a functional approach (presumably obviating the need for a computation graph), using torch.func.grad. However, this actually increased the runtime of the same code! Am I not understanding torch.func.grad properly? Would a similar implementation in JAX provide any performance increase?
maybe the output just 6 elements, like softmax, it has 6 inputs and 6 outputs, but it's derivative is a Jacobian matrix with shape 6x6,if you try to get the second derivative, it will be more than 6x6x6 and it is more than a three-dim hessian matrix.
and you should check jax, it's example, like this https://github.com/HIPS/autograd/blob/master/examples/tanh.py
vector-to-vector gradient will be matrix or multi-dim matrix. so it's time used will increase exponentially.
or you can refer this softmax_https://zhuanlan.zhihu.com/p/657177292