Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

traverse Array with Maybe more quickly #142

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
288 changes: 285 additions & 3 deletions Data/Primitive/Array.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE CPP, MagicHash, UnboxedTuples, DeriveDataTypeable, BangPatterns #-}
{-# LANGUAGE CPP, MagicHash, UnboxedTuples, DeriveDataTypeable, BangPatterns, ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

Expand Down Expand Up @@ -59,9 +59,7 @@ import qualified GHC.ST as GHCST
import qualified Data.Foldable as F
import Data.Semigroup
#endif
#if MIN_VERSION_base(4,8,0)
import Data.Functor.Identity
#endif
#if MIN_VERSION_base(4,10,0)
import GHC.Exts (runRW#)
#elif MIN_VERSION_base(4,9,0)
Expand All @@ -74,6 +72,12 @@ import Text.ParserCombinators.ReadP
import Data.Functor.Classes (Eq1(..),Ord1(..),Show1(..),Read1(..))
#endif

import Data.Functor.Compose
import Control.Monad (join)
import Control.Monad.Trans.Maybe (MaybeT(MaybeT,runMaybeT))
import Control.Monad.Trans.State.Strict (StateT(StateT),State,runStateT)
import Control.Monad.Trans.Reader (ReaderT(ReaderT,runReaderT),Reader)

-- | Boxed arrays
data Array a = Array
{ array# :: Array# a }
Expand Down Expand Up @@ -519,11 +523,36 @@ traverseArray f = \ !ary ->
else runSTA len <$> go 0
{-# INLINE [1] traverseArray #-}

-- Note on rewrite rules for traverse. Some types admit a traversal that
-- outperforms the general one that works with all applicatives. Such types
-- include IO and ST as well as any type constructed by layering sufficiently
-- affine monad transformers on top of IO or ST. This also includes the types
-- that correspond to such monad transformers.
--
-- For example, MaybeT is sufficiently affine. Consequently, for
-- 'MaybeT (ST s)' and 'MaybeT IO', the traversal offered by traverseArrayP is
-- semantically equivalent to traverseArray, but its tail-recursiveness
-- and lack of closure allocations mean that it performs better. This also
-- gives us a faster traversal for 'Maybe', since we can hoist an arbitrary
-- 'Maybe' into 'MaybeT (ST s)', perform the faster traversal, and then run
-- the effectful computaton to get back to a 'Maybe'.
--
-- Rewrite rule are not provided for the lazy State type or for any variant
-- of Writer. Use of these types is the types is likely to build up thunks
-- on the heap anyway.
{-# RULES
"traverse/ST" forall (f :: a -> ST s b). traverseArray f =
traverseArrayP f
"traverse/IO" forall (f :: a -> IO b). traverseArray f =
traverseArrayP f
"traverse/Maybe" forall (f :: a -> Maybe b). traverseArray f =
(\xs -> runST (runMaybeT (traverseArrayP (MaybeT . return . f) xs)))
"traverse/Either" forall (f :: a -> Either e b). traverseArray f =
traverseEither f
"traverse/State" forall (f :: a -> State s b). traverseArray f =
(\xs -> StateT (\s0 -> Identity (runST (runStateT (traverseArrayP (hoistState . f) xs) s0))))
"traverse/Reader" forall (f :: a -> Reader r b). traverseArray f =
(\xs -> ReaderT (\s0 -> Identity (runST (runReaderT (traverseArrayP (hoistReader . f) xs) s0))))
#-}
#if MIN_VERSION_base(4,8,0)
{-# RULES
Expand All @@ -533,6 +562,259 @@ traverseArray f = \ !ary ->
#-}
#endif


{-# RULES
"traverse/MaybeT/init" forall (f :: a -> MaybeT m b). traverseArray f = runTraverseMonad (initMaybeT (traverseArray f) f)
"traverse/MaybeT/pop" forall (t :: TraverseMonad p n (MaybeT m) a b). runTraverseMonad t = runTraverseMonad (popMaybeT t)
"traverse/IO/run" forall (t :: TraverseMonad p n IO a b). runTraverseMonad t = finalizeTraverseMonadIO t
"traverse/ST/run" forall (t :: TraverseMonad p n (ST s) a b). runTraverseMonad t = finalizeTraverseMonadST t
#-}

-- "traverse/Maybe/run" forall (t :: TraverseMonad p n Maybe a b). runTraverseMonad t = (\xs -> runST (getCompose (finalizeTraverseMonadST (finalMaybe t) xs)))

-- type variables: full, inner (starts as empty), outer (starts with everything), from, to
data TraverseMonad p n m a b = TraverseMonad
(Array a -> p (Array b))
-- original traversal, used if we run into an unrecognized monad or monad transformer
(a -> p b)
-- original traverse function
(forall x. (forall y z. (y -> z) -> m y -> m z) -> m (n x) -> p x)
-- given an implementation of fmap for m, convert the split stack back to the original type
(forall x. (forall y z. (y -> z) -> m y -> m z) -> p x -> m (n x))
-- convert original type to split stack
(forall x y. (forall z. z -> m z) -> (forall w z. m w -> (w -> m z) -> m z) -> m (n x) -> (x -> m (n y)) -> m (n y))
-- lift monadic bind, needs both pure and bind of underlying monad
(forall x. (forall y z. (y -> z) -> m y -> m z) -> m x -> m (n x))
-- lift, given fmap
-- (forall s x. (forall f y. Applicative f => m (f y) -> f (m y)) -> m (n (ST s x)) -> ST s (m (n x)))
(forall s x y. (forall z. z -> (m z)) -> (forall w z s'. ST s' (m w) -> (w -> ST s' (m z)) -> (ST s' (m z))) -> ST s (m (n x)) -> (x -> ST s (m (n y))) -> ST s (m (n y)))
-- traverse in ST
-- this is needed to make base monads other than IO or ST (like Maybe) work

runTraverseMonad :: TraverseMonad p n m a b -> Array a -> p (Array b)
runTraverseMonad (TraverseMonad f _ _ _ _ _ _) = f
{-# NOINLINE[1] runTraverseMonad #-}

initMaybeT ::
(Array a -> MaybeT m (Array b))
-> (a -> MaybeT m b)
-> TraverseMonad (MaybeT m) Maybe m a b
initMaybeT f t = TraverseMonad f t (\_ -> MaybeT) (\_ -> runMaybeT)
(\pure' bind' m g -> bind' m $ \mx -> case mx of
Nothing -> pure' Nothing
Just x -> g x
)
(\fmap' x -> fmap' Just x)
(\pure' bind' m g -> bind' m $ \mx -> case mx of
Nothing -> return (pure' Nothing)
Just x -> g x
)

popMaybeT ::
TraverseMonad p n (MaybeT m) a b
-> TraverseMonad p (Compose Maybe n) m a b
popMaybeT (TraverseMonad f t trans transBack liftBind lift' liftBindST) = TraverseMonad f t
(\fmap' x -> trans (liftMapMaybeT fmap') (MaybeT (fmap' getCompose x)))
(\fmap' x -> fmap' Compose (runMaybeT (transBack (liftMapMaybeT fmap') x)))
(\pure' bind' m g -> fmapFromPureBind pure' bind' Compose
(runMaybeT (liftBind (liftPureMaybeT pure') (liftBindMaybeT pure' bind') (MaybeT (fmapFromPureBind pure' bind' getCompose m)) (\x -> MaybeT (fmapFromPureBind pure' bind' getCompose (g x)))))
)
(\fmap' x -> fmap' Compose (runMaybeT (lift' (liftMapMaybeT fmap') (MaybeT (fmap' Just x)))))
(\pure' bind' m g -> fmapFromPureBindST pure' bind' Compose
(fmap runMaybeT (liftBindST (liftPureMaybeT pure') (liftBindMaybeT_ST pure' bind') (fmap MaybeT (fmapFromPureBindST pure' bind' getCompose m)) (\x -> fmap MaybeT (fmapFromPureBindST pure' bind' getCompose (g x)))))
)

finalMaybe ::
TraverseMonad p n Maybe a b
-> TraverseMonad (Compose (ST s) p) (Compose Maybe n) (ST s) a b
finalMaybe (TraverseMonad f t trans transBack liftBind lift' liftBindST) = TraverseMonad
(\arr -> Compose (return (f arr)))
(\a -> Compose (return (t a)))
(\_ x -> Compose (fmap (\(Compose mn) -> trans fmap mn) x))
(\_ (Compose x) -> fmap (\p -> Compose (transBack fmap p)) x)
(\_ _ v g -> do
Compose mn <- v
r <- liftBindST pure bindMaybeST (return mn) (\y -> fmap getCompose (g y))
return (Compose r)
)
(\_ x -> fmap (Compose . lift' fmap . Just) x)
(\_ _ m g -> return (fmap Compose (liftBindST Just bindMaybeST (fmap getCompose (join m)) (\y -> fmap getCompose (join (g y))))))

bindMaybeST :: ST s (Maybe a) -> (a -> ST s (Maybe b)) -> ST s (Maybe b)
bindMaybeST sm g = do
m <- sm
case m of
Nothing -> pure Nothing
Just a -> g a

-- finalMaybe ::
-- TraverseMonad p n Maybe a b
-- -> TraverseMonad (Compose (ST s) p) (Compose Maybe n) (ST s) a b
-- finalMaybe (TraverseMonad f t trans transBack liftBind lift' trav) = TraverseMonad
-- (\arr -> Compose (return (f arr)))
-- (\a -> Compose (return (t a)))
-- (\_ x -> Compose (fmap (\(Compose mn) -> trans fmap mn) x))
-- (\_ (Compose x) -> fmap (\p -> Compose (transBack fmap p)) x)
-- (\_ _ v g -> do
-- Compose mn <- v
-- let y = fmapTwiceFromPureBind (lift' fmap . Just) (liftBind pure (>>=)) g mn
-- fmap (Compose . joinTwiceFromBind (liftBind pure (>>=)) . fmapTwiceFromPureBind (lift' fmap . Just) (liftBind pure (>>=)) getCompose) (trav sequenceA y)
-- )
-- (\_ x -> fmap (Compose . lift' fmap . Just) x)
-- (\_ -> error "uheotn")

liftPureMaybeT :: (forall a. a -> m a) -> b -> MaybeT m b
liftPureMaybeT pure' = MaybeT . pure' . Just

liftPureMaybeT_ST :: (forall a. a -> m a) -> b -> ST s (MaybeT m b)
liftPureMaybeT_ST pure' = return . MaybeT . pure' . Just

liftMapMaybeT ::
(forall a b. (a -> b) -> m a -> m b)
-> (x -> y) -> MaybeT m x -> MaybeT m y
liftMapMaybeT fmap' f (MaybeT m) = MaybeT (fmap' (fmap f) m)

liftBindMaybeT ::
(forall a. a -> m a)
-> (forall a b. m a -> (a -> m b) -> m b)
-> MaybeT m x -> (x -> MaybeT m y) -> MaybeT m y
liftBindMaybeT pure' bind' (MaybeT m) g = MaybeT $ bind' m $ \mx -> case mx of
Nothing -> pure' Nothing
Just x -> runMaybeT (g x)

liftBindMaybeT_ST ::
(forall a. a -> m a)
-> (forall a b. ST s (m a) -> (a -> ST s (m b)) -> ST s (m b))
-> ST s (MaybeT m x) -> (x -> ST s (MaybeT m y)) -> ST s (MaybeT m y)
liftBindMaybeT_ST pure' bind' sma g = fmap MaybeT $ bind' (fmap runMaybeT sma) $ \ma -> case ma of
Nothing -> return (pure' Nothing)
Just a -> fmap runMaybeT (g a)

fmapFromPureBind ::
(forall x. x -> m x)
-> (forall x y. m x -> (x -> m y) -> m y)
-> (a -> b) -> m a -> m b
fmapFromPureBind pure' bind' f ma = bind' ma (\z -> pure' (f z))

fmapFromPureBindST ::
(forall x. x -> m x)
-> (forall x y. ST s (m x) -> (x -> ST s (m y)) -> ST s (m y))
-> (a -> b) -> ST s (m a) -> ST s (m b)
fmapFromPureBindST pure' bind' f ma = bind' ma (\z -> return (pure' (f z)))

fmapTwiceFromPureBind ::
(forall x. x -> m (n x))
-> (forall x y. m (n x) -> (x -> m (n y)) -> m (n y))
-> (a -> b) -> m (n a) -> m (n b)
fmapTwiceFromPureBind pure' bind' f ma = bind' ma (\z -> pure' (f z))

joinTwiceFromBind ::
(forall x y. m (n x) -> (x -> m (n y)) -> m (n y))
-> m (n (m (n a)))
-> m (n a)
joinTwiceFromBind bind' ma = bind' ma id


finalizeTraverseMonadIO :: forall p n a b. TraverseMonad p n IO a b -> Array a -> p (Array b)
finalizeTraverseMonadIO (TraverseMonad _ f trans transBack liftBind lift' _) = \ !ary ->
trans fmap
( let
!sz = sizeofArray ary
go :: Int -> MutableArray RealWorld b -> IO (n (Array b))
go !i !mary
| i == sz = lift' fmap (unsafeFreezeArray mary)
| otherwise =
liftBind pure (>>=) (lift' fmap (indexArrayM ary i)) $ \a ->
liftBind pure (>>=) (transBack fmap (f a)) $ \b ->
liftBind pure (>>=) (lift' fmap (writeArray mary i b)) $ \_ ->
go (i + 1) mary
in liftBind pure (>>=) (lift' fmap (newArray sz badTraverseValue)) $ \mary ->
go 0 mary
)
{-# INLINE finalizeTraverseMonadIO #-}

finalizeTraverseMonadST :: forall s p n a b. TraverseMonad p n (ST s) a b -> Array a -> p (Array b)
finalizeTraverseMonadST (TraverseMonad _ f trans transBack liftBind lift' _) = \ !ary ->
trans fmap
( let
!sz = sizeofArray ary
go :: Int -> MutableArray s b -> ST s (n (Array b))
go !i !mary
| i == sz = lift' fmap (unsafeFreezeArray mary)
| otherwise =
liftBind pure (>>=) (lift' fmap (indexArrayM ary i)) $ \a ->
liftBind pure (>>=) (transBack fmap (f a)) $ \b ->
liftBind pure (>>=) (lift' fmap (writeArray mary i b)) $ \_ ->
go (i + 1) mary
in liftBind pure (>>=) (lift' fmap (newArray sz badTraverseValue)) $ \mary ->
go 0 mary
)
{-# INLINE finalizeTraverseMonadST #-}



-- finalizeTraverseMonadMaybe :: forall p n a b. TraverseMonad p n Maybe a b -> Array a -> p (Array b)
-- finalizeTraverseMonadMaybe (TraverseMonad _ f trans transBack liftBind lift') = \ !ary ->
-- runST
-- ( let
-- !sz = sizeofArray ary
-- go :: Int -> MutableArray s b -> ST s (Maybe (n (Array b)))
-- go !i !mary
-- | i == sz = do
-- result <- unsafeFreezeArray mary
-- return (lift' fmap (Just result))
-- | otherwise = case indexArray## ary i of
-- (# a #) ->
-- liftBind pure (>>=) (transBack fmap (f a)) $ \b ->
-- liftBind pure (>>=) (lift' fmap (writeArray mary i b)) $ \_ ->
-- go (i + 1) mary
-- in do mary <- newArray sz badTraverseValue
-- mnary <- go 0 mary
-- return (trans fmap mnary)
-- )
-- {-# INLINE finalizeTraverseMonadMaybe #-}

-- This is only used internally in a rewrite rule. Ideally, this function
-- would live in transformers.
hoistState :: Monad m => State s a -> StateT s m a
hoistState (StateT f) = StateT (return . runIdentity . f)
{-# INLINE hoistState #-}

-- This is only used internally in a rewrite rule. Ideally, this function
-- would live in transformers.
hoistReader :: Monad m => Reader r a -> ReaderT r m a
hoistReader (ReaderT f) = ReaderT (return . runIdentity . f)
{-# INLINE hoistReader #-}

-- This is required for Either's rewrite rule. It would be
-- much more concise to use ExceptT just like we use the
-- other monad transformers in the other rewrite rules, but
-- ExceptT isn't available on older versions of transformers.
traverseEither :: forall e a b.
(a -> Either e b)
-> Array a
-> Either e (Array b)
traverseEither f = \ !ary ->
let
!sz = sizeofArray ary
go :: forall s. Int -> MutableArray s b -> ST s (Either e (Array b))
go !i !mary
| i == sz = do
r <- unsafeFreezeArray mary
return (Right r)
| otherwise = do
a <- indexArrayM ary i
case f a of
Left e -> return (Left e)
Right b -> do
writeArray mary i b
go (i + 1) mary
in runST $ do
mary <- newArray sz badTraverseValue
go 0 mary
{-# INLINE traverseEither #-}


-- | This is the fastest, most straightforward way to traverse
-- an array, but it only works correctly with a sufficiently
-- "affine" 'PrimMonad' instance. In particular, it must only produce
Expand Down
48 changes: 48 additions & 0 deletions bench/Array/Traverse/Either.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{-# LANGUAGE BangPatterns #-}

module Array.Traverse.Either
( traverseEither
) where

import Control.Monad.ST
import Control.Monad.Trans.State.Strict
import Control.Monad.Primitive
import Data.Primitive.Array

-- This is a specialization of traverse, where the applicative is
-- chosen to be Either. In the benchmark suite, this implementation
-- is compared against an implementation that uses ExceptT to see
-- if GHC is able to optimize the ExceptT variant to code as efficient
-- as this. At the time this test was written (2018-04-23), GHC does
-- appear to optimize the ExceptT variant so that it performs as well
-- as this one.
{-# INLINE traverseEither #-}
traverseEither ::
(a -> Either e b)
-> Array a
-> Either e (Array b)
traverseEither f = \ !ary ->
let
!sz = sizeofArray ary
go !i !mary
| i == sz = do
r <- unsafeFreezeArray mary
return (Right r)
| otherwise = do
a <- indexArrayM ary i
case f a of
Left e -> return (Left e)
Right b -> do
writeArray mary i b
go (i + 1) mary
in runST $ do
mary <- newArray sz badTraverseValue
go 0 mary

badTraverseValue :: a
badTraverseValue = die "traverseEither" "bad indexing"
{-# NOINLINE badTraverseValue #-}

die :: String -> String -> a
die fun problem = error $ "Array.Traverse.Either" ++ fun ++ ": " ++ problem

Loading