Skip to content

Commit

Permalink
Drain rabbitmq consumers slowly from Cannon (#4342)
Browse files Browse the repository at this point in the history
* nit: Fixed grammar.

* Cannon: Add draining for RabbitMQ consumers.

* lint: removed dead code.

* Changelog.

* Deleted outdated todo.

* hi ci
  • Loading branch information
elland authored Nov 19, 2024
1 parent cb83614 commit 12f3687
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog.d/5-internal/drain-rabbit
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add rabbitmq consumers to the draining step on Cannon, in case of termination.
2 changes: 1 addition & 1 deletion services/cannon/src/Cannon/Options.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ data DrainOpts = DrainOpts
-- there are not many users connected. Must not be set to 0.
_drainOptsMillisecondsBetweenBatches :: Word64,
-- | Batch size is calculated considering actual number of websockets and
-- gracePeriod. If this number is too little, '_drainOptsMinBatchSize' is
-- gracePeriod. If this number is too small, '_drainOptsMinBatchSize' is
-- used.
_drainOptsMinBatchSize :: Word64
}
Expand Down
68 changes: 65 additions & 3 deletions services/cannon/src/Cannon/RabbitMqConsumerApp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,78 @@
module Cannon.RabbitMqConsumerApp where

import Cannon.App (rejectOnError)
import Cannon.Dict qualified as D
import Cannon.Options
import Cannon.WS hiding (env)
import Cassandra as C
import Cassandra as C hiding (batch)
import Control.Concurrent.Async
import Control.Concurrent.Timeout
import Control.Exception (Handler (..), bracket, catch, catches, throwIO, try)
import Control.Lens hiding ((#))
import Control.Monad.Codensity
import Data.Aeson
import Data.Aeson hiding (Key)
import Data.Id
import Imports
import Data.List.Extra hiding (delete)
import Data.Timeout (TimeoutUnit (..), (#))
import Imports hiding (min, threadDelay)
import Network.AMQP qualified as Q
import Network.AMQP.Extended (withConnection)
import Network.WebSockets
import Network.WebSockets qualified as WS
import System.Logger qualified as Log
import UnliftIO.Async (pooledMapConcurrentlyN_)
import Wire.API.Event.WebSocketProtocol
import Wire.API.Notification

drainRabbitQueues :: Env -> IO ()
drainRabbitQueues e = do
conns <- D.toList e.rabbitConnections
numberOfConns <- fromIntegral <$> D.size e.rabbitConnections

let opts = e.drainOpts
maxNumberOfBatches = (opts ^. gracePeriodSeconds * 1000) `div` (opts ^. millisecondsBetweenBatches)
computedBatchSize = numberOfConns `div` maxNumberOfBatches
batchSize = max (opts ^. minBatchSize) computedBatchSize

logDraining e.logg numberOfConns batchSize (opts ^. minBatchSize) computedBatchSize maxNumberOfBatches

-- Sleeps for the grace period + 1 second. If the sleep completes, it means
-- that draining didn't finish, and we should log that.
timeoutAction <- async $ do
-- Allocate 1 second more than the grace period to allow for overhead of
-- spawning threads.
liftIO $ threadDelay $ ((opts ^. gracePeriodSeconds) # Second + 1 # Second)
logExpired e.logg (opts ^. gracePeriodSeconds)

for_ (chunksOf (fromIntegral batchSize) conns) $ \batch -> do
-- 16 was chosen with a roll of a fair dice.
void . async $ pooledMapConcurrentlyN_ 16 (uncurry (closeConn e.logg)) batch
liftIO $ threadDelay ((opts ^. millisecondsBetweenBatches) # MilliSecond)
cancel timeoutAction
Log.info e.logg $ Log.msg (Log.val "Draining complete")
where
closeConn :: Log.Logger -> Key -> Q.Connection -> IO ()
closeConn l key conn = do
Log.info l $
Log.msg (Log.val "closing rabbitmq connection")
. Log.field "key" (show key)
Q.closeConnection conn
void $ D.remove key e.rabbitConnections

logExpired :: Log.Logger -> Word64 -> IO ()
logExpired l period = do
Log.err l $ Log.msg (Log.val "Drain grace period expired") . Log.field "gracePeriodSeconds" period

logDraining :: Log.Logger -> Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> IO ()
logDraining l count b min batchSize m = do
Log.info l $
Log.msg (Log.val "draining all rabbitmq connections")
. Log.field "numberOfConns" count
. Log.field "computedBatchSize" b
. Log.field "minBatchSize" min
. Log.field "batchSize" batchSize
. Log.field "maxNumberOfBatches" m

rabbitMQWebSocketApp :: UserId -> ClientId -> Env -> ServerApp
rabbitMQWebSocketApp uid cid e pendingConn = do
wsVar <- newEmptyMVar
Expand Down Expand Up @@ -126,11 +182,16 @@ rabbitMQWebSocketApp uid cid e pendingConn = do
-- create rabbitmq connection
conn <- Codensity $ withConnection e.logg e.rabbitmq

-- Store it in the env
let key = mkKeyRabbit uid cid
D.insert key conn e.rabbitConnections

-- create rabbitmq channel
amqpChan <- Codensity $ bracket (Q.openChannel conn) Q.closeChannel

-- propagate rabbitmq connection failure
lift $ Q.addConnectionClosedHandler conn True $ do
void $ D.remove key e.rabbitConnections
putMVar msgVar $
Left (Q.ConnectionClosedException Q.Normal "")

Expand All @@ -149,6 +210,7 @@ rabbitMQWebSocketApp uid cid e pendingConn = do
catch (WS.sendBinaryData wsConn (encode (EventMessage eventData))) $
\(err :: SomeException) -> do
logSendFailure err
void $ D.remove key e.rabbitConnections
throwIO err

-- get ack from wsVar and forward to rabbitmq
Expand Down
9 changes: 6 additions & 3 deletions services/cannon/src/Cannon/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import Cannon.API.Public
import Cannon.App (maxPingInterval)
import Cannon.Dict qualified as D
import Cannon.Options
import Cannon.Types (Cannon, applog, clients, env, mkEnv, runCannon, runCannonToServant)
import Cannon.RabbitMqConsumerApp (drainRabbitQueues)
import Cannon.Types (Cannon, applog, clients, connectionLimit, env, mkEnv, runCannon, runCannonToServant)
import Cannon.WS hiding (drainOpts, env)
import Cassandra.Util (defInitCassandra)
import Control.Concurrent
Expand Down Expand Up @@ -76,8 +77,9 @@ run o = withTracer \tracer -> do
cassandra <- defInitCassandra (o ^. cassandraOpts) g
e <-
mkEnv ext o cassandra g
<$> D.empty 128
<*> newManager defaultManagerSettings {managerConnCount = 128}
<$> D.empty connectionLimit
<*> D.empty connectionLimit
<*> newManager defaultManagerSettings {managerConnCount = connectionLimit}
<*> createSystemRandom
<*> mkClock
<*> pure (o ^. Cannon.Options.rabbitmq)
Expand Down Expand Up @@ -124,6 +126,7 @@ run o = withTracer \tracer -> do
signalHandler :: Env -> ThreadId -> Signals.Handler
signalHandler e mainThread = CatchOnce $ do
runWS e drain
drainRabbitQueues e
throwTo mainThread SignalledToExit

-- | This is called when the main thread receives the exception generated by
Expand Down
17 changes: 12 additions & 5 deletions services/cannon/src/Cannon/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
module Cannon.Types
( Env (..),
Cannon,
connectionLimit,
mapConcurrentlyCannon,
mkEnv,
runCannon,
Expand All @@ -42,20 +43,25 @@ import Control.Monad.Catch
import Data.Id
import Data.Text.Encoding
import Imports
import Network.AMQP qualified as Q
import Network.AMQP.Extended (AmqpEndpoint)
import Prometheus
import Servant qualified
import System.Logger qualified as Logger
import System.Logger.Class hiding (info)
import System.Random.MWC (GenIO)

connectionLimit :: Int
connectionLimit = 128

-----------------------------------------------------------------------------
-- Cannon monad

data Env = Env
{ opts :: !Opts,
applog :: !Logger,
dict :: !(Dict Key Websocket),
websockets :: !(Dict Key Websocket),
rabbitConnections :: (Dict Key Q.Connection),
reqId :: !RequestId,
env :: !WS.Env
}
Expand Down Expand Up @@ -95,20 +101,21 @@ mkEnv ::
ClientState ->
Logger ->
Dict Key Websocket ->
Dict Key Q.Connection ->
Manager ->
GenIO ->
Clock ->
AmqpEndpoint ->
Env
mkEnv external o cs l d p g t rabbitmqOpts =
Env o l d (RequestId defRequestId) $
WS.env external (o ^. cannon . port) (encodeUtf8 $ o ^. gundeck . host) (o ^. gundeck . port) l p d g t (o ^. drainOpts) rabbitmqOpts cs
mkEnv external o cs l d conns p g t rabbitmqOpts =
Env o l d conns (RequestId defRequestId) $
WS.env external (o ^. cannon . port) (encodeUtf8 $ o ^. gundeck . host) (o ^. gundeck . port) l p d conns g t (o ^. drainOpts) rabbitmqOpts cs

runCannon :: Env -> Cannon a -> IO a
runCannon e c = runReaderT (unCannon c) e

clients :: Cannon (Dict Key Websocket)
clients = Cannon $ asks dict
clients = Cannon $ asks websockets

wsenv :: Cannon WS.Env
wsenv = Cannon $ do
Expand Down
17 changes: 12 additions & 5 deletions services/cannon/src/Cannon/WS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module Cannon.WS
connIdent,
Key,
mkKey,
mkKeyRabbit,
key2bytes,
client,
sendMsg,
Expand Down Expand Up @@ -68,6 +69,7 @@ import Data.List.Extra (chunksOf)
import Data.Text.Encoding (decodeUtf8)
import Data.Timeout (TimeoutUnit (..), (#))
import Imports hiding (threadDelay)
import Network.AMQP qualified as Q
import Network.AMQP.Extended
import Network.HTTP.Types.Method
import Network.HTTP.Types.Status
Expand All @@ -90,6 +92,9 @@ newtype Key = Key
mkKey :: UserId -> ConnId -> Key
mkKey u c = Key (toByteString' u, fromConnId c)

mkKeyRabbit :: UserId -> ClientId -> Key
mkKeyRabbit u c = Key (toByteString' u, toByteString' c)

key2bytes :: Key -> ByteString
key2bytes (Key (u, c)) = u <> "." <> c

Expand Down Expand Up @@ -144,7 +149,8 @@ data Env = Env
reqId :: !RequestId,
logg :: !Logger,
manager :: !Manager,
dict :: !(Dict Key Websocket),
websockets :: !(Dict Key Websocket),
rabbitConnections :: !(Dict Key Q.Connection),
rand :: !GenIO,
clock :: !Clock,
drainOpts :: DrainOpts,
Expand Down Expand Up @@ -192,6 +198,7 @@ env ::
Logger ->
Manager ->
Dict Key Websocket ->
Dict Key Q.Connection ->
GenIO ->
Clock ->
DrainOpts ->
Expand All @@ -206,13 +213,13 @@ runWS e m = liftIO $ runReaderT (_conn m) e
registerLocal :: Key -> Websocket -> WS ()
registerLocal k c = do
trace $ client (key2bytes k) . msg (val "register")
d <- WS $ asks dict
d <- WS $ asks websockets
D.insert k c d

unregisterLocal :: Key -> Websocket -> WS Bool
unregisterLocal k c = do
trace $ client (key2bytes k) . msg (val "unregister")
d <- WS $ asks dict
d <- WS $ asks websockets
D.removeIf (maybe False ((connIdent c ==) . connIdent)) k d

registerRemote :: Key -> Maybe ClientId -> WS ()
Expand Down Expand Up @@ -250,7 +257,7 @@ sendMsg message k c = do
traceLog m = trace $ client kb . msg (logMsg m)

logMsg :: (WebSocketsData a) => a -> Builder
logMsg m = val "sendMsgConduit: \"" +++ L.take 128 (toLazyByteString m) +++ val "...\""
logMsg m = val "sendMsgConduit: \"" +++ L.take 129 (toLazyByteString m) +++ val "...\""

kb = key2bytes k

Expand Down Expand Up @@ -294,7 +301,7 @@ sendMsg message k c = do
drain :: WS ()
drain = do
opts <- asks drainOpts
websockets <- asks dict
websockets <- asks websockets
numberOfConns <- fromIntegral <$> D.size websockets
let maxNumberOfBatches = (opts ^. gracePeriodSeconds * 1000) `div` (opts ^. millisecondsBetweenBatches)
computedBatchSize = numberOfConns `div` maxNumberOfBatches
Expand Down

0 comments on commit 12f3687

Please sign in to comment.