From 2365cdf9b33dc709baa7f830541c552cccbb873a Mon Sep 17 00:00:00 2001 From: Sasha Bogicevic Date: Thu, 21 Sep 2023 12:14:36 +0200 Subject: [PATCH] Introduce deleteSeenMessages --- hydra-node/src/Hydra/Network/Reliability.hs | 23 +++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/hydra-node/src/Hydra/Network/Reliability.hs b/hydra-node/src/Hydra/Network/Reliability.hs index 44110fca802..805311e763e 100644 --- a/hydra-node/src/Hydra/Network/Reliability.hs +++ b/hydra-node/src/Hydra/Network/Reliability.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -Wno-orphans #-} -- | A `Network` layer that guarantees delivery of `msg` in order even in the @@ -61,6 +62,7 @@ import Control.Concurrent.Class.MonadSTM ( writeTVar, ) import Control.Tracer (Tracer) +import qualified Data.Map.Strict as Map import Data.Vector ( Vector, elemIndex, @@ -137,6 +139,7 @@ instance Arbitrary ReliabilityLog where -- TODO: better use of Vectors? We should perhaps use a `MVector` to be able to -- mutate in-place and not need `zipWith` withReliability :: + forall m msg a. (MonadThrow (STM m), MonadThrow m, MonadAsync m) => -- | Tracer for logging messages. Tracer m ReliabilityLog -> @@ -150,10 +153,11 @@ withReliability :: withReliability tracer me otherParties withRawNetwork callback action = do ackCounter <- newTVarIO $ replicate (length allParties) 0 sentMessages <- newTVarIO empty + seenMessages <- newTVarIO $ Map.fromList $ (,0) <$> toList allParties resendQ <- newTQueueIO ourIndex <- findPartyIndex me let resend = writeTQueue resendQ - withRawNetwork (reliableCallback ackCounter sentMessages resend) $ \network@Network{broadcast} -> do + withRawNetwork (reliableCallback ackCounter sentMessages seenMessages resend) $ \network@Network{broadcast} -> do withAsync (forever $ atomically (readTQueue resendQ) >>= broadcast) $ \_ -> reliableBroadcast ourIndex ackCounter sentMessages network where @@ -181,10 +185,11 @@ withReliability tracer me otherParties withRawNetwork callback action = do broadcast $ ReliableMsg acks msg } - reliableCallback ackCounter sentMessages resend (Authenticated (ReliableMsg acks msg) party) = do + reliableCallback ackCounter sentMessages seenMessages resend (Authenticated (ReliableMsg acks msg) party) = do if length acks /= length allParties then ignoreMalformedMessages else do + deleteSeenMessages sentMessages seenMessages acks party partyIndex <- findPartyIndex party traceWith tracer Callbacking (shouldCallback, messageAckForParty, knownAckForParty, knownAcks) <- atomically $ do @@ -220,8 +225,6 @@ withReliability tracer me otherParties withRawNetwork callback action = do -- resend messages if party did not acknowledge our latest idx resendMessages resend partyIndex sentMessages knownAcks acks messageAckForParty knownAckForParty - -- TODO - traceWith tracer (ClearedMessageQueue 0 1) ignoreMalformedMessages = pure () @@ -255,6 +258,18 @@ withReliability tracer me otherParties withRawNetwork callback action = do traceWith tracer (Resending missing messageAcks newAcks' partyIndex) atomically $ resend $ ReliableMsg newAcks' missingMsg + deleteSeenMessages sentMessages seenMessages acks party = do + myIndex <- findPartyIndex me + let messageAckForUs = acks ! myIndex + (queueLength, deleted) <- atomically $ do + modifyTVar' seenMessages (Map.update (const $ Just messageAckForUs) party) + _messages <- readTVar seenMessages + return (0, 1) + -- TODO: here we want to delete old messages but it would be good to + -- convert this structure from vector to map + -- modifyTVar' sentMessages (Map.delete (const $ Just messageAckForUs) party) + traceWith tracer (ClearedMessageQueue queueLength deleted) + -- find the index of a party in the list of parties or fail with 'ReliabilityMissingPartyIndex' findPartyIndex party = maybe (throwIO $ ReliabilityMissingPartyIndex party) pure $ elemIndex party allParties