I am trying to compile the following minimal example from Numeric.AD:
import Numeric.AD
timeAndGrad f l = grad f l
main = putStrLn "hi"
and I run into this error:
test.hs:3:24:
Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
s a)
-> Numeric.AD.Internal.Reverse.Reverse s a’
with actual type ‘t’
because type variable ‘s’ would escape its scope
This (rigid, skolem) type variable is bound by
a type expected by the context:
Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
f (Numeric.AD.Internal.Reverse.Reverse s a)
-> Numeric.AD.Internal.Reverse.Reverse s a
at test.hs:3:19-26
Relevant bindings include
l :: f a (bound at test.hs:3:15)
f :: t (bound at test.hs:3:13)
timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
In the first argument of ‘grad’, namely ‘f’
In the expression: grad f l
Any clue as to why this is happening? From looking at previous examples I gather that this is "flattening" grad
's type:
grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a
but I actually need to do something like this in my code. In fact, this is the most minimal example that won't compile. The more complicated thing I want to do is something like this:
example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
where gradient = grad f x
gradientFn = grad f
(other where clauses involving gradient and gradient "function")
Here's a slightly more complicated version with type signatures that does compile.
{-# LANGUAGE RankNTypes #-}
import Numeric.AD
import Numeric.AD.Internal.Reverse
-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l
-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
where f' = Lift . f . extractAll
-- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work
extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
where extract (Lift x) = x -- non-exhaustive pattern match
dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)
-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]
However, I can't figure out how to use the first version, grad2
, in code because I don't know how to deal with Reverse s a
. The second version, grad2'
, has the right type because I use the internal constructor Lift
to create a Reverse s a
, but I must not be understanding how the internals (specifically the parameter s
) works, because the output gradient is all 0s. Using the other constructor Reverse
(not shown here) also produces the wrong gradient.
Alternatively, are there examples of libraries/code where people have used the ad
code? I think my use case is a very common one.
With
where f' = Lift . f . extractAll
you essentially create a back door into the automatic-differentiation underlying type that throws away all the derivatives and only keeps the constant-values. If you then use this forgrad
, it's hardly surprising that you get a zero result!The sensible way is to just use
grad
as it is:You don't really need to know anything more complicated to use automatic differentiation. As long as you only differentiate
Num
orFloating
-polymorphic functions, everything will work as-is. If you need to differentiate a function that's passed in as an argument, you need to make that argument rank-2 polymorphic (an alternative would be to switch to the rank-1 version of thead
functions, but I daresay that is less elegant and doesn't really gain you much).