diff --git a/README.md b/README.md index a1cda23..e9d3411 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ ![CircleCI](https://img.shields.io/circleci/build/github/brandonchinn178/persistent-mtl) ![Hackage](https://img.shields.io/hackage/v/persistent-mtl) +![Codecov](https://img.shields.io/codecov/c/gh/brandonchinn178/persistent-mtl) TODO diff --git a/package.yaml b/package.yaml index 795e807..4490e2c 100644 --- a/package.yaml +++ b/package.yaml @@ -13,6 +13,7 @@ library: - resource-pool >= 0.2.3.2 && < 0.3 - resourcet >= 1.2.1 && < 2 - text >= 1.2.3.0 && < 2 + - transformers >= 0.5.2.0 && < 0.6 - unliftio-core >= 0.1.2.0 && < 0.3 tests: diff --git a/persistent-mtl.cabal b/persistent-mtl.cabal index fc5655f..b7f9542 100644 --- a/persistent-mtl.cabal +++ b/persistent-mtl.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: a9e4d579a32b54a7ae519c2276148dd5c124932d51972d18cc1050f294ef033f +-- hash: ed87d93280415b4fcbbb88057eb1df14d7ed6b3251947b40b2b150598adef468 name: persistent-mtl version: 0.1.0.0 @@ -14,6 +14,10 @@ build-type: Simple library exposed-modules: Database.Persist.Monad + Database.Persist.Monad.Class + Database.Persist.Monad.Shim + Database.Persist.Monad.SqlQueryRep + Database.Persist.Monad.TestUtils other-modules: Paths_persistent_mtl hs-source-dirs: @@ -26,6 +30,7 @@ library , resource-pool >=0.2.3.2 && <0.3 , resourcet >=1.2.1 && <2 , text >=1.2.3.0 && <2 + , transformers >=0.5.2.0 && <0.6 , unliftio-core >=0.1.2.0 && <0.3 default-language: Haskell2010 diff --git a/src/Database/Persist/Monad.hs b/src/Database/Persist/Monad.hs index 3b80f33..dfd115a 100644 --- a/src/Database/Persist/Monad.hs +++ b/src/Database/Persist/Monad.hs @@ -4,47 +4,34 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} module Database.Persist.Monad - ( MonadSqlQuery(..) + ( + -- * Type class for executing database queries + MonadSqlQuery(..) + , SqlQueryRep(..) - -- * SqlQueryT monad transformer + -- * SqlQueryT monad transformer , SqlQueryT , SqlQueryBackend(..) , runSqlQueryT - -- * Test utility - , MockSqlQueryT - , runMockSqlQueryT - , withRecord - - -- * Coerced functions - , SqlQueryRep(..) - , selectList - , insert - , insert_ - , runMigrationSilent + -- * Lifted functions + , module Database.Persist.Monad.Shim ) where -import Control.Monad (msum) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Unlift (MonadUnliftIO(..), wrappedWithRunInIO) -import Control.Monad.Reader (ReaderT, ask, lift, local, runReaderT) +import Control.Monad.Reader (ReaderT, ask, local, runReaderT) +import Control.Monad.Trans.Class (MonadTrans(..)) import Data.Pool (Pool) -import Data.Proxy (Proxy(..)) -import Data.Text (Text) -import Data.Typeable (Typeable, eqT, typeRep, (:~:)(..)) -import Database.Persist (Entity, Filter, Key, PersistRecordBackend, SelectOpt) -import Database.Persist.Sql (Migration, SqlBackend, runSqlPool) -import qualified Database.Persist.Sql as Persist +import Database.Persist.Sql (SqlBackend, runSqlConn, runSqlPool) -class MonadSqlQuery m where - runQueryRep :: Typeable record => SqlQueryRep record a -> m a - runRawQuery :: Persist.SqlPersistT m a -> m a - withTransaction :: m a -> m a +import Database.Persist.Monad.Class +import Database.Persist.Monad.Shim +import Database.Persist.Monad.SqlQueryRep -{- SqlQueryT -} +{- SqlQueryT monad -} data SqlQueryEnv = SqlQueryEnv { backend :: SqlQueryBackend @@ -58,11 +45,23 @@ newtype SqlQueryT m a = SqlQueryT , Applicative , Monad , MonadIO + , MonadTrans ) +instance MonadUnliftIO m => MonadSqlQuery (SqlQueryT m) where + runQueryRep queryRep = + withCurrentConnection $ \conn -> + runSqlConn (runSqlQueryRep queryRep) conn + + withTransaction action = + withCurrentConnection $ \conn -> + SqlQueryT . local (\env -> env { currentConn = Just conn }) . unSqlQueryT $ action + instance MonadUnliftIO m => MonadUnliftIO (SqlQueryT m) where withRunInIO = wrappedWithRunInIO SqlQueryT unSqlQueryT +{- Running SqlQueryT -} + data SqlQueryBackend = BackendSingle SqlBackend | BackendPool (Pool SqlBackend) @@ -79,89 +78,3 @@ withCurrentConnection f = SqlQueryT ask >>= \case -- Otherwise, get a new connection SqlQueryEnv { backend = BackendSingle conn } -> f conn SqlQueryEnv { backend = BackendPool pool } -> runSqlPool (lift . f =<< ask) pool - -instance MonadUnliftIO m => MonadSqlQuery (SqlQueryT m) where - runQueryRep = runRawQuery . runSqlQueryRep - - runRawQuery m = withCurrentConnection (Persist.runSqlConn m) - - withTransaction action = - withCurrentConnection $ \conn -> - SqlQueryT . local (\env -> env { currentConn = Just conn }) . unSqlQueryT $ action - -{- SqlQueryRep - TODO: generate this with TH --} - -data SqlQueryRep record a where - SelectList - :: PersistRecordBackend record SqlBackend - => [Filter record] -> [SelectOpt record] -> SqlQueryRep record [Entity record] - - Insert - :: PersistRecordBackend record SqlBackend - => record -> SqlQueryRep record (Key record) - - Insert_ - :: PersistRecordBackend record SqlBackend - => record -> SqlQueryRep record () - -instance Typeable record => Show (SqlQueryRep record a) where - show = \case - SelectList{} -> "SelectList{..}" ++ record - Insert{} -> "Insert{..}" ++ record - Insert_{} -> "Insert_{..}" ++ record - where - record = "<" ++ show (typeRep $ Proxy @record) ++ ">" - -runSqlQueryRep :: MonadIO m => SqlQueryRep record a -> Persist.SqlPersistT m a -runSqlQueryRep = \case - SelectList a b -> Persist.selectList a b - Insert a -> Persist.insert a - Insert_ a -> Persist.insert_ a - -selectList :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => [Filter record] -> [SelectOpt record] -> m [Entity record] -selectList a b = runQueryRep $ SelectList a b - -insert :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => record -> m (Key record) -insert a = runQueryRep $ Insert a - -insert_ :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => record -> m () -insert_ a = runQueryRep $ Insert_ a - -runMigrationSilent :: (MonadUnliftIO m, MonadSqlQuery m) => Migration -> m [Text] -runMigrationSilent a = runRawQuery $ Persist.runMigrationSilent a - -{- MockSqlQueryT -} - -data MockQuery = MockQuery (forall record a. Typeable record => SqlQueryRep record a -> Maybe a) - -withRecord :: forall record. Typeable record => (forall a. SqlQueryRep record a -> Maybe a) -> MockQuery -withRecord f = MockQuery $ \(rep :: SqlQueryRep someRecord result) -> - case eqT @record @someRecord of - Just Refl -> f rep - Nothing -> Nothing - -newtype MockSqlQueryT m a = MockSqlQueryT - { unMockSqlQueryT :: ReaderT [MockQuery] m a - } deriving - ( Functor - , Applicative - , Monad - , MonadIO - ) - -runMockSqlQueryT :: MockSqlQueryT m a -> [MockQuery] -> m a -runMockSqlQueryT action mockQueries = (`runReaderT` mockQueries) . unMockSqlQueryT $ action - -instance Monad m => MonadSqlQuery (MockSqlQueryT m) where - runQueryRep rep = do - mockQueries <- MockSqlQueryT ask - maybe (error $ "Could not find mock for query: " ++ show rep) return - $ msum $ map tryMockQuery mockQueries - where - tryMockQuery (MockQuery f) = f rep - - runRawQuery _ = error "Can't run raw queries with MockSqlQueryT" - - withTransaction = id diff --git a/src/Database/Persist/Monad/Class.hs b/src/Database/Persist/Monad/Class.hs new file mode 100644 index 0000000..5a5a54e --- /dev/null +++ b/src/Database/Persist/Monad/Class.hs @@ -0,0 +1,64 @@ +module Database.Persist.Monad.Class + ( MonadSqlQuery(..) + ) where + +import Control.Monad.Trans.Class (lift) +import qualified Control.Monad.Trans.Except as Except +import qualified Control.Monad.Trans.Identity as Identity +import qualified Control.Monad.Trans.Maybe as Maybe +import qualified Control.Monad.Trans.RWS.Lazy as RWS.Lazy +import qualified Control.Monad.Trans.RWS.Strict as RWS.Strict +import qualified Control.Monad.Trans.Reader as Reader +import qualified Control.Monad.Trans.State.Lazy as State.Lazy +import qualified Control.Monad.Trans.State.Strict as State.Strict +import qualified Control.Monad.Trans.Writer.Lazy as Writer.Lazy +import qualified Control.Monad.Trans.Writer.Strict as Writer.Strict +import Data.Typeable (Typeable) + +import Database.Persist.Monad.SqlQueryRep (SqlQueryRep) + +class Monad m => MonadSqlQuery m where + runQueryRep :: Typeable record => SqlQueryRep record a -> m a + withTransaction :: m a -> m a + +{- Instances for common monad transformers -} + +instance MonadSqlQuery m => MonadSqlQuery (Reader.ReaderT r m) where + runQueryRep = lift . runQueryRep + withTransaction = Reader.mapReaderT withTransaction + +instance MonadSqlQuery m => MonadSqlQuery (Except.ExceptT e m) where + runQueryRep = lift . runQueryRep + withTransaction = Except.mapExceptT withTransaction + +instance MonadSqlQuery m => MonadSqlQuery (Identity.IdentityT m) where + runQueryRep = lift . runQueryRep + withTransaction = Identity.mapIdentityT withTransaction + +instance MonadSqlQuery m => MonadSqlQuery (Maybe.MaybeT m) where + runQueryRep = lift . runQueryRep + withTransaction = Maybe.mapMaybeT withTransaction + +instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (RWS.Lazy.RWST r w s m) where + runQueryRep = lift . runQueryRep + withTransaction = RWS.Lazy.mapRWST withTransaction + +instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (RWS.Strict.RWST r w s m) where + runQueryRep = lift . runQueryRep + withTransaction = RWS.Strict.mapRWST withTransaction + +instance MonadSqlQuery m => MonadSqlQuery (State.Lazy.StateT s m) where + runQueryRep = lift . runQueryRep + withTransaction = State.Lazy.mapStateT withTransaction + +instance MonadSqlQuery m => MonadSqlQuery (State.Strict.StateT s m) where + runQueryRep = lift . runQueryRep + withTransaction = State.Strict.mapStateT withTransaction + +instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (Writer.Lazy.WriterT w m) where + runQueryRep = lift . runQueryRep + withTransaction = Writer.Lazy.mapWriterT withTransaction + +instance (Monoid w, MonadSqlQuery m) => MonadSqlQuery (Writer.Strict.WriterT w m) where + runQueryRep = lift . runQueryRep + withTransaction = Writer.Strict.mapWriterT withTransaction diff --git a/src/Database/Persist/Monad/Shim.hs b/src/Database/Persist/Monad/Shim.hs new file mode 100644 index 0000000..2bc97c1 --- /dev/null +++ b/src/Database/Persist/Monad/Shim.hs @@ -0,0 +1,24 @@ +{-# LANGUAGE GADTs #-} + +module Database.Persist.Monad.Shim where + +import Control.Monad.IO.Unlift (MonadUnliftIO) +import Data.Text (Text) +import Data.Typeable (Typeable) +import Database.Persist +import Database.Persist.Sql + +import Database.Persist.Monad.Class (MonadSqlQuery(..)) +import Database.Persist.Monad.SqlQueryRep (SqlQueryRep(..)) + +selectList :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => [Filter record] -> [SelectOpt record] -> m [Entity record] +selectList a b = runQueryRep $ SelectList a b + +insert :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => record -> m (Key record) +insert a = runQueryRep $ Insert a + +insert_ :: (PersistRecordBackend record SqlBackend, Typeable record, MonadSqlQuery m) => record -> m () +insert_ a = runQueryRep $ Insert_ a + +runMigrationSilent :: (MonadUnliftIO m, MonadSqlQuery m) => Migration -> m [Text] +runMigrationSilent a = runQueryRep $ RunMigrationsSilent a diff --git a/src/Database/Persist/Monad/SqlQueryRep.hs b/src/Database/Persist/Monad/SqlQueryRep.hs new file mode 100644 index 0000000..39378d3 --- /dev/null +++ b/src/Database/Persist/Monad/SqlQueryRep.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Database.Persist.Monad.SqlQueryRep + ( SqlQueryRep(..) + , runSqlQueryRep + ) where + +import Control.Monad.IO.Unlift (MonadUnliftIO) +import Data.Proxy (Proxy(..)) +import Data.Text (Text) +import Data.Typeable (Typeable, eqT, typeRep, (:~:)(..)) +import Data.Void (Void) +import Database.Persist (Entity, Filter, Key, PersistRecordBackend, SelectOpt) +import Database.Persist.Sql (Migration, SqlBackend) +import qualified Database.Persist.Sql as Persist + +-- TODO: generate this +data SqlQueryRep record a where + SelectList + :: PersistRecordBackend record SqlBackend + => [Filter record] -> [SelectOpt record] -> SqlQueryRep record [Entity record] + + Insert + :: PersistRecordBackend record SqlBackend + => record -> SqlQueryRep record (Key record) + + Insert_ + :: PersistRecordBackend record SqlBackend + => record -> SqlQueryRep record () + + RunMigrationsSilent + :: Migration -> SqlQueryRep Void [Text] + +instance Typeable record => Show (SqlQueryRep record a) where + show = \case + SelectList{} -> "SelectList{..}" ++ record + Insert{} -> "Insert{..}" ++ record + Insert_{} -> "Insert_{..}" ++ record + RunMigrationsSilent{} -> "RunMigrationsSilent{..}" ++ record + where + record = case recordTypeRep of + Just recordType -> "<" ++ show recordType ++ ">" + Nothing -> "" + recordTypeRep = case eqT @record @Void of + Just Refl -> Nothing + Nothing -> Just $ typeRep $ Proxy @record + +runSqlQueryRep :: MonadUnliftIO m => SqlQueryRep record a -> Persist.SqlPersistT m a +runSqlQueryRep = \case + SelectList a b -> Persist.selectList a b + Insert a -> Persist.insert a + Insert_ a -> Persist.insert_ a + RunMigrationsSilent a -> Persist.runMigrationSilent a diff --git a/src/Database/Persist/Monad/TestUtils.hs b/src/Database/Persist/Monad/TestUtils.hs new file mode 100644 index 0000000..f543abb --- /dev/null +++ b/src/Database/Persist/Monad/TestUtils.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Database.Persist.Monad.TestUtils + ( MockSqlQueryT + , runMockSqlQueryT + , withRecord + ) where + +import Control.Monad (msum) +import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Reader (ReaderT, ask, runReaderT) +import Data.Typeable (Typeable, eqT, (:~:)(..)) + +import Database.Persist.Monad (MonadSqlQuery(..), SqlQueryRep) + +newtype MockSqlQueryT m a = MockSqlQueryT + { unMockSqlQueryT :: ReaderT [MockQuery] m a + } deriving + ( Functor + , Applicative + , Monad + , MonadIO + ) + +runMockSqlQueryT :: MockSqlQueryT m a -> [MockQuery] -> m a +runMockSqlQueryT action mockQueries = (`runReaderT` mockQueries) . unMockSqlQueryT $ action + +instance Monad m => MonadSqlQuery (MockSqlQueryT m) where + runQueryRep rep = do + mockQueries <- MockSqlQueryT ask + maybe (error $ "Could not find mock for query: " ++ show rep) return + $ msum $ map tryMockQuery mockQueries + where + tryMockQuery (MockQuery f) = f rep + + withTransaction = id + +data MockQuery = MockQuery (forall record a. Typeable record => SqlQueryRep record a -> Maybe a) + +withRecord :: forall record. Typeable record => (forall a. SqlQueryRep record a -> Maybe a) -> MockQuery +withRecord f = MockQuery $ \(rep :: SqlQueryRep someRecord result) -> + case eqT @record @someRecord of + Just Refl -> f rep + Nothing -> Nothing diff --git a/test/Mocked.hs b/test/Mocked.hs index c735945..8293285 100644 --- a/test/Mocked.hs +++ b/test/Mocked.hs @@ -10,6 +10,7 @@ import Test.Tasty import Test.Tasty.HUnit import Database.Persist.Monad +import Database.Persist.Monad.TestUtils import Example tests :: TestTree