From 9b2b76006bc73f1b836241397a7a892f62374194 Mon Sep 17 00:00:00 2001 From: Sasha Bogicevic Date: Thu, 14 Sep 2023 16:15:18 +0200 Subject: [PATCH] Use Map for messages We want to be able to delete old messages but keep the indices growing. Is there a better data structure for this? --- hydra-node/src/Hydra/Network/Reliability.hs | 19 ++++++++----- .../test/Hydra/Network/ReliabilitySpec.hs | 27 ++++++++++--------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/hydra-node/src/Hydra/Network/Reliability.hs b/hydra-node/src/Hydra/Network/Reliability.hs index 51a06c0d2b3..6c58d01305f 100644 --- a/hydra-node/src/Hydra/Network/Reliability.hs +++ b/hydra-node/src/Hydra/Network/Reliability.hs @@ -46,19 +46,17 @@ import Control.Concurrent.Class.MonadSTM ( writeTVar, ) import Control.Tracer (Tracer) +import qualified Data.Map.Strict as Map import Data.Maybe (fromJust) import Data.Vector ( Vector, elemIndex, - empty, fromList, generate, length, replicate, - snoc, zipWith, (!), - (!?), ) import Hydra.Logging (traceWith) import Hydra.Network (Network (..), NetworkComponent) @@ -103,7 +101,7 @@ withReliability :: NetworkComponent m (Authenticated msg) (Authenticated msg) a withReliability tracer us allParties withRawNetwork callback action = do ackCounter <- newTVarIO $ replicate (length allParties) 0 - sentMessages <- newTVarIO empty + sentMessages <- newTVarIO Map.empty resendQ <- newTQueueIO let resend = writeTQueue resendQ withRawNetwork (reliableCallback ackCounter sentMessages resend) $ \network@Network{broadcast} -> do @@ -119,7 +117,7 @@ withReliability tracer us allParties withRawNetwork callback action = do let ourIndex = fromJust $ elemIndex us allParties let newAcks = constructAcks acks ourIndex writeTVar ackCounter newAcks - modifyTVar' sentMessages (`snoc` msg) + modifyTVar' sentMessages (mapInsert msg) readTVar ackCounter traceWith tracer (BroadcastCounter ackCounter') @@ -152,14 +150,14 @@ withReliability tracer us allParties withRawNetwork callback action = do atomically $ do messages <- readTVar sentMessages forM_ missing $ \idx -> do - case messages !? (idx - 1) of + case messages Map.!? (idx - 1) of Nothing -> throwIO $ ReliabilityFailedToFindMsg $ "FIXME: this should never happen, there's no sent message at index " <> show idx <> ", messages length = " - <> show (length messages) + <> show (Map.size messages) <> ", latest message ack: " <> show latestMsgAck <> ", acked: " @@ -173,6 +171,13 @@ withReliability tracer us allParties withRawNetwork callback action = do constructAcks acks wantedIndex = zipWith (\ack i -> if i == wantedIndex then ack + 1 else ack) acks partyIndexes +mapInsert :: msg -> Map Int msg -> Map Int msg +mapInsert msg m = + case Map.lookupMax m of + Nothing -> Map.insert 1 msg m + Just (lastIndex, _) -> + Map.insert (lastIndex + 1) msg m + data ReliabilityLog = Resending {missing :: Vector Int, acknowledged :: Vector Int, localCounter :: Vector Int, party :: Party} | BroadcastCounter {localCounter :: Vector Int} diff --git a/hydra-node/test/Hydra/Network/ReliabilitySpec.hs b/hydra-node/test/Hydra/Network/ReliabilitySpec.hs index e039859cb59..477f36249a3 100644 --- a/hydra-node/test/Hydra/Network/ReliabilitySpec.hs +++ b/hydra-node/test/Hydra/Network/ReliabilitySpec.hs @@ -8,27 +8,28 @@ import Test.Hydra.Prelude import Control.Concurrent.Class.MonadSTM (MonadSTM (readTQueue, readTVarIO, writeTQueue), modifyTVar', newTQueueIO, newTVarIO) import Control.Monad.IOSim (runSimOrThrow) import Control.Tracer (nullTracer) +import qualified Data.Map.Strict as Map import qualified Data.Set as Set -import Data.Vector (empty, fromList, head, snoc) +import Data.Vector (fromList, head) import Hydra.Network (Network (..)) import Hydra.Network.Authenticate (Authenticated (..)) -import Hydra.Network.Reliability (Msg (..), withReliability) +import Hydra.Network.Reliability (Msg (..), mapInsert, withReliability) import Test.Hydra.Fixture (alice, bob, carol) import Test.QuickCheck (Positive (Positive), collect, counterexample, forAll, generate, suchThat, tabulate) spec :: Spec spec = parallel $ do let captureOutgoing msgqueue _cb action = - action $ Network{broadcast = \msg -> atomically $ modifyTVar' msgqueue (`snoc` msg)} + action $ Network{broadcast = atomically . modifyTVar' msgqueue . mapInsert} captureIncoming receivedMessages msg = - atomically $ modifyTVar' receivedMessages (`snoc` msg) + atomically $ modifyTVar' receivedMessages (mapInsert msg) msg <- runIO $ generate @String arbitrary it "forward received messages" $ do let receivedMsgs = runSimOrThrow $ do - receivedMessages <- newTVarIO empty + receivedMessages <- newTVarIO Map.empty withReliability nullTracer @@ -47,7 +48,7 @@ spec = parallel $ do prop "broadcast messages to the network assigning a sequential id" $ \(messages :: [String]) -> let sentMsgs = runSimOrThrow $ do - sentMessages <- newTVarIO empty + sentMessages <- newTVarIO Map.empty withReliability nullTracer alice (fromList [alice]) (captureOutgoing sentMessages) noop $ \Network{broadcast} -> do mapM_ (\m -> broadcast (Authenticated m alice)) messages @@ -57,7 +58,7 @@ spec = parallel $ do it "broadcasts messages to single connected peer" $ do let receivedMsgs = runSimOrThrow $ do - receivedMessages <- newTVarIO empty + receivedMessages <- newTVarIO Map.empty queue <- newTQueueIO let aliceNetwork _ action = do @@ -83,7 +84,7 @@ spec = parallel $ do prop "drops already received messages" $ \(messages :: [Positive Int]) -> let receivedMsgs = runSimOrThrow $ do - receivedMessages <- newTVarIO empty + receivedMessages <- newTVarIO Map.empty withReliability nullTracer @@ -106,7 +107,7 @@ spec = parallel $ do it "do not drop messages with same ids from different peers" $ do let receivedMsgs = runSimOrThrow $ do - receivedMessages <- newTVarIO empty + receivedMessages <- newTVarIO Map.empty withReliability nullTracer @@ -128,7 +129,7 @@ spec = parallel $ do forAll (arbitrary `suchThat` (> lastMessageKnownToBob)) $ \totalNumberOfMessages -> let messagesList = show <$> [1 .. totalNumberOfMessages] sentMsgs = runSimOrThrow $ do - sentMessages <- newTVarIO empty + sentMessages <- newTVarIO Map.empty withReliability nullTracer @@ -136,7 +137,7 @@ spec = parallel $ do (fromList [alice, bob]) ( \incoming action -> do concurrently_ - (action $ Network{broadcast = \m -> atomically $ modifyTVar' sentMessages (`snoc` message (payload m))}) + (action $ Network{broadcast = atomically . modifyTVar' sentMessages . mapInsert . message . payload}) (threadDelay 2 >> incoming (Authenticated (Msg (fromList [lastMessageKnownToBob, 1]) msg) bob)) ) noop @@ -156,14 +157,14 @@ spec = parallel $ do it "broadcast updates counter from peers" $ do let receivedMsgs = runSimOrThrow $ do - sentMessages <- newTVarIO empty + sentMessages <- newTVarIO Map.empty withReliability nullTracer alice (fromList [alice, bob]) ( \incoming action -> do concurrently_ - (action $ Network{broadcast = \m -> atomically $ modifyTVar' sentMessages (`snoc` payload m)}) + (action $ Network{broadcast = atomically . modifyTVar' sentMessages . mapInsert . payload}) (incoming (Authenticated (Msg (fromList [0, 1]) msg) bob)) ) noop