How does pytorch compute derivatives for simple functions?

1.6k views Asked by At

When we talk about the auto-differentiation in the pytorch, we are usually presented a graphical structures of tensors based on their formulas, and pytorch will compute the gradients by tracing down the graphical tree using chain rules. However, I want to know what will happen at the leaf nodes? Does pytorch hardcode a whole list of basic functions with their analytical derivatives, or does it compute the gradients using numerical methods? A quick example:

import torch

def f(x):
    return x ** 2
x = torch.tensor([1.0], requires_grad=True)
y = f(x)
y.backward()
print(x.grad) # 2.0

In this example, does pytorch compute the derivative by $$ (x^2)' = 2x = 2 * 1 = 2 $$, or does pytorch compute in a way similar to $$ (1.00001^2 - 1^2) / (1.000001 - 1) ~ 2 $$ ?

Thanks!

1

There are 1 answers

0
Hashem Ghanem On

See this paper for exact answer, specifically section 2.1 or figure 2.

In short, PyTorch has a list of basic functions and the expression of their derivatives. So, what is done in your case (y =xx), is evaluating $$ y' = 2x $$.

The numerical method you mentioned is called numerical differentiation or finite differences, and it is an approximation of the derivative. But it is not what PyTorch does.