Skip to content

Commit

Permalink
Clean up forked threads (#205)
Browse files Browse the repository at this point in the history
Prior to this change, `runServer` would not kill the threads it forked. Namely, it would lose track of the `appAsync`s.

After this change, `runServer` will kill all threads it forks.
  • Loading branch information
chrismwendt authored Dec 22, 2023
1 parent bf607ba commit da1fd1d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 65 deletions.
105 changes: 44 additions & 61 deletions src/Network/WebSockets/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ module Network.WebSockets.Server


--------------------------------------------------------------------------------
import Control.Concurrent (threadDelay)
import Control.Concurrent (takeMVar, tryPutMVar,
newEmptyMVar)
import qualified Control.Concurrent.Async as Async
import Control.Exception (Exception, allowInterrupt,
bracket, bracketOnError,
finally, mask_, throwIO)
import Control.Monad (forever, void, when)
import qualified Data.IORef as IORef
import Data.Maybe (isJust)
import Control.Exception (Exception, bracket,
bracketOnError, finally, mask_,
throwIO)
import Network.Socket (Socket)
import qualified Network.Socket as S
import qualified System.Clock as Clock
import System.Timeout (timeout)


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -110,59 +108,44 @@ defaultServerOptions = ServerOptions
runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions opts app = S.withSocketsDo $
bracket
(makeListenSocket host port)
S.close $ \sock -> mask_ $ forever $ do
allowInterrupt
(conn, _) <- S.accept sock

-- This IORef holds a time at which the thread may be killed. This time
-- can be extended by calling 'tickle'.
killRef <- IORef.newIORef =<< (+ killDelay) <$> getSecs
let tickle = IORef.writeIORef killRef =<< (+ killDelay) <$> getSecs

-- Update the connection options to call 'tickle' whenever a pong is
-- received.
let connOpts'
| not useKiller = connOpts
| otherwise = connOpts
{ connectionOnPong = tickle >> connectionOnPong connOpts
}

-- Run the application.
appAsync <- Async.asyncWithUnmask $ \unmask ->
(unmask $ do
runApp conn connOpts' app) `finally`
(S.close conn)

-- Install the killer if required.
when useKiller $ void $ Async.async (killer killRef appAsync)
where
host = serverHost opts
port = serverPort opts
connOpts = serverConnectionOptions opts

-- Get the current number of seconds on some clock.
getSecs = Clock.sec <$> Clock.getTime Clock.Monotonic

-- Parse the 'serverRequirePong' options.
useKiller = isJust $ serverRequirePong opts
killDelay = maybe 0 fromIntegral (serverRequirePong opts)

-- Thread that reads the killRef, and kills the application if enough time
-- has passed.
killer killRef appAsync = do
killAt <- IORef.readIORef killRef
now <- getSecs
appState <- Async.poll appAsync
case appState of
-- Already finished/killed/crashed, we can give up.
Just _ -> return ()
-- Should not be killed yet. Wait and try again.
Nothing | now < killAt -> do
threadDelay (fromIntegral killDelay * 1000 * 1000)
killer killRef appAsync
-- Time to kill.
_ -> Async.cancelWith appAsync PongTimeout
(makeListenSocket (serverHost opts) (serverPort opts))
S.close $ \sock -> do
let connOpts = serverConnectionOptions opts

connThread conn = case serverRequirePong opts of
Nothing -> runApp conn connOpts app
Just grace -> do
heartbeat <- newEmptyMVar

let -- Update the connection options to perform a heartbeat
-- whenever a pong is received.
connOpts' = connOpts
{ connectionOnPong = do
_ <- tryPutMVar heartbeat ()
connectionOnPong connOpts
}

whileJust io = do
result <- io
case result of
Nothing -> return ()
Just _ -> whileJust io

-- Runs until a pong was not received within the grace
-- period.
heart = whileJust $ timeout (grace * 1000000) (takeMVar heartbeat)

Async.race_
(runApp conn connOpts' app)
(heart >> throwIO PongTimeout)

mainThread = do
(conn, _) <- S.accept sock
Async.withAsyncWithUnmask
(\unmask -> unmask (connThread conn) `finally` S.close conn)
(\_ -> mainThread)

mask_ mainThread


--------------------------------------------------------------------------------
Expand Down
4 changes: 0 additions & 4 deletions websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ Library
binary >= 0.8.1 && < 0.11,
bytestring >= 0.9 && < 0.12,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0.1 && < 1.3,
Expand Down Expand Up @@ -151,7 +150,6 @@ Test-suite websockets-tests
binary >= 0.8.1 && < 0.11,
bytestring >= 0.9 && < 0.12,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -212,7 +210,6 @@ Executable websockets-autobahn
binary >= 0.8.1 && < 0.11,
bytestring >= 0.9 && < 0.12,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down Expand Up @@ -240,7 +237,6 @@ Benchmark bench-mask
binary >= 0.8.1 && < 0.11,
bytestring >= 0.9 && < 0.12,
case-insensitive >= 0.3 && < 1.3,
clock >= 0.8 && < 0.9,
containers >= 0.3 && < 0.7,
network >= 2.3 && < 3.2,
random >= 1.0 && < 1.3,
Expand Down

0 comments on commit da1fd1d

Please sign in to comment.