Skip to content

Commit

Permalink
Remove sent messages from memory
Browse files Browse the repository at this point in the history
  • Loading branch information
v0d1ch committed Oct 10, 2023
1 parent 086cca2 commit eb93944
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
33 changes: 12 additions & 21 deletions hydra-node/src/Hydra/Network/Reliability.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ import Cardano.Binary (serialize')
import Cardano.Crypto.Util (SignableRepresentation (getSignableRepresentation))
import Control.Concurrent.Class.MonadSTM (
MonadSTM (readTQueue, readTVarIO, writeTQueue),
modifyTVar',
newTQueueIO,
newTVarIO,
writeTVar,
Expand All @@ -86,7 +85,7 @@ import Hydra.Network (Network (..), NetworkComponent)
import Hydra.Network.Authenticate (Authenticated (..))
import Hydra.Network.Heartbeat (Heartbeat (..), isPing)
import Hydra.Party (Party)
import Hydra.Persistence (Persistence (..), append, loadAll, PersistenceIncremental)
import Hydra.Persistence (Persistence (..), append, loadAll, PersistenceIncremental (..))
import Test.QuickCheck (getPositive, listOf)

data ReliableMsg msg = ReliableMsg
Expand Down Expand Up @@ -154,25 +153,23 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw
Nothing -> pure $ replicate (length allParties) 0
Just existingCounter -> pure existingCounter
ackCounter <- newTVarIO startingAckCounter
storedMessages <- loadAll msgPersistence
sentMessages <- newTVarIO $ IMap.fromList (zip [1 ..] storedMessages)
resendQ <- newTQueueIO
let ourIndex = fromMaybe (error "This cannot happen because we constructed the list with our party inside.") (findPartyIndex me)
let resend = writeTQueue resendQ
withRawNetwork (reliableCallback ackCounter sentMessages resend ourIndex) $ \network@Network{broadcast} -> do
withRawNetwork (reliableCallback ackCounter resend ourIndex) $ \network@Network{broadcast} -> do
withAsync (forever $ atomically (readTQueue resendQ) >>= broadcast) $ \_ ->
reliableBroadcast ourIndex ackCounter sentMessages network
reliableBroadcast ourIndex ackCounter msgPersistence network
where
allParties = fromList $ sort $ me : otherParties

reliableBroadcast ourIndex ackCounter sentMessages Network{broadcast} =
reliableBroadcast ourIndex ackCounter PersistenceIncremental{append} Network{broadcast} =
action $
Network
{ broadcast = \msg ->
case msg of
Data{} -> do
ackCounter' <- atomically $ incrementCountersFor msg
append msgPersistence msg
ackCounter' <- atomically incrementAckCounter
append msg
void $ save ackPersistence ackCounter'
traceWith tracer (BroadcastCounter ourIndex ackCounter')
broadcast $ ReliableMsg ackCounter' msg
Expand All @@ -182,13 +179,12 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw
broadcast $ ReliableMsg acks msg
}
where
incrementCountersFor msg = do
incrementAckCounter = do
acks <- readTVar ackCounter
writeTVar ackCounter $ constructAcks acks ourIndex
modifyTVar' sentMessages (insertNewMsg msg)
readTVar ackCounter

reliableCallback ackCounter sentMessages resend ourIndex (Authenticated (ReliableMsg acks msg) party) = do
reliableCallback ackCounter resend ourIndex (Authenticated (ReliableMsg acks msg) party) = do
if length acks /= length allParties
then ignoreMalformedMessages
else do
Expand Down Expand Up @@ -219,7 +215,7 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw
else traceWith tracer (Ignored acks knownAcks partyIndex)

when (isPing msg) $
resendMessagesIfLagging resend partyIndex sentMessages knownAcks acks ourIndex
resendMessagesIfLagging resend partyIndex msgPersistence knownAcks acks ourIndex
Nothing -> pure ()

ignoreMalformedMessages = pure ()
Expand All @@ -229,7 +225,7 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw

partyIndexes = generate (length allParties) id

resendMessagesIfLagging resend partyIndex sentMessages knownAcks messageAcks myIndex = do
resendMessagesIfLagging resend partyIndex PersistenceIncremental{loadAll} knownAcks messageAcks myIndex = do
let mmessageAckForUs = messageAcks !? myIndex
let mknownAckForUs = knownAcks !? myIndex
case (mmessageAckForUs, mknownAckForUs) of
Expand All @@ -238,7 +234,8 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw
-- latest message sent
when (messageAckForUs < knownAckForUs) $ do
let missing = fromList [messageAckForUs + 1 .. knownAckForUs]
messages <- readTVarIO sentMessages
storedMessages <- loadAll
let messages = IMap.fromList (zip [1 ..] storedMessages)
forM_ missing $ \idx -> do
case messages IMap.!? idx of
Nothing ->
Expand All @@ -258,12 +255,6 @@ withReliability tracer msgPersistence ackPersistence me otherParties withRawNetw
atomically $ resend $ ReliableMsg newAcks' missingMsg
_ -> pure ()

-- Find the maximum index and increment it by one to store the next message.
insertNewMsg msg m =
case IMap.lookupMax m of
Nothing -> IMap.insert 1 msg m
Just (k, _) -> IMap.insert (k + 1) msg m

-- Find the index of a party in the list of all parties.
-- NOTE: This should never fail.
findPartyIndex party =
Expand Down
30 changes: 19 additions & 11 deletions hydra-node/test/Hydra/Network/ReliabilitySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import Test.QuickCheck (
(===),
)
import Prelude (unlines)
import Data.Sequence ((|>))

spec :: Spec
spec = parallel $ do
Expand Down Expand Up @@ -92,8 +93,9 @@ spec = parallel $ do
prop "broadcast messages to the network assigning a sequential id" $ \(messages :: [String]) ->
let sentMsgs = runSimOrThrow $ do
sentMessages <- newTVarIO empty
messagePersistence <- msgPersistence

withReliability nullTracer msgPersistence ackPersistence alice [] (captureOutgoing sentMessages) noop $ \Network{broadcast} -> do
withReliability nullTracer messagePersistence ackPersistence alice [] (captureOutgoing sentMessages) noop $ \Network{broadcast} -> do
mapM_ (broadcast . Data "node-1") messages

fromList . toList <$> readTVarIO sentMessages
Expand All @@ -111,14 +113,16 @@ spec = parallel $ do
randomSeed <- newTVarIO $ mkStdGen seed
aliceToBob <- newTQueueIO
bobToAlice <- newTQueueIO
aliceMessagePersistence <- msgPersistence
bobMessagePersistence <- msgPersistence
let
-- this is a NetworkComponent that broadcasts authenticated messages
-- mediated through a read and a write TQueue but drops 0.2 % of them
aliceFailingNetwork = failingNetwork randomSeed alice (bobToAlice, aliceToBob)
bobFailingNetwork = failingNetwork randomSeed bob (aliceToBob, bobToAlice)

bobReliabilityStack = reliabilityStack msgPersistence bobFailingNetwork emittedTraces "bob" bob [alice]
aliceReliabilityStack = reliabilityStack msgPersistence aliceFailingNetwork emittedTraces "alice" alice [bob]
bobReliabilityStack = reliabilityStack aliceMessagePersistence bobFailingNetwork emittedTraces "bob" bob [alice]
aliceReliabilityStack = reliabilityStack bobMessagePersistence aliceFailingNetwork emittedTraces "alice" alice [bob]

runAlice = runPeer aliceReliabilityStack "alice" messagesReceivedByAlice messagesReceivedByBob aliceToBobMessages bobToAliceMessages
runBob = runPeer bobReliabilityStack "bob" messagesReceivedByBob messagesReceivedByAlice bobToAliceMessages aliceToBobMessages
Expand All @@ -140,9 +144,10 @@ spec = parallel $ do
it "broadcast updates counter from peers" $ do
let receivedMsgs = runSimOrThrow $ do
sentMessages <- newTVarIO empty
messagePersistence <- msgPersistence
withReliability
nullTracer
msgPersistence
messagePersistence
ackPersistence
alice
[bob]
Expand Down Expand Up @@ -232,13 +237,13 @@ noop = const $ pure ()
aliceReceivesMessages :: (FromJSON msg, ToJSON msg) => [Authenticated (ReliableMsg (Heartbeat msg))] -> [Authenticated (Heartbeat msg)]
aliceReceivesMessages messages = runSimOrThrow $ do
receivedMessages <- newTVarIO empty

messagePersistence <- msgPersistence
let baseNetwork incoming _ = mapM incoming messages

aliceReliabilityStack =
withReliability
nullTracer
msgPersistence
messagePersistence
ackPersistence
alice
[bob, carol]
Expand Down Expand Up @@ -271,11 +276,14 @@ captureTraces ::
captureTraces tvar = Tracer $ \msg -> do
atomically $ modifyTVar' tvar (msg :)

msgPersistence :: Applicative m => PersistenceIncremental a m
msgPersistence =
PersistenceIncremental
{ append = \_ -> pure ()
, loadAll = pure []
msgPersistence :: MonadSTM m => m (PersistenceIncremental a m)
msgPersistence = do
messages <- newTVarIO mempty
pure PersistenceIncremental
{ append = \msg -> atomically $ do
modifyTVar' messages (|> msg )

, loadAll = toList <$> readTVarIO messages
}

ackPersistence :: Applicative m => Persistence a m
Expand Down

0 comments on commit eb93944

Please sign in to comment.