diff --git a/src/overlay/FlowControl.cpp b/src/overlay/FlowControl.cpp index fe4c81adb7..0960259ee2 100644 --- a/src/overlay/FlowControl.cpp +++ b/src/overlay/FlowControl.cpp @@ -70,7 +70,7 @@ FlowControl::maybeReleaseCapacity(StellarMessage const& msg) releaseAssert(threadIsMain()); std::lock_guard guard(mFlowControlMutex); - if (msg.type() == SEND_MORE || msg.type() == SEND_MORE_EXTENDED) + if (msg.type() == SEND_MORE_EXTENDED) { if (mNoOutboundCapacity) { @@ -283,8 +283,8 @@ FlowControl::canRead() const uint32_t FlowControl::getNumMessages(StellarMessage const& msg) { - return msg.type() == SEND_MORE ? msg.sendMoreMessage().numMessages - : msg.sendMoreExtendedMessage().numMessages; + releaseAssert(msg.type() == SEND_MORE_EXTENDED); + return msg.sendMoreExtendedMessage().numMessages; } bool @@ -294,10 +294,7 @@ FlowControl::isSendMoreValid(StellarMessage const& msg, releaseAssert(threadIsMain()); std::lock_guard guard(mFlowControlMutex); - bool sendMoreExtendedType = msg.type() == SEND_MORE_EXTENDED; - bool sendMoreType = msg.type() == SEND_MORE; - - if (!sendMoreExtendedType && !sendMoreType) + if (msg.type() != SEND_MORE_EXTENDED) { errorMsg = fmt::format("unexpected message type {}", diff --git a/src/overlay/FlowControlCapacity.cpp b/src/overlay/FlowControlCapacity.cpp index 69d2f66bcf..ec86fab896 100644 --- a/src/overlay/FlowControlCapacity.cpp +++ b/src/overlay/FlowControlCapacity.cpp @@ -36,7 +36,7 @@ void FlowControlMessageCapacity::releaseOutboundCapacity(StellarMessage const& msg) { ZoneScoped; - releaseAssert(msg.type() == SEND_MORE || msg.type() == SEND_MORE_EXTENDED); + releaseAssert(msg.type() == SEND_MORE_EXTENDED); auto numMessages = FlowControl::getNumMessages(msg); if (!hasOutboundCapacity(msg) && numMessages != 0) { diff --git a/src/overlay/test/OverlayTests.cpp b/src/overlay/test/OverlayTests.cpp index 26365c6058..91017e0d77 100644 --- a/src/overlay/test/OverlayTests.cpp +++ b/src/overlay/test/OverlayTests.cpp @@ -562,6 +562,15 @@ TEST_CASE("loopback peer flow control activation", "[overlay][flowcontrol]") conn.getAcceptor()->sendSendMore(0, 0); dropReason = "invalid message SEND_MORE_EXTENDED"; } + SECTION("invalid message type") + { + // Manually construct a SEND_MORE message and send it + auto m = std::make_shared(); + m->type(SEND_MORE); + m->sendMoreMessage().numMessages = 1; + conn.getAcceptor()->sendAuthenticatedMessageForTesting(m); + dropReason = "unexpected message type SEND_MORE"; + } testutil::crankSome(clock); REQUIRE(!conn.getInitiator()->isConnectedForTesting()); REQUIRE(!conn.getAcceptor()->isConnectedForTesting());