Skip to content

Commit

Permalink
Improve support for monad stacks
Browse files Browse the repository at this point in the history
The main change in this commit is split in MonadSTM, which avoids the
injectivity requirement, enabling GeneralizedNewtypeDeriving for
MonadSTM. The remainder of the changes are related, and similarly
intended to faciliate the derivation of monad stacks.
  • Loading branch information
edsko committed Feb 10, 2020
1 parent 99c0e3a commit 70edb92
Show file tree
Hide file tree
Showing 22 changed files with 500 additions and 405 deletions.
193 changes: 118 additions & 75 deletions io-sim-classes/src/Control/Monad/Class/MonadAsync.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Control.Monad.Class.MonadAsync
( MonadAsync (..)
, MonadAsyncSTM (..)
, AsyncCancelled(..)
, ExceptionInLinkedThread(..)
, link
Expand All @@ -21,25 +23,81 @@ import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow

import Control.Monad (void)
import Control.Concurrent.Async (AsyncCancelled (..))
import qualified Control.Concurrent.Async as Async
import Control.Exception (SomeException)
import qualified Control.Exception as E
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.Async (AsyncCancelled(..))
import Control.Monad (void)
import Control.Monad.Reader
import qualified Control.Monad.STM as STM
import Data.Proxy

class (Functor async, MonadSTMTx stm) => MonadAsyncSTM async stm where
{-# MINIMAL waitCatchSTM, pollSTM #-}

waitSTM :: async a -> stm a
pollSTM :: async a -> stm (Maybe (Either SomeException a))
waitCatchSTM :: async a -> stm (Either SomeException a)

default waitSTM :: MonadThrow stm => async a -> stm a
waitSTM action = waitCatchSTM action >>= either throwM return

waitAnySTM :: [async a] -> stm (async a, a)
waitAnyCatchSTM :: [async a] -> stm (async a, Either SomeException a)
waitEitherSTM :: async a -> async b -> stm (Either a b)
waitEitherSTM_ :: async a -> async b -> stm ()
waitEitherCatchSTM :: async a -> async b
-> stm (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: async a -> async b -> stm (a, b)

default waitAnySTM :: MonadThrow stm => [async a] -> stm (async a, a)
default waitEitherSTM :: MonadThrow stm => async a -> async b -> stm (Either a b)
default waitEitherSTM_ :: MonadThrow stm => async a -> async b -> stm ()
default waitBothSTM :: MonadThrow stm => async a -> async b -> stm (a, b)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

class ( MonadSTM m
, MonadThread m
, Functor (Async m)
, MonadAsyncSTM (Async m) (STM m)
) => MonadAsync m where

{-# MINIMAL async, asyncThreadId, cancel, cancelWith, waitCatchSTM, pollSTM #-}
{-# MINIMAL async, asyncThreadId, cancel, cancelWith #-}

-- | An asynchronous action
type Async m :: * -> *

async :: m a -> m (Async m a)
asyncThreadId :: proxy m -> Async m a -> ThreadId m
asyncThreadId :: Proxy m -> Async m a -> ThreadId m
withAsync :: m a -> (Async m a -> m b) -> m b

wait :: Async m a -> m a
Expand All @@ -49,10 +107,6 @@ class ( MonadSTM m
cancelWith :: Exception e => Async m a -> e -> m ()
uninterruptibleCancel :: Async m a -> m ()

waitSTM :: Async m a -> STM m a
pollSTM :: Async m a -> STM m (Maybe (Either SomeException a))
waitCatchSTM :: Async m a -> STM m (Either SomeException a)

waitAny :: [Async m a] -> m (Async m a, a)
waitAnyCatch :: [Async m a] -> m (Async m a, Either SomeException a)
waitAnyCancel :: [Async m a] -> m (Async m a, a)
Expand All @@ -70,15 +124,6 @@ class ( MonadSTM m
waitEither_ :: Async m a -> Async m b -> m ()
waitBoth :: Async m a -> Async m b -> m (a, b)

waitAnySTM :: [Async m a] -> STM m (Async m a, a)
waitAnyCatchSTM :: [Async m a] -> STM m (Async m a, Either SomeException a)
waitEitherSTM :: Async m a -> Async m b -> STM m (Either a b)
waitEitherSTM_ :: Async m a -> Async m b -> STM m ()
waitEitherCatchSTM :: Async m a -> Async m b
-> STM m (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: Async m a -> Async m b -> STM m (a, b)

race :: m a -> m b -> m (Either a b)
race_ :: m a -> m b -> m ()
concurrently :: m a -> m b -> m (a,b)
Expand All @@ -87,7 +132,6 @@ class ( MonadSTM m
default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b
default uninterruptibleCancel
:: MonadMask m => Async m a -> m ()
default waitSTM :: MonadThrow (STM m) => Async m a -> STM m a
default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a)
default waitAnyCatchCancel :: MonadThrow m => [Async m a]
-> m (Async m a, Either SomeException a)
Expand All @@ -97,12 +141,6 @@ class ( MonadSTM m
-> m (Either (Either SomeException a)
(Either SomeException b))

default waitAnySTM :: MonadThrow (STM m) => [Async m a] -> STM m (Async m a, a)
default waitEitherSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (Either a b)
default waitEitherSTM_ :: MonadThrow (STM m) => Async m a -> Async m b -> STM m ()
default waitBothSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (a, b)


withAsync action inner = mask $ \restore -> do
a <- async (restore action)
restore (inner a)
Expand All @@ -113,7 +151,6 @@ class ( MonadSTM m
waitCatch = atomically . waitCatchSTM

uninterruptibleCancel = uninterruptibleMask_ . cancel
waitSTM action = waitCatchSTM action >>= either throwM return

waitAny = atomically . waitAnySTM
waitAnyCatch = atomically . waitAnyCatchSTM
Expand All @@ -134,36 +171,6 @@ class ( MonadSTM m
waitEitherCatchCancel left right =
waitEitherCatch left right `finally` (cancel left >> cancel right)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

race left right = withAsync left $ \a ->
withAsync right $ \b ->
waitEither a b
Expand All @@ -180,6 +187,17 @@ class ( MonadSTM m
-- Instance for IO uses the existing async library implementations
--

instance MonadAsyncSTM Async.Async STM.STM where
waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM
waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

instance MonadAsync IO where

type Async IO = Async.Async
Expand All @@ -195,10 +213,6 @@ instance MonadAsync IO where
cancelWith = Async.cancelWith
uninterruptibleCancel = Async.uninterruptibleCancel

waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM

waitAny = Async.waitAny
waitAnyCatch = Async.waitAnyCatch
waitAnyCancel = Async.waitAnyCancel
Expand All @@ -210,17 +224,46 @@ instance MonadAsync IO where
waitEither_ = Async.waitEither_
waitBoth = Async.waitBoth

waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

race = Async.race
race_ = Async.race_
concurrently = Async.concurrently

--
-- Lift to ReaderT
--

(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
(f .: g) x y = f (g x y)

instance MonadAsync m => MonadAsync (ReaderT r m) where
type Async (ReaderT r m) = Async m

asyncThreadId _ = asyncThreadId (Proxy @m)

async (ReaderT ma) = ReaderT $ \r -> async (ma r)
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r) $ \a -> runReaderT (f a) r

race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r)
race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r)
concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r)

wait = lift . wait
poll = lift . poll
waitCatch = lift . waitCatch
cancel = lift . cancel
uninterruptibleCancel = lift . uninterruptibleCancel
cancelWith = lift .: cancelWith
waitAny = lift . waitAny
waitAnyCatch = lift . waitAnyCatch
waitAnyCancel = lift . waitAnyCancel
waitAnyCatchCancel = lift . waitAnyCatchCancel
waitEither = lift .: waitEither
waitEitherCatch = lift .: waitEitherCatch
waitEitherCancel = lift .: waitEitherCancel
waitEitherCatchCancel = lift .: waitEitherCatchCancel
waitEither_ = lift .: waitEither_
waitBoth = lift .: waitBoth

--
-- Linking
--
Expand Down Expand Up @@ -275,7 +318,7 @@ linkToOnly tid shouldThrow a = do
r <- waitCatch a
case r of
Left e | shouldThrow e -> throwTo tid (exceptionInLinkedThread e)
_otherwise -> return ()
_otherwise -> return ()
where
exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
Expand Down
5 changes: 4 additions & 1 deletion io-sim-classes/src/Control/Monad/Class/MonadST.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Class.MonadST where

import Control.Monad.ST (ST, stToIO)
import Control.Monad.Reader
import Control.Monad.ST (ST, stToIO)


-- | This class is for abstracting over 'stToIO' which allows running 'ST'
Expand Down Expand Up @@ -29,3 +30,5 @@ instance MonadST IO where
instance MonadST (ST s) where
withLiftST = \f -> f id

instance MonadST m => MonadST (ReaderT r m) where
withLiftST f = withLiftST $ \g -> f (lift . g)
Loading

0 comments on commit 70edb92

Please sign in to comment.