Skip to content

Commit

Permalink
Remove seen messages
Browse files Browse the repository at this point in the history
- We record what each party has seen and remove messages seen by all
  parties.

- Use IntMap instead of Vector for storing sent messages because we need
to be able to remove old/seen indices without re-indexing.

- Use Map to keep track of the last seen message by party

- Introduce a test case to assert the withNetwork logs removal of old
  messages.
  • Loading branch information
v0d1ch committed Sep 21, 2023
1 parent 5139a31 commit 8ea1ed6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
49 changes: 31 additions & 18 deletions hydra-node/src/Hydra/Network/Reliability.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import Control.Concurrent.Class.MonadSTM (
writeTVar,
)
import Control.Tracer (Tracer)
import qualified Data.IntMap as IMap
import qualified Data.Map.Strict as Map
import Data.Vector (
Vector,
Expand Down Expand Up @@ -132,11 +133,9 @@ instance Arbitrary ReliabilityLog where
-- layer is tied to a specific structure of other layers, eg. be between
-- `withHeartbeat` and `withAuthenticate` layers.
--
-- TODO: garbage-collect the `sentMessages` which otherwise will grow forever
-- TODO: better use of Vectors? We should perhaps use a `MVector` to be able to
-- mutate in-place and not need `zipWith`
withReliability ::
forall m msg a.
(MonadThrow (STM m), MonadThrow m, MonadAsync m) =>
-- | Tracer for logging messages.
Tracer m ReliabilityLog ->
Expand All @@ -149,7 +148,7 @@ withReliability ::
NetworkComponent m (Authenticated (Heartbeat msg)) (Heartbeat msg) a
withReliability tracer me otherParties withRawNetwork callback action = do
ackCounter <- newTVarIO $ replicate (length allParties) 0
sentMessages <- newTVarIO Map.empty
sentMessages <- newTVarIO IMap.empty
seenMessages <- newTVarIO $ Map.fromList $ (,0) <$> toList allParties
resendQ <- newTQueueIO
ourIndex <- findPartyIndex me
Expand Down Expand Up @@ -186,7 +185,6 @@ withReliability tracer me otherParties withRawNetwork callback action = do
if length acks /= length allParties
then ignoreMalformedMessages
else do
deleteSeenMessages sentMessages seenMessages acks party
partyIndex <- findPartyIndex party
traceWith tracer Callbacking
(shouldCallback, messageAckForParty, knownAckForParty, knownAcks) <- atomically $ do
Expand Down Expand Up @@ -222,6 +220,9 @@ withReliability tracer me otherParties withRawNetwork callback action = do

-- resend messages if party did not acknowledge our latest idx
resendMessages resend partyIndex sentMessages knownAcks acks messageAckForParty knownAckForParty
updateSeenMessages seenMessages acks party

deleteSeenMessages sentMessages seenMessages

ignoreMalformedMessages = pure ()

Expand All @@ -238,14 +239,14 @@ withReliability tracer me otherParties withRawNetwork callback action = do
let missing = fromList [messageAckForUs + 1 .. knownAckForUs]
messages <- readTVarIO sentMessages
forM_ missing $ \idx -> do
case messages Map.!? (idx - 1) of
case messages IMap.!? (idx - 1) of
Nothing ->
throwIO $
ReliabilityFailedToFindMsg $
"FIXME: this should never happen, there's no sent message at index "
<> show idx
<> ", messages length = "
<> show (Map.size messages)
<> show (IMap.size messages)
<> ", latest message ack: "
<> show knownAckForUs
<> ", acked: "
Expand All @@ -255,22 +256,34 @@ withReliability tracer me otherParties withRawNetwork callback action = do
traceWith tracer (Resending missing messageAcks newAcks' partyIndex)
atomically $ resend $ ReliableMsg newAcks' missingMsg

deleteSeenMessages sentMessages seenMessages acks party = do
updateSeenMessages seenMessages acks party = do
myIndex <- findPartyIndex me
let messageAckForUs = acks ! myIndex
(queueLength, deleted) <- atomically $ do
modifyTVar' seenMessages (Map.update (const $ Just messageAckForUs) party)
_messages <- readTVar seenMessages
return (0, 1)
-- TODO: here we want to delete old messages but it would be good to
-- convert this structure from vector to map
-- modifyTVar' sentMessages (Map.delete (const $ Just messageAckForUs) party)
traceWith tracer (ClearedMessageQueue queueLength deleted)
atomically $ modifyTVar' seenMessages (Map.insert party messageAckForUs)

deleteSeenMessages sentMessages seenMessages = do
clearedMessages <- atomically $ do
seenMessages' <- readTVar seenMessages
let messageReceivedByEveryone =
case sortBy (\(_, a) (_, b) -> compare a b) (Map.toList seenMessages') of
[] -> 0 -- should not happen
((_, v) : _) -> v
sentMessages' <- readTVar sentMessages
if IMap.member messageReceivedByEveryone sentMessages'
then do
let updatedMap = IMap.delete messageReceivedByEveryone sentMessages'
writeTVar sentMessages updatedMap
pure $ Just (IMap.size updatedMap, messageReceivedByEveryone)
else pure Nothing

case clearedMessages of
Nothing -> pure ()
Just (messageQueueLength, deletedMessages) -> traceWith tracer (ClearedMessageQueue{messageQueueLength, deletedMessages})

insertNewMsg msg m =
case Map.lookupMax m of
Nothing -> Map.insert 1 msg m
Just (k, _) -> Map.insert (k + 1) msg m
case IMap.lookupMax m of
Nothing -> IMap.insert 1 msg m
Just (k, _) -> IMap.insert (k + 1) msg m

-- find the index of a party in the list of parties or fail with 'ReliabilityMissingPartyIndex'
findPartyIndex party =
Expand Down
13 changes: 8 additions & 5 deletions hydra-node/test/Hydra/Network/ReliabilitySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,20 @@ spec = parallel $ do
(captureTraces emittedTraces)
alice
[bob, carol]
( \incoming _ -> do
incoming (Authenticated (ReliableMsg (fromList [1, 1, 0]) (Data "node-2" msg)) bob)
incoming (Authenticated (ReliableMsg (fromList [1, 0, 1]) (Data "node-3" msg)) carol)
( \incoming action -> do
incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-1" msg)) alice)
action $ Network{broadcast = \_ -> pure ()}
incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-2" msg)) bob)
action $ Network{broadcast = \_ -> pure ()}
incoming (Authenticated (ReliableMsg (fromList [1, 0, 0]) (Data "node-3" msg)) carol)
action $ Network{broadcast = \_ -> pure ()}
)
noop
$ \Network{broadcast} -> do
broadcast (Data "node-1" msg)
threadDelay 1
readTVarIO emittedTraces

receivedTraces `shouldContain` [ClearedMessageQueue{messageQueueLength = 0, deletedMessages = 1}]
receivedTraces `shouldContain` [ClearedMessageQueue{messageQueueLength = 1, deletedMessages = 1}]

describe "sending messages" $ do
prop "broadcast messages to the network assigning a sequential id" $ \(messages :: [String]) ->
Expand Down

0 comments on commit 8ea1ed6

Please sign in to comment.