SymPy - Can't calculate derivative wrt expression, is there an alternative for intermediate expressions?

31 views Asked by At

I am going through Andrej Karpathy's video "The spelled-out intro to neural networks and backpropagation: building micrograd", and in the "manual backpropagation example #1: simple expression" chapter he builds up an example expression using his demo library and then manually steps through a back propagation through that expression. The 'micrograd' library that he's building in the video stores expressions as a graph and is also an auto differentiation engine and stores the intermediate gradient at each node in the graph so it can be easily available for applying the chain rule.

I thought it would be fun to have a SymPy version of his example to see not just the numeric gradient at each step but also the symbolic derivative at each node of the graph, and to start out with I just wanted to do the basics in SymPy, but I'm not sure how I can get the derivatives of some of the intermediate expressions using SymPy.diff()

Here's the code:

(a, b, c, d, e, f, L) = sp.symbols('a b c d e f L')
e = a * b
d = e + c
L = d * f

sublist = {a: 2.0, b:-3.0, c:10.0, f:-2.0}
dLda = sp.diff(L, a)
dLdb = sp.diff(L, b)
dLdf = sp.diff(L, f)
dLdc = sp.diff(L, c)

dddc = sp.diff(d, c)

print(L)
print(L.subs(sublist))
print(dLdc)
print(dLdf)
print(dddc)
#dLdc = dLdd * dddc
dLdd = sp.diff(L, d)  # should get 'f' 

And I get:

f*(a*b + c)
-8.00000000000000
f
a*b + c
1

ValueError                                Traceback (most recent call last)
Cell In[48], line 20
     18 print(dddc)
     19 #dLdc = dLdd * dddc
---> 20 dLdd = sp.diff(L, d)  # should print 'f' 
<...>
ValueError: 
Can't calculate derivative wrt a*b + c.

I get that normally this wouldn't make sense because these intermediate "nodes" like d aren't places where in an actual network there would be any adjustments to values (ie if this were an actual neural network I would only be 'nudging' a,b,f, and c), but if you are stepping through each node to keep track of the derivative thus far I don't think trying to have dLdd as a thing is unreasonable, and it'd be nice to be able to have SymPy have that. Again just to reemphasize I'm not planning on trying to use SymPy to actually do these calculations, it's just for illustrating what's happening under the covers.

I've tried running this a couple of different ways using with sp.evaluate(False): but I didn't have any luck, but maybe I'm missing something and there's some way to actually keep 'd' and 'e' around without SymPy simplifying them away.

1

There are 1 answers

0
ti7 On

This looks like a SymPy gotcha - when you write this, d is not longer a simple Symbol and the name is given to the expression of e and c (you'll get the same issue if you try to use e)

d = e + c  # d Symbol lost and name re-used
L = d * f

Instead use Eq() or subtract 'em and give that new result a different name

e_ = a  * b
d_ = e_ + c
L  = d_ * f

sublist = {a: 2.0, b:-3.0, c:10.0, f:-2.0}
dLda = sp.diff(L, a)
dLdb = sp.diff(L, b)
dLdf = sp.diff(L, f)
dLdc = sp.diff(L, c)

dddc = sp.diff(d_, c)