Skip to content

Commit

Permalink
Merge #1539
Browse files Browse the repository at this point in the history
1539: Improve support for monad stacks r=edsko a=edsko

The main change in this commit is split in MonadSTM, which avoids the injectivity requirement, enabling GeneralizedNewtypeDeriving for MonadSTM. The remainder of the changes are related, and similarly
intended to faciliate the derivation of monad stacks.

Note that a similar change could be made to `MonadTimer`, but I didn't need it and so I've not done that yet.

With this PR, we can define something like

```haskell
newtype MockTime m a = MockTime (ReaderT () m a)
  deriving ( Functor
           , Applicative
           , Monad
           , MonadTrans
           , MonadThrow
           , MonadSTM
           , MonadDelay
           , MonadTime
           )
```

and have it Just Work (TM). (Of course, this particular definition isn't very useful, just a proof of concept.)

Co-authored-by: Edsko de Vries <[email protected]>
  • Loading branch information
iohk-bors[bot] and edsko authored Feb 10, 2020
2 parents 99c0e3a + 70edb92 commit 2275859
Show file tree
Hide file tree
Showing 22 changed files with 500 additions and 405 deletions.
193 changes: 118 additions & 75 deletions io-sim-classes/src/Control/Monad/Class/MonadAsync.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Control.Monad.Class.MonadAsync
( MonadAsync (..)
, MonadAsyncSTM (..)
, AsyncCancelled(..)
, ExceptionInLinkedThread(..)
, link
Expand All @@ -21,25 +23,81 @@ import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow

import Control.Monad (void)
import Control.Concurrent.Async (AsyncCancelled (..))
import qualified Control.Concurrent.Async as Async
import Control.Exception (SomeException)
import qualified Control.Exception as E
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.Async (AsyncCancelled(..))
import Control.Monad (void)
import Control.Monad.Reader
import qualified Control.Monad.STM as STM
import Data.Proxy

class (Functor async, MonadSTMTx stm) => MonadAsyncSTM async stm where
{-# MINIMAL waitCatchSTM, pollSTM #-}

waitSTM :: async a -> stm a
pollSTM :: async a -> stm (Maybe (Either SomeException a))
waitCatchSTM :: async a -> stm (Either SomeException a)

default waitSTM :: MonadThrow stm => async a -> stm a
waitSTM action = waitCatchSTM action >>= either throwM return

waitAnySTM :: [async a] -> stm (async a, a)
waitAnyCatchSTM :: [async a] -> stm (async a, Either SomeException a)
waitEitherSTM :: async a -> async b -> stm (Either a b)
waitEitherSTM_ :: async a -> async b -> stm ()
waitEitherCatchSTM :: async a -> async b
-> stm (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: async a -> async b -> stm (a, b)

default waitAnySTM :: MonadThrow stm => [async a] -> stm (async a, a)
default waitEitherSTM :: MonadThrow stm => async a -> async b -> stm (Either a b)
default waitEitherSTM_ :: MonadThrow stm => async a -> async b -> stm ()
default waitBothSTM :: MonadThrow stm => async a -> async b -> stm (a, b)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

class ( MonadSTM m
, MonadThread m
, Functor (Async m)
, MonadAsyncSTM (Async m) (STM m)
) => MonadAsync m where

{-# MINIMAL async, asyncThreadId, cancel, cancelWith, waitCatchSTM, pollSTM #-}
{-# MINIMAL async, asyncThreadId, cancel, cancelWith #-}

-- | An asynchronous action
type Async m :: * -> *

async :: m a -> m (Async m a)
asyncThreadId :: proxy m -> Async m a -> ThreadId m
asyncThreadId :: Proxy m -> Async m a -> ThreadId m
withAsync :: m a -> (Async m a -> m b) -> m b

wait :: Async m a -> m a
Expand All @@ -49,10 +107,6 @@ class ( MonadSTM m
cancelWith :: Exception e => Async m a -> e -> m ()
uninterruptibleCancel :: Async m a -> m ()

waitSTM :: Async m a -> STM m a
pollSTM :: Async m a -> STM m (Maybe (Either SomeException a))
waitCatchSTM :: Async m a -> STM m (Either SomeException a)

waitAny :: [Async m a] -> m (Async m a, a)
waitAnyCatch :: [Async m a] -> m (Async m a, Either SomeException a)
waitAnyCancel :: [Async m a] -> m (Async m a, a)
Expand All @@ -70,15 +124,6 @@ class ( MonadSTM m
waitEither_ :: Async m a -> Async m b -> m ()
waitBoth :: Async m a -> Async m b -> m (a, b)

waitAnySTM :: [Async m a] -> STM m (Async m a, a)
waitAnyCatchSTM :: [Async m a] -> STM m (Async m a, Either SomeException a)
waitEitherSTM :: Async m a -> Async m b -> STM m (Either a b)
waitEitherSTM_ :: Async m a -> Async m b -> STM m ()
waitEitherCatchSTM :: Async m a -> Async m b
-> STM m (Either (Either SomeException a)
(Either SomeException b))
waitBothSTM :: Async m a -> Async m b -> STM m (a, b)

race :: m a -> m b -> m (Either a b)
race_ :: m a -> m b -> m ()
concurrently :: m a -> m b -> m (a,b)
Expand All @@ -87,7 +132,6 @@ class ( MonadSTM m
default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b
default uninterruptibleCancel
:: MonadMask m => Async m a -> m ()
default waitSTM :: MonadThrow (STM m) => Async m a -> STM m a
default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a)
default waitAnyCatchCancel :: MonadThrow m => [Async m a]
-> m (Async m a, Either SomeException a)
Expand All @@ -97,12 +141,6 @@ class ( MonadSTM m
-> m (Either (Either SomeException a)
(Either SomeException b))

default waitAnySTM :: MonadThrow (STM m) => [Async m a] -> STM m (Async m a, a)
default waitEitherSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (Either a b)
default waitEitherSTM_ :: MonadThrow (STM m) => Async m a -> Async m b -> STM m ()
default waitBothSTM :: MonadThrow (STM m) => Async m a -> Async m b -> STM m (a, b)


withAsync action inner = mask $ \restore -> do
a <- async (restore action)
restore (inner a)
Expand All @@ -113,7 +151,6 @@ class ( MonadSTM m
waitCatch = atomically . waitCatchSTM

uninterruptibleCancel = uninterruptibleMask_ . cancel
waitSTM action = waitCatchSTM action >>= either throwM return

waitAny = atomically . waitAnySTM
waitAnyCatch = atomically . waitAnyCatchSTM
Expand All @@ -134,36 +171,6 @@ class ( MonadSTM m
waitEitherCatchCancel left right =
waitEitherCatch left right `finally` (cancel left >> cancel right)

waitAnySTM as =
foldr orElse retry $
map (\a -> do r <- waitSTM a; return (a, r)) as

waitAnyCatchSTM as =
foldr orElse retry $
map (\a -> do r <- waitCatchSTM a; return (a, r)) as

waitEitherSTM left right =
(Left <$> waitSTM left)
`orElse`
(Right <$> waitSTM right)

waitEitherSTM_ left right =
(void $ waitSTM left)
`orElse`
(void $ waitSTM right)

waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left)
`orElse`
(Right <$> waitCatchSTM right)

waitBothSTM left right = do
a <- waitSTM left
`orElse`
(waitSTM right >> retry)
b <- waitSTM right
return (a,b)

race left right = withAsync left $ \a ->
withAsync right $ \b ->
waitEither a b
Expand All @@ -180,6 +187,17 @@ class ( MonadSTM m
-- Instance for IO uses the existing async library implementations
--

instance MonadAsyncSTM Async.Async STM.STM where
waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM
waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

instance MonadAsync IO where

type Async IO = Async.Async
Expand All @@ -195,10 +213,6 @@ instance MonadAsync IO where
cancelWith = Async.cancelWith
uninterruptibleCancel = Async.uninterruptibleCancel

waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
waitCatchSTM = Async.waitCatchSTM

waitAny = Async.waitAny
waitAnyCatch = Async.waitAnyCatch
waitAnyCancel = Async.waitAnyCancel
Expand All @@ -210,17 +224,46 @@ instance MonadAsync IO where
waitEither_ = Async.waitEither_
waitBoth = Async.waitBoth

waitAnySTM = Async.waitAnySTM
waitAnyCatchSTM = Async.waitAnyCatchSTM
waitEitherSTM = Async.waitEitherSTM
waitEitherSTM_ = Async.waitEitherSTM_
waitEitherCatchSTM = Async.waitEitherCatchSTM
waitBothSTM = Async.waitBothSTM

race = Async.race
race_ = Async.race_
concurrently = Async.concurrently

--
-- Lift to ReaderT
--

(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
(f .: g) x y = f (g x y)

instance MonadAsync m => MonadAsync (ReaderT r m) where
type Async (ReaderT r m) = Async m

asyncThreadId _ = asyncThreadId (Proxy @m)

async (ReaderT ma) = ReaderT $ \r -> async (ma r)
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r) $ \a -> runReaderT (f a) r

race (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race (ma r) (mb r)
race_ (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> race_ (ma r) (mb r)
concurrently (ReaderT ma) (ReaderT mb) = ReaderT $ \r -> concurrently (ma r) (mb r)

wait = lift . wait
poll = lift . poll
waitCatch = lift . waitCatch
cancel = lift . cancel
uninterruptibleCancel = lift . uninterruptibleCancel
cancelWith = lift .: cancelWith
waitAny = lift . waitAny
waitAnyCatch = lift . waitAnyCatch
waitAnyCancel = lift . waitAnyCancel
waitAnyCatchCancel = lift . waitAnyCatchCancel
waitEither = lift .: waitEither
waitEitherCatch = lift .: waitEitherCatch
waitEitherCancel = lift .: waitEitherCancel
waitEitherCatchCancel = lift .: waitEitherCatchCancel
waitEither_ = lift .: waitEither_
waitBoth = lift .: waitBoth

--
-- Linking
--
Expand Down Expand Up @@ -275,7 +318,7 @@ linkToOnly tid shouldThrow a = do
r <- waitCatch a
case r of
Left e | shouldThrow e -> throwTo tid (exceptionInLinkedThread e)
_otherwise -> return ()
_otherwise -> return ()
where
exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
Expand Down
5 changes: 4 additions & 1 deletion io-sim-classes/src/Control/Monad/Class/MonadST.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Class.MonadST where

import Control.Monad.ST (ST, stToIO)
import Control.Monad.Reader
import Control.Monad.ST (ST, stToIO)


-- | This class is for abstracting over 'stToIO' which allows running 'ST'
Expand Down Expand Up @@ -29,3 +30,5 @@ instance MonadST IO where
instance MonadST (ST s) where
withLiftST = \f -> f id

instance MonadST m => MonadST (ReaderT r m) where
withLiftST f = withLiftST $ \g -> f (lift . g)
Loading

0 comments on commit 2275859

Please sign in to comment.