diff --git a/CHANGELOG.md b/CHANGELOG.md index 170873f..7c5206b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 0.2.0.0 + +* Use a separate monad within `withTransaction` to prevent unsafe/arbitrary IO actions ([#7](https://github.com/brandonchinn178/persistent-mtl/issues/7), [#28](https://github.com/brandonchinn178/persistent-mtl/issues/28)) +* Add `MonadRerunnableIO` to support IO actions within `withTransaction` only if the IO action is determined to be rerunnable +* Add built-in support for retrying transactions if a serialization error occurs +* Remove `SqlQueryRep` as an export from `Database.Persist.Monad`. You shouldn't ever need it for normal usage. It is now re-exported by `Database.Persist.Monad.TestUtils`, since most of the usage of `SqlQueryRep` is in mocking queries. If you need it otherwise, you can import it directly from `Database.Persist.Monad.SqlQueryRep`. + # 0.1.0.1 Fix quickstart diff --git a/README.md b/README.md index a883ceb..a589303 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,12 @@ newtype MyApp a = MyApp instance MonadUnliftIO MyApp where withRunInIO = wrappedWithRunInIO MyApp unMyApp -getYoungPeople :: (MonadIO m, MonadSqlQuery m) => m [Entity Person] +getYoungPeople :: MonadSqlQuery m => m [Entity Person] getYoungPeople = selectList [PersonAge <. 18] [] main :: IO () -main = runStderrLoggingT $ withSqlitePool "db.sqlite" 5 $ \conn -> - liftIO $ runSqlQueryT conn $ unMyApp $ do +main = runStderrLoggingT $ withSqlitePool "db.sqlite" 5 $ \pool -> + liftIO $ runSqlQueryT pool $ unMyApp $ do runMigration migrate insert_ $ Person "Alice" 25 insert_ $ Person "Bob" 10 @@ -229,6 +229,59 @@ So what does `persistent-mtl` do differently? In summary, `persistent-mtl` takes all the good things about option 2, implements them out of the box (so you don't have to do it yourself), and makes your business logic functions composable with transactions behaving the way YOU want. +### Easy transaction management + +Some databases will throw an error if two transactions conflict (e.g. [PostgreSQL](https://www.postgresql.org/docs/9.5/transaction-iso.html)). The client is expected to retry transactions if this error is thrown. `persistent` doesn't easily support this out of the box, but `persistent-mtl` does! + +```hs +import Database.PostgreSQL.Simple.Errors (isSerializationError) + +main :: IO () +main = withPostgresqlPool "..." 5 $ \pool -> do + let env = mkSqlQueryEnv pool $ \env -> env + { retryIf = isSerializationError . fromException + , retryLimit = 100 -- defaults to 10 + } + + -- in any of the marked transactions below, if someone else is querying + -- the postgresql database at the same time with queries that conflict + -- with yours, your operations will automatically be retried + runSqlQueryTWith env $ do + -- transaction 1 + insert_ $ ... + + -- transaction 2 + withTransaction $ do + insert_ $ ... + + -- transaction 2.5: transaction-within-a-transaction is supported in PostgreSQL + withTransaction $ do + insert_ $ ... + + insert_ $ ... + + -- transaction 3 + insert_ $ ... +``` + +Because of this built-in retry support, any IO actions inside `withTransaction` have to be explicitly marked with `rerunnableIO`. If you try to use a function with a `MonadIO m` constraint, you'll get a compile-time error! + +``` +.../Foo.hs:100:5: error: + • Cannot run arbitrary IO actions within a transaction. If the IO action is rerunnable, use rerunnableIO + • In a stmt of a 'do' block: arbitraryIO + In the second argument of ‘($)’, namely + ‘withTransaction + $ do insert_ record1 + arbitraryIO + insert_ record2’ + | +100 | arbitraryIO + | ^^^^^^^^^^^ +``` + +Note that this **only** applies for transactions, so `MonadIO` and `MonadSqlQuery` constraints can still co-exist (for a function with IO actions that are not rerunnable) as long as the function is never called within `withTransaction`. + ### Testing functions that use `persistent` operations Generally, I would recommend someone using `persistent` in their application to make a monad type class containing the API for their domain, like diff --git a/package.yaml b/package.yaml index 8e971b8..590e75b 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: persistent-mtl -version: 0.1.0.1 +version: 0.2.0.0 maintainer: Brandon Chinn synopsis: Monad transformer for the persistent API description: | @@ -29,6 +29,7 @@ library: - resourcet-pool >= 0.1.0.0 && < 0.2 - text >= 1.2.3.0 && < 2 - transformers >= 0.5.2.0 && < 0.6 + - unliftio >= 0.2.7.0 && < 0.3 - unliftio-core >= 0.1.2.0 && < 0.3 tests: diff --git a/persistent-mtl.cabal b/persistent-mtl.cabal index 9d36506..5ca9e6f 100644 --- a/persistent-mtl.cabal +++ b/persistent-mtl.cabal @@ -4,10 +4,10 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: a6b1252a25af52e3ddd10e843b4818e0255c269dfbc6a399eb29e41093ef0408 +-- hash: 9c53e0610dea4ca814133d0596978b961b85fd54cb6df7a82eccb6d260824dde name: persistent-mtl -version: 0.1.0.1 +version: 0.2.0.0 synopsis: Monad transformer for the persistent API description: A monad transformer and mtl-style type class for using the persistent API directly in your monad transformer stack. @@ -32,6 +32,7 @@ source-repository head library exposed-modules: + Control.Monad.IO.Rerunnable Database.Persist.Monad Database.Persist.Monad.Class Database.Persist.Monad.Shim @@ -53,6 +54,7 @@ library , resourcet-pool >=0.1.0.0 && <0.2 , text >=1.2.3.0 && <2 , transformers >=0.5.2.0 && <0.6 + , unliftio >=0.2.7.0 && <0.3 , unliftio-core >=0.1.2.0 && <0.3 default-language: Haskell2010 diff --git a/scripts/generate/templates/TestHelpers.mustache b/scripts/generate/templates/TestHelpers.mustache index aeda95c..62dde8b 100644 --- a/scripts/generate/templates/TestHelpers.mustache +++ b/scripts/generate/templates/TestHelpers.mustache @@ -12,9 +12,9 @@ import Data.Int (Int64) import Data.Map (Map) import Data.Text (Text) import Data.Void (Void) -import Database.Persist.Sql hiding (pattern Update) +import Database.Persist.Sql (CautiousMigration, Entity, Key, PersistValue, Sql) -import Database.Persist.Monad +import Database.Persist.Monad.TestUtils (SqlQueryRep(..)) import Example {-# ANN module "HLint: ignore" #-} diff --git a/src/Control/Monad/IO/Rerunnable.hs b/src/Control/Monad/IO/Rerunnable.hs new file mode 100644 index 0000000..9c0b507 --- /dev/null +++ b/src/Control/Monad/IO/Rerunnable.hs @@ -0,0 +1,71 @@ +{-| +Module: Control.Monad.IO.Rerunnable + +Defines the 'MonadRerunnableIO' type class that is functionally equivalent +to 'Control.Monad.IO.Class.MonadIO', but use of it requires the user to +explicitly acknowledge that the given IO operation can be rerun. +-} + +module Control.Monad.IO.Rerunnable + ( MonadRerunnableIO(..) + ) where + +import Control.Monad.Trans.Class (lift) +import qualified Control.Monad.Trans.Except as Except +import qualified Control.Monad.Trans.Identity as Identity +import qualified Control.Monad.Trans.Maybe as Maybe +import qualified Control.Monad.Trans.RWS.Lazy as RWS.Lazy +import qualified Control.Monad.Trans.RWS.Strict as RWS.Strict +import qualified Control.Monad.Trans.Reader as Reader +import qualified Control.Monad.Trans.Resource as Resource +import qualified Control.Monad.Trans.State.Lazy as State.Lazy +import qualified Control.Monad.Trans.State.Strict as State.Strict +import qualified Control.Monad.Trans.Writer.Lazy as Writer.Lazy +import qualified Control.Monad.Trans.Writer.Strict as Writer.Strict + +-- | A copy of 'Control.Monad.IO.Class.MonadIO' to explicitly allow only IO +-- operations that are rerunnable, e.g. in the context of a SQL transaction. +class Monad m => MonadRerunnableIO m where + -- | Lift the given IO operation to @m@. + -- + -- The given IO operation may be rerun, so use of this function requires + -- manually verifying that the given IO operation is rerunnable. + rerunnableIO :: IO a -> m a + +instance MonadRerunnableIO IO where + rerunnableIO = id + +{- Instances for common monad transformers -} + +instance MonadRerunnableIO m => MonadRerunnableIO (Reader.ReaderT r m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (Except.ExceptT e m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (Identity.IdentityT m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (Maybe.MaybeT m) where + rerunnableIO = lift . rerunnableIO + +instance (Monoid w, MonadRerunnableIO m) => MonadRerunnableIO (RWS.Lazy.RWST r w s m) where + rerunnableIO = lift . rerunnableIO + +instance (Monoid w, MonadRerunnableIO m) => MonadRerunnableIO (RWS.Strict.RWST r w s m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (State.Lazy.StateT s m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (State.Strict.StateT s m) where + rerunnableIO = lift . rerunnableIO + +instance (Monoid w, MonadRerunnableIO m) => MonadRerunnableIO (Writer.Lazy.WriterT w m) where + rerunnableIO = lift . rerunnableIO + +instance (Monoid w, MonadRerunnableIO m) => MonadRerunnableIO (Writer.Strict.WriterT w m) where + rerunnableIO = lift . rerunnableIO + +instance MonadRerunnableIO m => MonadRerunnableIO (Resource.ResourceT m) where + rerunnableIO = lift . rerunnableIO diff --git a/src/Database/Persist/Monad.hs b/src/Database/Persist/Monad.hs index 690aad6..c3d45c0 100644 --- a/src/Database/Persist/Monad.hs +++ b/src/Database/Persist/Monad.hs @@ -1,8 +1,11 @@ {-| Module: Database.Persist.Monad -Defines the 'SqlQueryT' monad transformer that has a 'MonadSqlQuery' instance -to execute @persistent@ database operations. +Defines the 'SqlQueryT' monad transformer, which has a 'MonadSqlQuery' instance +to execute @persistent@ database operations. Also provides easy transaction +management with 'withTransaction', which supports retrying with exponential +backoff and restricts IO actions to only allow IO actions explicitly marked +as rerunnable. Usage: @@ -18,34 +21,48 @@ myFunction = do liftIO $ print (personList :: [Person]) -- everything in here will run in a transaction - withTransaction $ + withTransaction $ do selectFirst [PersonAge >. 30] [] >>= \\case Nothing -> insert_ $ Person { name = \"Claire\", age = Just 50 } Just (Entity key person) -> replace key person{ age = Just (age person - 10) } + -- liftIO doesn't work in here, since transactions can be retried. + -- Use rerunnableIO to run IO actions, after verifying that the IO action + -- can be rerun if the transaction needs to be retried. + rerunnableIO $ putStrLn "Transaction is finished!" + -- some more business logic return () @ -} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} module Database.Persist.Monad ( -- * Type class for executing database queries MonadSqlQuery , withTransaction - , SqlQueryRep(..) -- * SqlQueryT monad transformer , SqlQueryT , runSqlQueryT + , runSqlQueryTWith + , SqlQueryEnv(..) + , mkSqlQueryEnv + + -- * Transactions + , SqlTransaction + , TransactionError(..) -- * Lifted functions , module Database.Persist.Monad.Shim @@ -53,23 +70,104 @@ module Database.Persist.Monad import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Unlift (MonadUnliftIO(..), wrappedWithRunInIO) -import Control.Monad.Reader (ReaderT, ask, local, runReaderT) +import Control.Monad.Reader (ReaderT, ask, runReaderT) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Resource (MonadResource) import Data.Acquire (withAcquire) import Data.Pool (Pool) import Data.Pool.Acquire (poolToAcquire) -import Database.Persist.Sql (SqlBackend, runSqlConn) +import Database.Persist.Sql (SqlBackend, SqlPersistT, runSqlConn) +import qualified GHC.TypeLits as GHC +import UnliftIO.Concurrent (threadDelay) +import UnliftIO.Exception (Exception, SomeException, catchJust, throwIO) +import Control.Monad.IO.Rerunnable (MonadRerunnableIO) import Database.Persist.Monad.Class import Database.Persist.Monad.Shim import Database.Persist.Monad.SqlQueryRep +{- SqlTransaction -} + +-- | The monad that tracks transaction state. +-- +-- Conceptually equivalent to 'Database.Persist.Sql.SqlPersistT', but restricts +-- IO operations, for two reasons: +-- 1. Forking a thread that uses the same 'SqlBackend' as the current thread +-- causes Bad Things to happen. +-- 2. Transactions may need to be retried, in which case IO operations in +-- a transaction are required to be rerunnable. +-- +-- You shouldn't need to explicitly use this type; your functions should only +-- declare the 'MonadSqlQuery' constraint. +newtype SqlTransaction m a = SqlTransaction + { unSqlTransaction :: SqlPersistT m a + } + deriving (Functor, Applicative, Monad, MonadRerunnableIO) + +instance + ( GHC.TypeError ('GHC.Text "Cannot run arbitrary IO actions within a transaction. If the IO action is rerunnable, use rerunnableIO") + , Monad m + ) + => MonadIO (SqlTransaction m) where + liftIO = undefined + +instance (MonadSqlQuery m, MonadUnliftIO m) => MonadSqlQuery (SqlTransaction m) where + type TransactionM (SqlTransaction m) = TransactionM m + + runQueryRep = SqlTransaction . runSqlQueryRep + + -- Delegate to 'm', since 'm' is in charge of starting/stopping transactions. + -- 'SqlTransaction' is ONLY in charge of executing queries. + withTransaction = SqlTransaction . withTransaction + +runSqlTransaction :: MonadUnliftIO m => SqlBackend -> SqlTransaction m a -> m a +runSqlTransaction conn = (`runSqlConn` conn) . unSqlTransaction + +-- | Errors that can occur within a SQL transaction. +data TransactionError + = RetryLimitExceeded + -- ^ The retry limit was reached when retrying a transaction. + deriving (Show, Eq) + +instance Exception TransactionError + {- SqlQueryT monad -} +-- | Environment to configure running 'SqlQueryT'. +-- +-- For simple usage, you can just use 'runSqlQueryT', but for more advanced +-- usage, including the ability to retry transactions, use 'mkSqlQueryEnv' with +-- 'runSqlQueryTWith'. data SqlQueryEnv = SqlQueryEnv { backendPool :: Pool SqlBackend - , currentConn :: Maybe SqlBackend + -- ^ The pool for your persistent backend. Get this from @withSqlitePool@ + -- or the equivalent for your backend. + + , retryIf :: SomeException -> Bool + -- ^ Retry a transaction when an exception matches this predicate. Will + -- retry with an exponential backoff. + -- + -- Defaults to always returning False (i.e. never retry) + + , retryLimit :: Int + -- ^ The number of times to retry, if 'retryIf' is satisfied. + -- + -- Defaults to 10. + } + +-- | Build a SqlQueryEnv from the default. +-- +-- Usage: +-- +-- @ +-- let env = mkSqlQueryEnv pool $ \\env -> env { retryIf = 10 } +-- in runSqlQueryTWith env m +-- @ +mkSqlQueryEnv :: Pool SqlBackend -> (SqlQueryEnv -> SqlQueryEnv) -> SqlQueryEnv +mkSqlQueryEnv backendPool f = f SqlQueryEnv + { backendPool + , retryIf = const False + , retryLimit = 10 } -- | The monad transformer that implements 'MonadSqlQuery'. @@ -82,18 +180,27 @@ newtype SqlQueryT m a = SqlQueryT , MonadIO , MonadTrans , MonadResource + , MonadRerunnableIO ) instance MonadUnliftIO m => MonadSqlQuery (SqlQueryT m) where - runQueryRep queryRep = do - SqlQueryEnv{currentConn} <- SqlQueryT ask - case currentConn of - Just conn -> runWithConn conn - Nothing -> withTransactionConn runWithConn - where - runWithConn = runReaderT (runSqlQueryRep queryRep) - - withTransaction action = withTransactionConn $ \_ -> action + type TransactionM (SqlQueryT m) = SqlTransaction (SqlQueryT m) + + -- Running a query directly in SqlQueryT will create a one-off transaction. + runQueryRep = withTransaction . runQueryRep + + -- Start a new transaction and run the given 'SqlTransaction' + withTransaction m = do + SqlQueryEnv{..} <- SqlQueryT ask + withAcquire (poolToAcquire backendPool) $ \conn -> + let filterRetry e = if retryIf e then Just e else Nothing + loop i = catchJust filterRetry (runSqlTransaction conn m) $ \_ -> + if i < retryLimit + then do + threadDelay $ 1000 * 2^i + loop $! i + 1 + else throwIO RetryLimitExceeded + in loop 0 instance MonadUnliftIO m => MonadUnliftIO (SqlQueryT m) where withRunInIO = wrappedWithRunInIO SqlQueryT unSqlQueryT @@ -102,16 +209,9 @@ instance MonadUnliftIO m => MonadUnliftIO (SqlQueryT m) where -- | Run the 'SqlQueryT' monad transformer with the given backend. runSqlQueryT :: Pool SqlBackend -> SqlQueryT m a -> m a -runSqlQueryT backendPool = (`runReaderT` env) . unSqlQueryT - where - env = SqlQueryEnv { currentConn = Nothing, .. } - --- | Start a new transaction and get the connection. -withTransactionConn :: MonadUnliftIO m => (SqlBackend -> SqlQueryT m a) -> SqlQueryT m a -withTransactionConn f = do - SqlQueryEnv{backendPool} <- SqlQueryT ask - withAcquire (poolToAcquire backendPool) $ \conn -> - SqlQueryT . local (setCurrentConn conn) . unSqlQueryT $ - runSqlConn (lift $ f conn) conn - where - setCurrentConn conn env = env { currentConn = Just conn } +runSqlQueryT backendPool = runSqlQueryTWith $ mkSqlQueryEnv backendPool id + +-- | Run the 'SqlQueryT' monad transformer with the explicitly provided +-- environment. +runSqlQueryTWith :: SqlQueryEnv -> SqlQueryT m a -> m a +runSqlQueryTWith env = (`runReaderT` env) . unSqlQueryT diff --git a/src/Database/Persist/Monad/Class.hs b/src/Database/Persist/Monad/Class.hs index add5e0b..74e06df 100644 --- a/src/Database/Persist/Monad/Class.hs +++ b/src/Database/Persist/Monad/Class.hs @@ -6,6 +6,7 @@ in order to interpret how to run a 'Database.Persist.Monad.SqlQueryRep.SqlQueryRep' sent by a lifted function from @Database.Persist.Monad.Shim@. -} +{-# LANGUAGE TypeFamilies #-} module Database.Persist.Monad.Class ( MonadSqlQuery(..) @@ -22,61 +23,69 @@ import qualified Control.Monad.Trans.State.Lazy as State.Lazy import qualified Control.Monad.Trans.State.Strict as State.Strict import qualified Control.Monad.Trans.Writer.Lazy as Writer.Lazy import qualified Control.Monad.Trans.Writer.Strict as Writer.Strict +import Data.Kind (Type) import Data.Typeable (Typeable) import Database.Persist.Monad.SqlQueryRep (SqlQueryRep) -- | The type-class for monads that can run persistent database queries. class Monad m => MonadSqlQuery m where - -- | The main function that interprets a SQL query operation and runs it - -- in the monadic context. + type TransactionM m :: Type -> Type + + -- | Interpret the given SQL query operation. runQueryRep :: Typeable record => SqlQueryRep record a -> m a -- | Run all queries in the given action using the same database connection. - -- - -- You should make sure to not fork any threads within this action. This - -- will almost certainly cause problems. - -- https://github.com/brandonchinn178/persistent-mtl/issues/7 - withTransaction :: m a -> m a + withTransaction :: TransactionM m a -> m a {- Instances for common monad transformers -} instance MonadSqlQuery m => MonadSqlQuery (Reader.ReaderT r m) where + type TransactionM (Reader.ReaderT r m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Reader.mapReaderT withTransaction + withTransaction = lift . withTransaction instance MonadSqlQuery m => MonadSqlQuery (Except.ExceptT e m) where + type TransactionM (Except.ExceptT e m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Except.mapExceptT withTransaction + withTransaction = lift . withTransaction instance MonadSqlQuery m => MonadSqlQuery (Identity.IdentityT m) where + type TransactionM (Identity.IdentityT m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Identity.mapIdentityT withTransaction + withTransaction = lift . withTransaction instance MonadSqlQuery m => MonadSqlQuery (Maybe.MaybeT m) where + type TransactionM (Maybe.MaybeT m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Maybe.mapMaybeT withTransaction + withTransaction = lift . withTransaction instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (RWS.Lazy.RWST r w s m) where + type TransactionM (RWS.Lazy.RWST r w s m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = RWS.Lazy.mapRWST withTransaction + withTransaction = lift . withTransaction instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (RWS.Strict.RWST r w s m) where + type TransactionM (RWS.Strict.RWST r w s m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = RWS.Strict.mapRWST withTransaction + withTransaction = lift . withTransaction instance MonadSqlQuery m => MonadSqlQuery (State.Lazy.StateT s m) where + type TransactionM (State.Lazy.StateT s m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = State.Lazy.mapStateT withTransaction + withTransaction = lift . withTransaction instance MonadSqlQuery m => MonadSqlQuery (State.Strict.StateT s m) where + type TransactionM (State.Strict.StateT s m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = State.Strict.mapStateT withTransaction + withTransaction = lift . withTransaction instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (Writer.Lazy.WriterT w m) where + type TransactionM (Writer.Lazy.WriterT w m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Writer.Lazy.mapWriterT withTransaction + withTransaction = lift . withTransaction instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (Writer.Strict.WriterT w m) where + type TransactionM (Writer.Strict.WriterT w m) = TransactionM m runQueryRep = lift . runQueryRep - withTransaction = Writer.Strict.mapWriterT withTransaction + withTransaction = lift . withTransaction diff --git a/src/Database/Persist/Monad/TestUtils.hs b/src/Database/Persist/Monad/TestUtils.hs index 49b505f..b2ac762 100644 --- a/src/Database/Persist/Monad/TestUtils.hs +++ b/src/Database/Persist/Monad/TestUtils.hs @@ -11,6 +11,7 @@ Defines 'MockSqlQueryT', which one can use in tests in order to mock out {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} module Database.Persist.Monad.TestUtils ( MockSqlQueryT @@ -18,12 +19,16 @@ module Database.Persist.Monad.TestUtils , withRecord , mockQuery , MockQuery + -- * Specialized helpers , mockSelectSource , mockSelectKeys , mockWithRawQuery , mockRawQuery , mockRawSql + + -- * Re-exports + , SqlQueryRep(..) ) where import Conduit ((.|)) @@ -85,6 +90,8 @@ runMockSqlQueryT :: MockSqlQueryT m a -> [MockQuery] -> m a runMockSqlQueryT action mockQueries = (`runReaderT` mockQueries) . unMockSqlQueryT $ action instance MonadIO m => MonadSqlQuery (MockSqlQueryT m) where + type TransactionM (MockSqlQueryT m) = MockSqlQueryT m + runQueryRep rep = do mockQueries <- MockSqlQueryT ask maybe (error $ "Could not find mock for query: " ++ show rep) liftIO diff --git a/test/Example.hs b/test/Example.hs index d71ab2a..e593d2b 100644 --- a/test/Example.hs +++ b/test/Example.hs @@ -17,6 +17,7 @@ module Example ( TestApp , runTestApp + , runTestAppWith -- * Person , Person(..) @@ -53,6 +54,7 @@ import Database.Persist.TH ) import UnliftIO (MonadUnliftIO(..), wrappedWithRunInIO) +import Control.Monad.IO.Rerunnable (MonadRerunnableIO) import Database.Persist.Monad import TestUtils.DB (BackendType(..), withTestDB) @@ -95,6 +97,7 @@ newtype TestApp a = TestApp , Applicative , Monad , MonadIO + , MonadRerunnableIO , MonadSqlQuery , MonadResource ) @@ -109,6 +112,14 @@ runTestApp backendType m = _ <- runMigrationSilent migration m +runTestAppWith :: BackendType -> (SqlQueryEnv -> SqlQueryEnv) -> TestApp a -> IO a +runTestAppWith backendType f m = + withTestDB backendType $ \pool -> do + let env = mkSqlQueryEnv pool f + runResourceT . runSqlQueryTWith env . unTestApp $ do + _ <- runMigrationSilent migration + m + {- Person functions -} person :: String -> Person diff --git a/test/Generated.hs b/test/Generated.hs index 285ca00..9ba21b1 100644 --- a/test/Generated.hs +++ b/test/Generated.hs @@ -12,9 +12,9 @@ import Data.Int (Int64) import Data.Map (Map) import Data.Text (Text) import Data.Void (Void) -import Database.Persist.Sql hiding (pattern Update) +import Database.Persist.Sql (CautiousMigration, Entity, Key, PersistValue, Sql) -import Database.Persist.Monad +import Database.Persist.Monad.TestUtils (SqlQueryRep(..)) import Example {-# ANN module "HLint: ignore" #-} diff --git a/test/Integration.hs b/test/Integration.hs index 36723c2..a36d4ea 100644 --- a/test/Integration.hs +++ b/test/Integration.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -31,8 +32,19 @@ import Database.Persist.Sql (IsolationLevel(..)) #endif import Test.Tasty import Test.Tasty.HUnit -import UnliftIO (Exception, MonadIO, MonadUnliftIO, liftIO, throwIO, try) +import UnliftIO (MonadIO, MonadUnliftIO, liftIO) +import UnliftIO.Exception + ( Exception + , SomeException + , StringException(..) + , fromException + , throwIO + , throwString + , try + ) +import UnliftIO.IORef (atomicModifyIORef, newIORef) +import Control.Monad.IO.Rerunnable (MonadRerunnableIO, rerunnableIO) import Database.Persist.Monad import Example import TestUtils.DB (BackendType(..), allBackendTypes) @@ -45,6 +57,7 @@ tests = testGroup "Integration tests" $ testsWithBackend :: BackendType -> TestTree testsWithBackend backendType = testGroup (show backendType) [ testWithTransaction backendType + , testComposability backendType , testPersistentAPI backendType ] @@ -62,8 +75,74 @@ testWithTransaction backendType = testGroup "withTransaction" catchTestError $ withTransaction $ insertAndFail $ person "Alice" result <- getPeopleNames liftIO $ result @?= [] + + , testCase "retries transactions" $ do + let retryIf e = case fromException e of + Just (StringException "retry me" _) -> True + _ -> False + setRetry env = env { retryIf, retryLimit = 5 } + + counter <- newIORef (0 :: Int) + + result <- try @_ @SomeException $ runTestAppWith backendType setRetry $ + withTransaction $ rerunnableIO $ do + x <- atomicModifyIORef counter $ \x -> (x + 1, x) + if x > 2 + then return () + else throwString "retry me" + + case result of + Right () -> return () + Left e -> error $ "Got unexpected error: " ++ show e + + , testCase "throws error when retry hits limit" $ do + let setRetry env = env { retryIf = const True, retryLimit = 2 } + + result <- try @_ @TransactionError @() $ runTestAppWith backendType setRetry $ + withTransaction $ rerunnableIO $ throwString "retry me" + + result @?= Left RetryLimitExceeded ] +-- this should compile +testComposability :: BackendType -> TestTree +testComposability backendType = testCase "Operations can be composed" $ do + let onlySql :: MonadSqlQuery m => m () + onlySql = do + _ <- getPeople + return () + + sqlAndRerunnableIO :: (MonadSqlQuery m, MonadRerunnableIO m) => m () + sqlAndRerunnableIO = do + _ <- getPeopleNames + _ <- rerunnableIO $ newIORef True + return () + + onlyRerunnableIO :: MonadRerunnableIO m => m () + onlyRerunnableIO = do + _ <- rerunnableIO $ newIORef True + return () + + arbitraryIO :: MonadIO m => m () + arbitraryIO = do + _ <- liftIO $ newIORef True + return () + + -- everything should compose naturally by default + runTestApp backendType $ do + onlySql + sqlAndRerunnableIO + onlyRerunnableIO + arbitraryIO + + -- in a transaction, you can compose everything except arbitrary IO + runTestApp backendType $ withTransaction $ do + onlySql + sqlAndRerunnableIO + onlyRerunnableIO + -- uncomment this to get compile error + -- arbitraryIO + testPersistentAPI :: BackendType -> TestTree testPersistentAPI backendType = testGroup "Persistent API" [ testCase "get" $ do @@ -780,7 +859,7 @@ catchTestError m = do liftIO $ result @?= Left TestError insertAndFail :: - ( MonadIO m + ( MonadRerunnableIO m , MonadSqlQuery m , PersistRecordBackend record SqlBackend , Typeable record @@ -788,7 +867,7 @@ insertAndFail :: => record -> m () insertAndFail record = do insert_ record - throwIO TestError + rerunnableIO $ throwIO TestError assertNotIn :: (Eq a, Show a) => a -> [a] -> Assertion assertNotIn a as = as @?= filter (/= a) as diff --git a/test/MockSqlQueryT.hs b/test/MockSqlQueryT.hs index 43689b7..ba5ddb8 100644 --- a/test/MockSqlQueryT.hs +++ b/test/MockSqlQueryT.hs @@ -10,7 +10,6 @@ import Test.Tasty import Test.Tasty.HUnit import UnliftIO (SomeException, try) -import Database.Persist.Monad import Database.Persist.Monad.TestUtils import Example