minimal Numeric.AD example won't compile

166 views Asked by At

I am trying to compile the following minimal example from Numeric.AD:

import Numeric.AD 
timeAndGrad f l = grad f l
main = putStrLn "hi"

and I run into this error:

test.hs:3:24:
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
                                       s a)
                                  -> Numeric.AD.Internal.Reverse.Reverse s a’
                with actual type ‘t’
      because type variable ‘s’ would escape its scope
    This (rigid, skolem) type variable is bound by
      a type expected by the context:
        Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
        f (Numeric.AD.Internal.Reverse.Reverse s a)
        -> Numeric.AD.Internal.Reverse.Reverse s a
      at test.hs:3:19-26
    Relevant bindings include
      l :: f a (bound at test.hs:3:15)
      f :: t (bound at test.hs:3:13)
      timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
    In the first argument of ‘grad’, namely ‘f’
    In the expression: grad f l

Any clue as to why this is happening? From looking at previous examples I gather that this is "flattening" grad's type:

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

but I actually need to do something like this in my code. In fact, this is the most minimal example that won't compile. The more complicated thing I want to do is something like this:

example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
    where gradient = grad f x
          gradientFn = grad f
          (other where clauses involving gradient and gradient "function")

Here's a slightly more complicated version with type signatures that does compile.

{-# LANGUAGE RankNTypes #-}

import Numeric.AD 
import Numeric.AD.Internal.Reverse

-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l

-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
       where f' = Lift . f . extractAll
       -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work

extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
           where extract (Lift x) = x -- non-exhaustive pattern match

dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)

-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]

However, I can't figure out how to use the first version, grad2, in code because I don't know how to deal with Reverse s a. The second version, grad2', has the right type because I use the internal constructor Lift to create a Reverse s a, but I must not be understanding how the internals (specifically the parameter s) works, because the output gradient is all 0s. Using the other constructor Reverse (not shown here) also produces the wrong gradient.

Alternatively, are there examples of libraries/code where people have used the ad code? I think my use case is a very common one.

1

There are 1 answers

2
leftaroundabout On BEST ANSWER

With where f' = Lift . f . extractAll you essentially create a back door into the automatic-differentiation underlying type that throws away all the derivatives and only keeps the constant-values. If you then use this for grad, it's hardly surprising that you get a zero result!

The sensible way is to just use grad as it is:

dist :: Floating a => [a] -> a
dist [x, y] = sqrt $ x^2 + y^2
-- preferrable is of course `dist = sqrt . sum . map (^2)`

main = print $ grad dist [1,2]
-- output: [0.4472135954999579,0.8944271909999159]

You don't really need to know anything more complicated to use automatic differentiation. As long as you only differentiate Num or Floating-polymorphic functions, everything will work as-is. If you need to differentiate a function that's passed in as an argument, you need to make that argument rank-2 polymorphic (an alternative would be to switch to the rank-1 version of the ad functions, but I daresay that is less elegant and doesn't really gain you much).

{-# LANGUAGE Rank2Types, UnicodeSyntax #-}

mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
mainWith f = print $ grad f [1,2]

main = mainWith dist