From 8ea1ed6ebfe0658caaf0bf9dd33762bc2dd1d123 Mon Sep 17 00:00:00 2001 From: Sasha Bogicevic Date: Thu, 21 Sep 2023 18:23:12 +0200 Subject: [PATCH] Remove seen messages - We record what each party has seen and remove messages seen by all parties. - Use IntMap instead of Vector for storing sent messages because we need to be able to remove old/seen indices without re-indexing. - Use Map to keep track of the last seen message by party - Introduce a test case to assert the withNetwork logs removal of old messages. --- hydra-node/src/Hydra/Network/Reliability.hs | 49 ++++++++++++------- .../test/Hydra/Network/ReliabilitySpec.hs | 13 +++-- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/hydra-node/src/Hydra/Network/Reliability.hs b/hydra-node/src/Hydra/Network/Reliability.hs index 1c34dab17d3..ac6fe60c9b4 100644 --- a/hydra-node/src/Hydra/Network/Reliability.hs +++ b/hydra-node/src/Hydra/Network/Reliability.hs @@ -62,6 +62,7 @@ import Control.Concurrent.Class.MonadSTM ( writeTVar, ) import Control.Tracer (Tracer) +import qualified Data.IntMap as IMap import qualified Data.Map.Strict as Map import Data.Vector ( Vector, @@ -132,11 +133,9 @@ instance Arbitrary ReliabilityLog where -- layer is tied to a specific structure of other layers, eg. be between -- `withHeartbeat` and `withAuthenticate` layers. -- --- TODO: garbage-collect the `sentMessages` which otherwise will grow forever -- TODO: better use of Vectors? We should perhaps use a `MVector` to be able to -- mutate in-place and not need `zipWith` withReliability :: - forall m msg a. (MonadThrow (STM m), MonadThrow m, MonadAsync m) => -- | Tracer for logging messages. Tracer m ReliabilityLog -> @@ -149,7 +148,7 @@ withReliability :: NetworkComponent m (Authenticated (Heartbeat msg)) (Heartbeat msg) a withReliability tracer me otherParties withRawNetwork callback action = do ackCounter <- newTVarIO $ replicate (length allParties) 0 - sentMessages <- newTVarIO Map.empty + sentMessages <- newTVarIO IMap.empty seenMessages <- newTVarIO $ Map.fromList $ (,0) <$> toList allParties resendQ <- newTQueueIO ourIndex <- findPartyIndex me @@ -186,7 +185,6 @@ withReliability tracer me otherParties withRawNetwork callback action = do if length acks /= length allParties then ignoreMalformedMessages else do - deleteSeenMessages sentMessages seenMessages acks party partyIndex <- findPartyIndex party traceWith tracer Callbacking (shouldCallback, messageAckForParty, knownAckForParty, knownAcks) <- atomically $ do @@ -222,6 +220,9 @@ withReliability tracer me otherParties withRawNetwork callback action = do -- resend messages if party did not acknowledge our latest idx resendMessages resend partyIndex sentMessages knownAcks acks messageAckForParty knownAckForParty + updateSeenMessages seenMessages acks party + + deleteSeenMessages sentMessages seenMessages ignoreMalformedMessages = pure () @@ -238,14 +239,14 @@ withReliability tracer me otherParties withRawNetwork callback action = do let missing = fromList [messageAckForUs + 1 .. knownAckForUs] messages <- readTVarIO sentMessages forM_ missing $ \idx -> do - case messages Map.!? (idx - 1) of + case messages IMap.!? (idx - 1) of Nothing -> throwIO $ ReliabilityFailedToFindMsg $ "FIXME: this should never happen, there's no sent message at index " <> show idx <> ", messages length = " - <> show (Map.size messages) + <> show (IMap.size messages) <> ", latest message ack: " <> show knownAckForUs <> ", acked: " @@ -255,22 +256,34 @@ withReliability tracer me otherParties withRawNetwork callback action = do traceWith tracer (Resending missing messageAcks newAcks' partyIndex) atomically $ resend $ ReliableMsg newAcks' missingMsg - deleteSeenMessages sentMessages seenMessages acks party = do + updateSeenMessages seenMessages acks party = do myIndex <- findPartyIndex me let messageAckForUs = acks ! myIndex - (queueLength, deleted) <- atomically $ do - modifyTVar' seenMessages (Map.update (const $ Just messageAckForUs) party) - _messages <- readTVar seenMessages - return (0, 1) - -- TODO: here we want to delete old messages but it would be good to - -- convert this structure from vector to map - -- modifyTVar' sentMessages (Map.delete (const $ Just messageAckForUs) party) - traceWith tracer (ClearedMessageQueue queueLength deleted) + atomically $ modifyTVar' seenMessages (Map.insert party messageAckForUs) + + deleteSeenMessages sentMessages seenMessages = do + clearedMessages <- atomically $ do + seenMessages' <- readTVar seenMessages + let messageReceivedByEveryone = + case sortBy (\(_, a) (_, b) -> compare a b) (Map.toList seenMessages') of + [] -> 0 -- should not happen + ((_, v) : _) -> v + sentMessages' <- readTVar sentMessages + if IMap.member messageReceivedByEveryone sentMessages' + then do + let updatedMap = IMap.delete messageReceivedByEveryone sentMessages' + writeTVar sentMessages updatedMap + pure $ Just (IMap.size updatedMap, messageReceivedByEveryone) + else pure Nothing + + case clearedMessages of + Nothing -> pure () + Just (messageQueueLength, deletedMessages) -> traceWith tracer (ClearedMessageQueue{messageQueueLength, deletedMessages}) insertNewMsg msg m = - case Map.lookupMax m of - Nothing -> Map.insert 1 msg m - Just (k, _) -> Map.insert (k + 1) msg m + case IMap.lookupMax m of + Nothing -> IMap.insert 1 msg m + Just (k, _) -> IMap.insert (k + 1) msg m -- find the index of a party in the list of parties or fail with 'ReliabilityMissingPartyIndex' findPartyIndex party = diff --git a/hydra-node/test/Hydra/Network/ReliabilitySpec.hs b/hydra-node/test/Hydra/Network/ReliabilitySpec.hs index 0b901658cd3..368496baf16 100644 --- a/hydra-node/test/Hydra/Network/ReliabilitySpec.hs +++ b/hydra-node/test/Hydra/Network/ReliabilitySpec.hs @@ -73,17 +73,20 @@ spec = parallel $ do (captureTraces emittedTraces) alice [bob, carol] - ( \incoming _ -> do - incoming (Authenticated (ReliableMsg (fromList [1, 1, 0]) (Data "node-2" msg)) bob) - incoming (Authenticated (ReliableMsg (fromList [1, 0, 1]) (Data "node-3" msg)) carol) + ( \incoming action -> do + incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-1" msg)) alice) + action $ Network{broadcast = \_ -> pure ()} + incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-2" msg)) bob) + action $ Network{broadcast = \_ -> pure ()} + incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-3" msg)) carol) + action $ Network{broadcast = \_ -> pure ()} ) noop $ \Network{broadcast} -> do broadcast (Data "node-1" msg) - threadDelay 1 readTVarIO emittedTraces - receivedTraces `shouldContain` [ClearedMessageQueue{messageQueueLength = 0, deletedMessages = 1}] + receivedTraces `shouldContain` [ClearedMessageQueue{messageQueueLength = 1, deletedMessages = 1}] describe "sending messages" $ do prop "broadcast messages to the network assigning a sequential id" $ \(messages :: [String]) ->