diff --git a/src/Network/WebSockets/Server.hs b/src/Network/WebSockets/Server.hs index ed77d86..827ed37 100644 --- a/src/Network/WebSockets/Server.hs +++ b/src/Network/WebSockets/Server.hs @@ -19,17 +19,15 @@ module Network.WebSockets.Server -------------------------------------------------------------------------------- -import Control.Concurrent (threadDelay) +import Control.Concurrent (takeMVar, tryPutMVar, + newEmptyMVar) import qualified Control.Concurrent.Async as Async -import Control.Exception (Exception, allowInterrupt, - bracket, bracketOnError, - finally, mask_, throwIO) -import Control.Monad (forever, void, when) -import qualified Data.IORef as IORef -import Data.Maybe (isJust) +import Control.Exception (Exception, bracket, + bracketOnError, finally, mask_, + throwIO) import Network.Socket (Socket) import qualified Network.Socket as S -import qualified System.Clock as Clock +import System.Timeout (timeout) -------------------------------------------------------------------------------- @@ -110,59 +108,44 @@ defaultServerOptions = ServerOptions runServerWithOptions :: ServerOptions -> ServerApp -> IO a runServerWithOptions opts app = S.withSocketsDo $ bracket - (makeListenSocket host port) - S.close $ \sock -> mask_ $ forever $ do - allowInterrupt - (conn, _) <- S.accept sock - - -- This IORef holds a time at which the thread may be killed. This time - -- can be extended by calling 'tickle'. - killRef <- IORef.newIORef =<< (+ killDelay) <$> getSecs - let tickle = IORef.writeIORef killRef =<< (+ killDelay) <$> getSecs - - -- Update the connection options to call 'tickle' whenever a pong is - -- received. - let connOpts' - | not useKiller = connOpts - | otherwise = connOpts - { connectionOnPong = tickle >> connectionOnPong connOpts - } - - -- Run the application. - appAsync <- Async.asyncWithUnmask $ \unmask -> - (unmask $ do - runApp conn connOpts' app) `finally` - (S.close conn) - - -- Install the killer if required. - when useKiller $ void $ Async.async (killer killRef appAsync) - where - host = serverHost opts - port = serverPort opts - connOpts = serverConnectionOptions opts - - -- Get the current number of seconds on some clock. - getSecs = Clock.sec <$> Clock.getTime Clock.Monotonic - - -- Parse the 'serverRequirePong' options. - useKiller = isJust $ serverRequirePong opts - killDelay = maybe 0 fromIntegral (serverRequirePong opts) - - -- Thread that reads the killRef, and kills the application if enough time - -- has passed. - killer killRef appAsync = do - killAt <- IORef.readIORef killRef - now <- getSecs - appState <- Async.poll appAsync - case appState of - -- Already finished/killed/crashed, we can give up. - Just _ -> return () - -- Should not be killed yet. Wait and try again. - Nothing | now < killAt -> do - threadDelay (fromIntegral killDelay * 1000 * 1000) - killer killRef appAsync - -- Time to kill. - _ -> Async.cancelWith appAsync PongTimeout + (makeListenSocket (serverHost opts) (serverPort opts)) + S.close $ \sock -> do + let connOpts = serverConnectionOptions opts + + connThread conn = case serverRequirePong opts of + Nothing -> runApp conn connOpts app + Just grace -> do + heartbeat <- newEmptyMVar + + let -- Update the connection options to perform a heartbeat + -- whenever a pong is received. + connOpts' = connOpts + { connectionOnPong = do + _ <- tryPutMVar heartbeat () + connectionOnPong connOpts + } + + whileJust io = do + result <- io + case result of + Nothing -> return () + Just _ -> whileJust io + + -- Runs until a pong was not received within the grace + -- period. + heart = whileJust $ timeout (grace * 1000000) (takeMVar heartbeat) + + Async.race_ + (runApp conn connOpts' app) + (heart >> throwIO PongTimeout) + + mainThread = do + (conn, _) <- S.accept sock + Async.withAsyncWithUnmask + (\unmask -> unmask (connThread conn) `finally` S.close conn) + (\_ -> mainThread) + + mask_ mainThread -------------------------------------------------------------------------------- diff --git a/websockets.cabal b/websockets.cabal index 35e9f02..df18c58 100644 --- a/websockets.cabal +++ b/websockets.cabal @@ -92,7 +92,6 @@ Library binary >= 0.8.1 && < 0.11, bytestring >= 0.9 && < 0.12, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0.1 && < 1.3, @@ -151,7 +150,6 @@ Test-suite websockets-tests binary >= 0.8.1 && < 0.11, bytestring >= 0.9 && < 0.12, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -212,7 +210,6 @@ Executable websockets-autobahn binary >= 0.8.1 && < 0.11, bytestring >= 0.9 && < 0.12, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3, @@ -240,7 +237,6 @@ Benchmark bench-mask binary >= 0.8.1 && < 0.11, bytestring >= 0.9 && < 0.12, case-insensitive >= 0.3 && < 1.3, - clock >= 0.8 && < 0.9, containers >= 0.3 && < 0.7, network >= 2.3 && < 3.2, random >= 1.0 && < 1.3,