Linear combinations of Zygote.Grads

197 views Asked by At

I am building and training a neural network model with Flux, and I am wondering if there is a way to take linear combinations of Zygote.Grads types.

Here is a minimalistic example. This is how it is typically done:

m = hcat(2.0); b = hcat(-1.0);  # random 1 x 1 matrices

f(x) = m*x .+ b
ps = Flux.params(m, b)  # parameters to be adjusted

inputs = [0.3 1.5]  # random 1 x 2 matrix

loss(x) = sum( f(x).^2 )

gs = Flux.gradient(() -> loss(inputs), ps)  # the typical way
@show gs[m], gs[b]  # 5.76, 3.2

But I want to do the same calculation by computing gradients at a deeper level, and then assembling it at the end. For example:

input1 = hcat(inputs[1, 1]); input2 = hcat(inputs[1, 2]);  # turn each input into a 1 x 1 matrix

grad1 = Flux.gradient(() -> f(input1)[1], ps)  # df/dp using input1 (where p is m or b)
grad2 = Flux.gradient(() -> f(input2)[1], ps)  # df/dp using input2 (where p is m or b)

predicted1 = f(input1)[1]
predicted2 = f(input2)[1]

myGrad_m = (2 * predicted1 * grad1[m]) + (2 * predicted2 * grad2[m])  # 5.76
myGrad_b = (2 * predicted1 * grad1[b]) + (2 * predicted2 * grad2[b])  # 3.2

Above, I used the chain rule and linearity of the derivative to decompose the gradient of the loss() function:

d(loss)/dp = d( sum(f^2) ) / dp = sum( d(f^2)/dp ) = sum( 2*f * df/dp )

Then, I calculated df/dp using Zygote.gradient, and then combined the results at the end.

But notice that I had to combine m and b separately. This was fine because there were only 2 parameters.

However, if there were a 1000 parameters, I would want to do something like this, which is a linear combination of the Zygote.Grads:

myGrad = (2 * predicted1 * grad1) + (2 * predicted2 * grad2)

But, I get an error saying that the + and * operators are not defined for these types. How can I get this shortcut to work?

1

There are 1 answers

0
darsnack On BEST ANSWER

Just turn each */+ into .*/.+ (i.e. use broadcasting) or you can use map to apply a function to multiple Grads at once. This is described in the Zygote docs here. Note that in order for this to work, all the Grads must share the same keys (so they must correspond to the same parameters).