diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs index a909cb42f30..543c35e0e60 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs @@ -62,12 +62,14 @@ data TraceTxSubmissionInbound txid tx = TraceTxSubmissionInbound --TODO data TxSubmissionProtocolError = ProtocolErrorTxNotRequested - + | ProtocolErrorTxIdsNotRequested deriving Show instance Exception TxSubmissionProtocolError where displayException ProtocolErrorTxNotRequested = "The peer replied with a transaction we did not ask for." + displayException ProtocolErrorTxIdsNotRequested = + "The peer replied with more txids than we asked for." -- | Information maintained internally in the 'txSubmissionInbound' server @@ -227,6 +229,16 @@ txSubmissionInbound _tracer maxUnacked mpReader mpWriter = -> Collect txid tx -> m (ServerStIdle n txid tx m ()) handleReply n st (CollectTxIds reqNo txids) = do + + -- Check they didn't send more than we asked for. We don't need to check + -- for a minimum: the blocking case checks for non-zero elsewhere, and + -- for the non-blocking case it is quite normal for them to send us none. + let txidsSeq = Seq.fromList (map fst txids) + txidsMap = Map.fromList txids + + unless (Seq.length txidsSeq <= fromIntegral reqNo) $ + throwM ProtocolErrorTxIdsNotRequested + -- Upon receiving a batch of new txids we extend our available set, -- and extended the unacknowledged sequence. -- @@ -235,10 +247,8 @@ txSubmissionInbound _tracer maxUnacked mpReader mpWriter = -- transactions again in the future. let st' = st { requestedTxIdsInFlight = requestedTxIdsInFlight st - reqNo, - unacknowledgedTxIds = unacknowledgedTxIds st - <> Seq.fromList (map fst txids), - availableTxids = availableTxids st - <> Map.fromList txids + unacknowledgedTxIds = unacknowledgedTxIds st <> txidsSeq, + availableTxids = availableTxids st <> txidsMap } mpSnapshot <- atomically mempoolGetSnapshot serverIdle n (acknowledgeTxIdsInMempool st' mpSnapshot)