diff --git a/hydra-node/src/Hydra/Network/Reliability.hs b/hydra-node/src/Hydra/Network/Reliability.hs index 805311e763e..1c34dab17d3 100644 --- a/hydra-node/src/Hydra/Network/Reliability.hs +++ b/hydra-node/src/Hydra/Network/Reliability.hs @@ -66,15 +66,12 @@ import qualified Data.Map.Strict as Map import Data.Vector ( Vector, elemIndex, - empty, fromList, generate, length, replicate, - snoc, zipWith, (!), - (!?), ) import Hydra.Logging (traceWith) import Hydra.Network (Network (..), NetworkComponent) @@ -152,7 +149,7 @@ withReliability :: NetworkComponent m (Authenticated (Heartbeat msg)) (Heartbeat msg) a withReliability tracer me otherParties withRawNetwork callback action = do ackCounter <- newTVarIO $ replicate (length allParties) 0 - sentMessages <- newTVarIO empty + sentMessages <- newTVarIO Map.empty seenMessages <- newTVarIO $ Map.fromList $ (,0) <$> toList allParties resendQ <- newTQueueIO ourIndex <- findPartyIndex me @@ -174,7 +171,7 @@ withReliability tracer me otherParties withRawNetwork callback action = do acks <- readTVar ackCounter let newAcks = constructAcks acks ourIndex writeTVar ackCounter newAcks - modifyTVar' sentMessages (`snoc` msg) + modifyTVar' sentMessages (insertNewMsg msg) readTVar ackCounter traceWith tracer (BroadcastCounter ourIndex ackCounter') @@ -241,14 +238,14 @@ withReliability tracer me otherParties withRawNetwork callback action = do let missing = fromList [messageAckForUs + 1 .. knownAckForUs] messages <- readTVarIO sentMessages forM_ missing $ \idx -> do - case messages !? (idx - 1) of + case messages Map.!? (idx - 1) of Nothing -> throwIO $ ReliabilityFailedToFindMsg $ "FIXME: this should never happen, there's no sent message at index " <> show idx <> ", messages length = " - <> show (length messages) + <> show (Map.size messages) <> ", latest message ack: " <> show knownAckForUs <> ", acked: " @@ -270,6 +267,11 @@ withReliability tracer me otherParties withRawNetwork callback action = do -- modifyTVar' sentMessages (Map.delete (const $ Just messageAckForUs) party) traceWith tracer (ClearedMessageQueue queueLength deleted) + insertNewMsg msg m = + case Map.lookupMax m of + Nothing -> Map.insert 1 msg m + Just (k, _) -> Map.insert (k + 1) msg m + -- find the index of a party in the list of parties or fail with 'ReliabilityMissingPartyIndex' findPartyIndex party = maybe (throwIO $ ReliabilityMissingPartyIndex party) pure $ elemIndex party allParties