diff --git a/README.md b/README.md index a4e2cef..a56393c 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Additionally, the `examples` directory shows a number of concrete use cases, e.g * [configuration](./examples/reader.hs) * [cookies](./examples/cookies.hs) * [file upload](./examples/upload.hs) +* [session](./examples/session.hs) * and more ## More Information diff --git a/Web/Scotty.hs b/Web/Scotty.hs index 415c8a1..2786415 100644 --- a/Web/Scotty.hs +++ b/Web/Scotty.hs @@ -55,7 +55,12 @@ module Web.Scotty , ScottyM, ActionM, RoutePattern, File, Content(..), Kilobytes, ErrorHandler, Handler(..) , ScottyState, defaultScottyState -- ** Functions from Cookie module - , setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie + , setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie + -- ** Session Management + , Session (..), SessionId, SessionJar, SessionStatus + , createSessionJar, createUserSession, createSession, addSession + , readSession, getUserSession, getSession, readUserSession + , deleteSession, maintainSessions ) where import qualified Web.Scotty.Trans as Trans @@ -76,7 +81,9 @@ import qualified Network.Wai.Parse as W import Web.FormUrlEncoded (FromForm) import Web.Scotty.Internal.Types (ScottyT, ActionT, ErrorHandler, Param, RoutePattern, Options, defaultOptions, File, Kilobytes, ScottyState, defaultScottyState, ScottyException, StatusError(..), Content(..)) import UnliftIO.Exception (Handler(..), catch) -import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie) +import Web.Scotty.Cookie (setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie) +import Web.Scotty.Session (Session (..), SessionId, SessionJar, SessionStatus , createSessionJar, + createSession, addSession, maintainSessions) {- $setup >>> :{ @@ -594,5 +601,32 @@ literal :: String -> RoutePattern literal = Trans.literal - - +-- | Retrieves a session by its ID from the session jar. +getSession :: SessionJar a -> SessionId -> ActionM (Either SessionStatus (Session a)) +getSession = Trans.getSession + +-- | Deletes a session by its ID from the session jar. +deleteSession :: SessionJar a -> SessionId -> ActionM () +deleteSession = Trans.deleteSession + +{- | Retrieves the current user's session based on the "sess_id" cookie. +| Returns `Left SessionStatus` if the session is expired or does not exist. +-} +getUserSession :: SessionJar a -> ActionM (Either SessionStatus (Session a)) +getUserSession = Trans.getUserSession + +-- | Reads the content of a session by its ID. +readSession :: SessionJar a -> SessionId -> ActionM (Either SessionStatus a) +readSession = Trans.readSession + +-- | Reads the content of the current user's session. +readUserSession ::SessionJar a -> ActionM (Either SessionStatus a) +readUserSession = Trans.readUserSession + +-- | Creates a new session for a user, storing the content and setting a cookie. +createUserSession :: + SessionJar a -- ^ SessionJar, which can be created by createSessionJar + -> Maybe Int -- ^ Optional expiration time (in seconds) + -> a -- ^ Content + -> ActionM (Session a) +createUserSession = Trans.createUserSession diff --git a/Web/Scotty/Session.hs b/Web/Scotty/Session.hs new file mode 100644 index 0000000..82fb9ed --- /dev/null +++ b/Web/Scotty/Session.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} + +{- | +Module : Web.Scotty.Session +Copyright : (c) 2025 Tushar Adhatrao, + (c) 2025 Marco Zocca + +License : BSD-3-Clause +Maintainer : +Stability : experimental +Portability : GHC + +This module provides session management functionality for Scotty web applications. + +==Example usage: + +@ +\{\-\# LANGUAGE OverloadedStrings \#\-\} + +import Web.Scotty +import Web.Scotty.Session +import Control.Monad.IO.Class (liftIO) +main :: IO () +main = do + -- Create a session jar + sessionJar <- createSessionJar + scotty 3000 $ do + -- Route to create a session + get "/create" $ do + sess <- createUserSession sessionJar "user data" + html $ "Session created with ID: " <> sessId sess + -- Route to read a session + get "/read" $ do + eSession <- getUserSession sessionJar + case eSession of + Left _-> html "No session found or session expired." + Right sess -> html $ "Session content: " <> sessContent sess +@ +-} +module Web.Scotty.Session ( + Session (..), + SessionId, + SessionJar, + SessionStatus, + + -- * Create Session Jar + createSessionJar, + + -- * Create session + createUserSession, + createSession, + + -- * Read session + readUserSession, + readSession, + getUserSession, + getSession, + + -- * Add session + addSession, + + -- * Delte session + deleteSession, + + -- * Helper functions + maintainSessions, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.IO.Class (MonadIO (..)) +import qualified Data.HashMap.Strict as HM +import qualified Data.Text as T +import Data.Time (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) +import System.Random (randomRIO) +import Web.Scotty.Action (ActionT) +import Web.Scotty.Cookie + +-- | Type alias for session identifiers. +type SessionId = T.Text + +-- | Status of a session lookup. +data SessionStatus = SessionNotFound | SessionExpired + deriving (Show, Eq) + +-- | Represents a session containing an ID, expiration time, and content. +data Session a = Session + { sessId :: SessionId + -- ^ Unique identifier for the session. + , sessExpiresAt :: UTCTime + -- ^ Expiration time of the session. + , sessContent :: a + -- ^ Content stored in the session. + } + deriving (Eq, Show) + +-- | Type for session storage, a transactional variable containing a map of session IDs to sessions. +type SessionJar a = TVar (HM.HashMap SessionId (Session a)) + +-- | Creates a new session jar and starts a background thread to maintain it. +createSessionJar :: IO (SessionJar a) +createSessionJar = do + storage <- newTVarIO HM.empty + _ <- forkIO $ maintainSessions storage + return storage + +-- | Continuously removes expired sessions from the session jar. +maintainSessions :: SessionJar a -> IO () +maintainSessions sessionJar = + forever $ do + now <- getCurrentTime + let stillValid sess = sessExpiresAt sess > now + atomically $ modifyTVar sessionJar $ \m -> HM.filter stillValid m + threadDelay 1000000 + + +-- | Adds or overwrites a new session to the session jar. +addSession :: SessionJar a -> Session a -> IO () +addSession sessionJar sess = + atomically $ modifyTVar sessionJar $ \m -> HM.insert (sessId sess) sess m + +-- | Retrieves a session by its ID from the session jar. +getSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Either SessionStatus (Session a)) +getSession sessionJar sId = + do + s <- liftIO $ readTVarIO sessionJar + case HM.lookup sId s of + Nothing -> pure $ Left SessionNotFound + Just sess -> do + now <- liftIO getCurrentTime + if sessExpiresAt sess < now + then deleteSession sessionJar (sessId sess) >> pure (Left SessionExpired) + else pure $ Right sess + +-- | Deletes a session by its ID from the session jar. +deleteSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m () +deleteSession sessionJar sId = + liftIO $ + atomically $ + modifyTVar sessionJar $ + HM.delete sId + +{- | Retrieves the current user's session based on the "sess_id" cookie. +| Returns `Left SessionStatus` if the session is expired or does not exist. +-} +getUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Either SessionStatus (Session a)) +getUserSession sessionJar = do + getCookie "sess_id" >>= \case + Nothing -> pure $ Left SessionNotFound + Just sid -> lookupSession sid + where + lookupSession = getSession sessionJar + +-- | Reads the content of a session by its ID. +readSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Either SessionStatus a) +readSession sessionJar sId = do + res <- getSession sessionJar sId + return $ sessContent <$> res + +-- | Reads the content of the current user's session. +readUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Either SessionStatus a) +readUserSession sessionJar = do + res <- getUserSession sessionJar + return $ sessContent <$> res + +-- | The time-to-live for sessions, in seconds. +sessionTTL :: NominalDiffTime +sessionTTL = 36000 -- in seconds + +-- | Creates a new session for a user, storing the content and setting a cookie. +createUserSession :: (MonadIO m) => + SessionJar a -- ^ SessionJar, which can be created by createSessionJar + -> Maybe Int -- ^ Optional expiration time (in seconds) + -> a -- ^ Content + -> ActionT m (Session a) +createUserSession sessionJar mbExpirationTime content = do + sess <- liftIO $ createSession sessionJar mbExpirationTime content + setSimpleCookie "sess_id" (sessId sess) + return sess + +-- | Creates a new session with a generated ID, sets its expiration, +-- | and adds it to the session jar. +createSession :: SessionJar a -> Maybe Int -> a -> IO (Session a) +createSession sessionJar mbExpirationTime content = do + sId <- liftIO $ T.pack <$> replicateM 32 (randomRIO ('a', 'z')) + now <- getCurrentTime + let expiresAt = addUTCTime (maybe sessionTTL fromIntegral mbExpirationTime) now + sess = Session sId expiresAt content + liftIO $ addSession sessionJar sess + return $ Session sId expiresAt content diff --git a/Web/Scotty/Trans.hs b/Web/Scotty/Trans.hs index b3468ea..fa84136 100644 --- a/Web/Scotty/Trans.hs +++ b/Web/Scotty/Trans.hs @@ -64,7 +64,11 @@ module Web.Scotty.Trans , ScottyT, ActionT , ScottyState, defaultScottyState -- ** Functions from Cookie module - , setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie + , setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie + -- ** Session Management + , Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions ) where import Blaze.ByteString.Builder (fromByteString) @@ -90,6 +94,9 @@ import Web.Scotty.Body (newBodyInfo) import UnliftIO.Exception (Handler(..), catch) import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie) +import Web.Scotty.Session (Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions) -- | Run a scotty application using the warp server. diff --git a/changelog.md b/changelog.md index 102bc49..1dba44c 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,6 @@ ## next [????.??.??] +* Added sessions (#317). * Fixed cookie example from `Cookie` module documentation. `getCookie` Function would return strict variant of `Text`. Will convert it into lazy variant using `fromStrict`. * Exposed simple functions of `Cookie` module via `Web.Scotty` & `Web.Scotty.Trans`. * Add tests for URL encoding of query parameters and form parameters. Add `formData` action for decoding `FromForm` instances (#321). diff --git a/examples/session.hs b/examples/session.hs new file mode 100644 index 0000000..45a0205 --- /dev/null +++ b/examples/session.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE OverloadedStrings #-} +module Main (main) where + +import Web.Scotty +import qualified Data.Text.Lazy as LT +import qualified Data.Text as T + +main :: IO () +main = do + sessionJar <- liftIO createSessionJar :: IO (SessionJar T.Text) + scotty 3000 $ do + -- Login route + get "/login" $ do + username <- queryParam "username" :: ActionM String + password <- queryParam "password" :: ActionM String + if username == "foo" && password == "bar" + then do + _ <- createUserSession sessionJar Nothing "foo" + text "Login successful!" + else + text "Invalid username or password." + -- Dashboard route + get "/dashboard" $ do + mUser <- readUserSession sessionJar + case mUser of + Nothing -> text "Hello, user." + Just userName -> text $ "Hello, " <> LT.fromStrict userName <> "." + -- Logout route + get "/logout" $ do + deleteCookie "sess_id" + text "Logged out successfully." diff --git a/scotty.cabal b/scotty.cabal index bd311a2..a347c47 100644 --- a/scotty.cabal +++ b/scotty.cabal @@ -64,6 +64,7 @@ Library Web.Scotty.Trans.Strict Web.Scotty.Internal.Types Web.Scotty.Cookie + Web.Scotty.Session other-modules: Web.Scotty.Action Web.Scotty.Body Web.Scotty.Route @@ -93,7 +94,8 @@ Library unordered-containers >= 0.2.10.0 && < 0.3, wai >= 3.0.0 && < 3.3, wai-extra >= 3.1.14, - warp >= 3.0.13 + warp >= 3.0.13, + random >= 1.0.0.0 if impl(ghc < 8.0) build-depends: fail diff --git a/test/Web/ScottySpec.hs b/test/Web/ScottySpec.hs index c6c36df..0068eff 100644 --- a/test/Web/ScottySpec.hs +++ b/test/Web/ScottySpec.hs @@ -11,6 +11,7 @@ import Data.Char import Data.String import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as TL +import qualified Data.Text as T import qualified Data.Text.Lazy.Encoding as TLE import Data.Time (UTCTime(..)) import Data.Time.Calendar (fromGregorian) @@ -537,6 +538,20 @@ spec = do withApp (Scotty.get "/nested" (nested simpleApp)) $ do it "responds with the expected simpleApp response" $ do get "/nested" `shouldRespondWith` 200 {matchHeaders = ["Content-Type" <:> "text/plain"], matchBody = "Hello, Web!"} + + describe "Session Management" $ do + withApp (Scotty.get "/scotty" $ do + sessionJar <- liftIO createSessionJar + sess <- createUserSession sessionJar Nothing ("foo" :: T.Text) + mRes <- readSession sessionJar (sessId sess) + case mRes of + Left _ -> Scotty.status status400 + Right res -> do + if res /= "foo" then Scotty.status status400 + else text "all good" + ) $ do + it "Roundtrip of session by adding and fetching a value" $ do + get "/scotty" `shouldRespondWith` 200 -- Unix sockets not available on Windows #if !defined(mingw32_HOST_OS)