I'm playing with some Data Parallel Haskell code and found myself in need of a prefix sum. However I didn't see any basic operator in the dph package for prefix sum.
I rolled my own, but, since I'm new to dph, I'm not sure if it's properly taking advantage of parallelization:
{-# LANGUAGE ParallelArrays #-}
{-# OPTIONS_GHC -fvectorise #-}
module PrefixSum ( scanP ) where
import Data.Array.Parallel (lengthP, indexedP, mapP, zipWithP, concatP, filterP, singletonP, sliceP, (+:+), (!:))
import Data.Array.Parallel.Prelude.Int ((<=), (-), (==), Int, mod)
-- hide prelude
import qualified Prelude
-- assuming zipWithP (a -> b -> c) given
-- [:a:] of length n and
-- [:b:] of length m, n /= m
-- will return
-- [:c:] of length min n m
scanP :: (a -> a -> a) -> [:a:] -> [:a:]
scanP f xs = if lengthP xs <= 1
then xs
else head +:+ tail
where -- [: x_0, x_2, ..., x_2n :]
evens = mapP snd . filterP (even . fst) $ indexedP xs
-- [: x_1, x_3 ... :]
odds = mapP snd . filterP (odd . fst) $ indexedP xs
lenEvens = lengthP evens
lenOdds = lengthP odds
-- calculate the prefix sums [:w:] of the pair sums [:z:]
psums = scanP f $ zipWithP f evens odds
-- calculate the total prefix sums as
-- [: x_0, w_0, f w_0 x_2, w_1, f w_1 x_4, ...,
head = singletonP (evens !: 0)
body = concatP . zipWithP (\p e -> [: p, f p e :]) psums $ sliceP 1 lenOdds evens
-- ending at either
-- ... w_{n-1}, f w_{n-1} x_2n :]
-- or
-- ... w_{n-1}, f w_{n-1} x_2n, w_n :]
-- depending on whether the length of [:x:] is 2n+1 or 2n+2
tail = if lenEvens == lenOdds then body +:+ singletonP (psums !: (lenEvens - 1)) else body
-- reimplement some of Prelude so it can be vectorised
f $ x = f x
infixr 0 $
(.) f g y = f (g y)
snd (a,b) = b
fst (a,b) = a
even n = n `mod` 2 == 0
odd n = n `mod` 2 == 1
Parallel prefix scans are supported, in fact, they're rather fundamental. So just pass
(+)
as your associative operator.