From ffbb5c6a4fc6fc7b9517b22477efa7f3b4c0257d Mon Sep 17 00:00:00 2001 From: Marc Scholten Date: Fri, 15 Mar 2024 10:43:27 +0100 Subject: [PATCH] Simplify session vault key handling Moved the session vault key from the ApplicationContext and RequestContext data structure to a global variable. This is the suggested way by the WAI developers. See https://www.yesodweb.com/blog/2015/10/using-wais-vault --- IHP/ApplicationContext.hs | 2 -- IHP/Controller/RequestContext.hs | 1 - IHP/Controller/Session.hs | 11 +++++++++-- IHP/ControllerSupport.hs | 4 ++-- IHP/RouterSupport.hs | 2 +- IHP/Server.hs | 13 ++++++------- IHP/Test/Mocking.hs | 17 +++++++---------- Test/Controller/AccessDeniedSpec.hs | 2 +- Test/Controller/CookieSpec.hs | 2 +- Test/Controller/NotFoundSpec.hs | 2 +- Test/Controller/ParamSpec.hs | 4 ++-- Test/View/CSSFrameworkSpec.hs | 2 +- Test/View/FormSpec.hs | 2 +- Test/ViewSupportSpec.hs | 2 +- 14 files changed, 33 insertions(+), 33 deletions(-) diff --git a/IHP/ApplicationContext.hs b/IHP/ApplicationContext.hs index a21fcd8dc..9174e71e3 100644 --- a/IHP/ApplicationContext.hs +++ b/IHP/ApplicationContext.hs @@ -2,14 +2,12 @@ module IHP.ApplicationContext where import IHP.Prelude import Network.Wai.Session (Session) -import qualified Data.Vault.Lazy as Vault import IHP.AutoRefresh.Types (AutoRefreshServer) import IHP.FrameworkConfig (FrameworkConfig) import IHP.PGListener (PGListener) data ApplicationContext = ApplicationContext { modelContext :: !ModelContext - , session :: !(Vault.Key (Session IO ByteString ByteString)) , autoRefreshServer :: !(IORef AutoRefreshServer) , frameworkConfig :: !FrameworkConfig , pgListener :: PGListener diff --git a/IHP/Controller/RequestContext.hs b/IHP/Controller/RequestContext.hs index bf6a46ddf..fffc93b66 100644 --- a/IHP/Controller/RequestContext.hs +++ b/IHP/Controller/RequestContext.hs @@ -24,6 +24,5 @@ data RequestContext = RequestContext { request :: Request , respond :: Respond , requestBody :: RequestBody - , vault :: (Vault.Key (Session IO ByteString ByteString)) , frameworkConfig :: FrameworkConfig } diff --git a/IHP/Controller/Session.hs b/IHP/Controller/Session.hs index d0e633c1c..3d5c207e6 100644 --- a/IHP/Controller/Session.hs +++ b/IHP/Controller/Session.hs @@ -24,6 +24,7 @@ module IHP.Controller.Session , getSessionEither , deleteSession , getSessionAndClear + , sessionVaultKey ) where import IHP.Prelude @@ -36,6 +37,8 @@ import qualified Network.Wai as Wai import qualified Data.Serialize as Serialize import Data.Serialize (Serialize) import Data.Serialize.Text () +import qualified Network.Wai.Session +import System.IO.Unsafe (unsafePerformIO) -- | Types of possible errors as a result of -- requesting a value from the session storage @@ -161,5 +164,9 @@ sessionVault = case vaultLookup of Just session -> session Nothing -> error "sessionInsert: The session vault is missing in the request" where - RequestContext { request, vault } = ?context.requestContext - vaultLookup = Vault.lookup vault (Wai.vault request) + RequestContext { request } = ?context.requestContext + vaultLookup = Vault.lookup sessionVaultKey request.vault + +sessionVaultKey :: Vault.Key (Network.Wai.Session.Session IO ByteString ByteString) +sessionVaultKey = unsafePerformIO Vault.newKey +{-# NOINLINE sessionVaultKey #-} \ No newline at end of file diff --git a/IHP/ControllerSupport.hs b/IHP/ControllerSupport.hs index b32988fde..197958203 100644 --- a/IHP/ControllerSupport.hs +++ b/IHP/ControllerSupport.hs @@ -259,7 +259,7 @@ requestBodyJSON = {-# INLINE createRequestContext #-} createRequestContext :: ApplicationContext -> Request -> Respond -> IO RequestContext -createRequestContext ApplicationContext { session, frameworkConfig } request respond = do +createRequestContext ApplicationContext { frameworkConfig } request respond = do let contentType = lookup hContentType (requestHeaders request) requestBody <- case contentType of "application/json" -> do @@ -270,7 +270,7 @@ createRequestContext ApplicationContext { session, frameworkConfig } request res (params, files) <- WaiParse.parseRequestBodyEx frameworkConfig.parseRequestBodyOptions WaiParse.lbsBackEnd request pure RequestContext.FormBody { .. } - pure RequestContext.RequestContext { request, respond, requestBody, vault = session, frameworkConfig } + pure RequestContext.RequestContext { request, respond, requestBody, frameworkConfig } -- | Returns a custom config parameter diff --git a/IHP/RouterSupport.hs b/IHP/RouterSupport.hs index addf6f42d..96da10814 100644 --- a/IHP/RouterSupport.hs +++ b/IHP/RouterSupport.hs @@ -838,7 +838,7 @@ withPrefix prefix routes = string prefix >> choice (map (\r -> r <* endOfInput) frontControllerToWAIApp :: forall app (autoRefreshApp :: Type). (?applicationContext :: ApplicationContext, FrontController app, WSApp autoRefreshApp, Typeable autoRefreshApp, InitControllerContext ()) => Middleware -> app -> Application -> Application frontControllerToWAIApp middleware application notFoundAction request respond = do - let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, vault = ?applicationContext.session, frameworkConfig = ?applicationContext.frameworkConfig } + let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, frameworkConfig = ?applicationContext.frameworkConfig } let ?context = requestContext diff --git a/IHP/Server.hs b/IHP/Server.hs index ce4e8e316..0157da068 100644 --- a/IHP/Server.hs +++ b/IHP/Server.hs @@ -8,6 +8,7 @@ import Network.Wai.Middleware.MethodOverridePost (methodOverridePost) import Network.Wai.Session (withSession, Session) import Network.Wai.Session.ClientSession (clientsessionStore) import qualified Web.ClientSession as ClientSession +import IHP.Controller.Session (sessionVaultKey) import qualified Data.Vault.Lazy as Vault import IHP.ApplicationContext import qualified IHP.ControllerSupport as ControllerSupport @@ -48,14 +49,12 @@ run configBuilder = do withInitalizers frameworkConfig modelContext do withPGListener \pgListener -> do - sessionVault <- Vault.newKey - autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener) let ?modelContext = modelContext - let ?applicationContext = ApplicationContext { modelContext = ?modelContext, session = sessionVault, autoRefreshServer, frameworkConfig, pgListener } + let ?applicationContext = ApplicationContext { modelContext = ?modelContext, autoRefreshServer, frameworkConfig, pgListener } - sessionMiddleware <- initSessionMiddleware sessionVault frameworkConfig + sessionMiddleware <- initSessionMiddleware frameworkConfig staticApp <- initStaticApp frameworkConfig let corsMiddleware = initCorsMiddleware frameworkConfig let requestLoggerMiddleware = frameworkConfig.requestLoggerMiddleware @@ -108,8 +107,8 @@ initStaticApp frameworkConfig = do pure (Static.staticApp appSettings) -initSessionMiddleware :: Vault.Key (Session IO ByteString ByteString) -> FrameworkConfig -> IO Middleware -initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do +initSessionMiddleware :: FrameworkConfig -> IO Middleware +initSessionMiddleware FrameworkConfig { sessionCookie } = do let path = "Config/client_session_key.aes" hasSessionSecretEnvVar <- EnvVar.hasEnvVar "IHP_SESSION_SECRET" @@ -118,7 +117,7 @@ initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do if hasSessionSecretEnvVar || not doesConfigDirectoryExist then ClientSession.getKeyEnv "IHP_SESSION_SECRET" else ClientSession.getKey path - let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVault + let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVaultKey pure sessionMiddleware initCorsMiddleware :: FrameworkConfig -> Middleware diff --git a/IHP/Test/Mocking.hs b/IHP/Test/Mocking.hs index 0598b2abd..8582cf66a 100644 --- a/IHP/Test/Mocking.hs +++ b/IHP/Test/Mocking.hs @@ -33,6 +33,7 @@ import qualified Network.Wai.Session import qualified Data.Serialize as Serialize import qualified Control.Exception as Exception import qualified IHP.PGListener as PGListener +import IHP.Controller.Session (sessionVaultKey) type ContextParameters application = (?applicationContext :: ApplicationContext, ?context :: RequestContext, ?modelContext :: ModelContext, ?application :: application, InitControllerContext application, ?mocking :: MockContext application) @@ -58,17 +59,15 @@ withIHPApp application configBuilder hspecAction = do withTestDatabase \testDatabase -> do modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections (testDatabase.url) logger - session <- Vault.newKey pgListener <- PGListener.init modelContext autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener) - let sessionVault = Vault.insert session mempty Vault.empty - let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener } + let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty + let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener } let requestContext = RequestContext { request = defaultRequest {vault = sessionVault} , requestBody = FormBody [] [] , respond = const (pure ResponseReceived) - , vault = session , frameworkConfig = frameworkConfig } (hspecAction MockContext { .. }) @@ -81,17 +80,15 @@ mockContextNoDatabase application configBuilder = do logger <- newLogger def { level = Warn } -- don't log queries modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections databaseUrl logger - session <- Vault.newKey - let sessionVault = Vault.insert session mempty Vault.empty + let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty pgListener <- PGListener.init modelContext autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener) - let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener } + let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener } let requestContext = RequestContext { request = defaultRequest {vault = sessionVault} , requestBody = FormBody [] [] , respond = \resp -> pure ResponseReceived - , vault = session , frameworkConfig = frameworkConfig } pure MockContext{..} @@ -230,8 +227,8 @@ withUser user callback = insertSession key value = pure () - newVault = Vault.insert vaultKey newSession (Wai.vault request) - RequestContext { request, vault = vaultKey } = ?mocking.requestContext + newVault = Vault.insert sessionVaultKey newSession (Wai.vault request) + RequestContext { request } = ?mocking.requestContext sessionValue = Serialize.encode (user.id) sessionKey = cs (Session.sessionKey @user) diff --git a/Test/Controller/AccessDeniedSpec.hs b/Test/Controller/AccessDeniedSpec.hs index 34ca74011..dba55e502 100644 --- a/Test/Controller/AccessDeniedSpec.hs +++ b/Test/Controller/AccessDeniedSpec.hs @@ -74,7 +74,7 @@ config = do makeApplication :: (?applicationContext :: ApplicationContext) => IO Application makeApplication = do store <- Session.mapStore_ - let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session + let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app)) assertAccessDenied :: SResponse -> IO () diff --git a/Test/Controller/CookieSpec.hs b/Test/Controller/CookieSpec.hs index ca950a556..8750b175e 100644 --- a/Test/Controller/CookieSpec.hs +++ b/Test/Controller/CookieSpec.hs @@ -37,6 +37,6 @@ createControllerContext = do let requestBody = FormBody { params = [], files = [] } request = Wai.defaultRequest - requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" } + requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" } let ?requestContext = requestContext newControllerContext diff --git a/Test/Controller/NotFoundSpec.hs b/Test/Controller/NotFoundSpec.hs index 3448c5cec..b79ca3cd3 100644 --- a/Test/Controller/NotFoundSpec.hs +++ b/Test/Controller/NotFoundSpec.hs @@ -74,7 +74,7 @@ config = do makeApplication :: (?applicationContext :: ApplicationContext) => IO Application makeApplication = do store <- Session.mapStore_ - let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session + let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app)) assertNotFound :: SResponse -> IO () diff --git a/Test/Controller/ParamSpec.hs b/Test/Controller/ParamSpec.hs index 630ad13c0..d45ed7c00 100644 --- a/Test/Controller/ParamSpec.hs +++ b/Test/Controller/ParamSpec.hs @@ -434,14 +434,14 @@ createControllerContextWithParams params = let requestBody = FormBody { params, files = [] } request = Wai.defaultRequest - requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" } + requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" } in FrozenControllerContext { requestContext, customFields = TypeMap.empty } createControllerContextWithJson params = let requestBody = JSONBody { jsonPayload = Just (json params), rawPayload = cs params } request = Wai.defaultRequest - requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" } + requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" } in FrozenControllerContext { requestContext, customFields = TypeMap.empty } json :: Text -> Aeson.Value diff --git a/Test/View/CSSFrameworkSpec.hs b/Test/View/CSSFrameworkSpec.hs index 11f1a2c24..945defe00 100644 --- a/Test/View/CSSFrameworkSpec.hs +++ b/Test/View/CSSFrameworkSpec.hs @@ -721,5 +721,5 @@ createControllerContextWithCSSFramework cssFramework = do option cssFramework let requestBody = FormBody { params = [], files = [] } let request = Wai.defaultRequest - let requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = frameworkConfig } + let requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = frameworkConfig } pure FrozenControllerContext { requestContext, customFields = TypeMap.empty } \ No newline at end of file diff --git a/Test/View/FormSpec.hs b/Test/View/FormSpec.hs index 7cc5d27bb..75532c411 100644 --- a/Test/View/FormSpec.hs +++ b/Test/View/FormSpec.hs @@ -49,7 +49,7 @@ createControllerContext = do frameworkConfig <- FrameworkConfig.buildFrameworkConfig (pure ()) let requestBody = FormBody { params = [], files = [] } let request = Wai.defaultRequest - let requestContext = RequestContext { request, respond = undefined, requestBody, vault = undefined, frameworkConfig = frameworkConfig } + let requestContext = RequestContext { request, respond = undefined, requestBody, frameworkConfig = frameworkConfig } pure FrozenControllerContext { requestContext, customFields = mempty } data Project' = Project {id :: (Id' "projects"), title :: Text, meta :: MetaBag} deriving (Eq, Show) diff --git a/Test/ViewSupportSpec.hs b/Test/ViewSupportSpec.hs index f5fe528d7..191bda6e8 100644 --- a/Test/ViewSupportSpec.hs +++ b/Test/ViewSupportSpec.hs @@ -101,7 +101,7 @@ config = do makeApplication :: (?applicationContext :: ApplicationContext) => IO Application makeApplication = do store <- Session.mapStore_ - let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session + let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey pure (sessionMiddleware $ (Server.application handleNotFound (\app -> app))) tests :: Spec