diff --git a/src/messaging/BUILD.gn b/src/messaging/BUILD.gn index 5e596b615ac596..f5632a0e6b6b11 100644 --- a/src/messaging/BUILD.gn +++ b/src/messaging/BUILD.gn @@ -44,7 +44,6 @@ static_library("messaging") { "ExchangeContext.cpp", "ExchangeContext.h", "ExchangeDelegate.h", - "ExchangeMessageDispatch.cpp", "ExchangeMessageDispatch.h", "ExchangeMgr.cpp", "ExchangeMgr.h", diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index a89bf15ed2bc2c..c0d26b7747e378 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -147,11 +147,6 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp // an error arising below. at the end, we have to close it. ExchangeHandle ref(*this); - // If session requires MRP, NoAutoRequestAck send flag is not specified and is not a group exchange context, request reliable - // transmission. - bool reliableTransmissionRequested = - GetSessionHandle()->RequireMRP() && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck) && !IsGroupExchangeContext(); - // If a response message is expected... if (sendFlags.Has(SendMessageFlags::kExpectResponse) && !IsGroupExchangeContext()) { @@ -184,9 +179,78 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp } // Create a new scope for `err`, to avoid shadowing warning previous `err`. - CHIP_ERROR err = mDispatch.SendMessage(GetExchangeMgr()->GetSessionManager(), mSession.Get(), mExchangeId, IsInitiator(), - GetReliableMessageContext(), reliableTransmissionRequested, protocolId, msgType, - std::move(msgBuf)); + CHIP_ERROR err = ([&] { + VerifyOrReturnError(mDispatch.MessagePermitted(protocolId.GetProtocolId(), msgType), CHIP_ERROR_INVALID_ARGUMENT); + + PayloadHeader payloadHeader; + payloadHeader.SetExchangeID(mExchangeId).SetMessageType(protocolId, msgType).SetInitiator(IsInitiator()); + + ReliableMessageContext * reliableMessageContext = GetReliableMessageContext(); + + // If there is a pending acknowledgment piggyback it on this message. + if (reliableMessageContext->HasPiggybackAckPending()) + { + payloadHeader.SetAckMessageCounter(reliableMessageContext->TakePendingPeerAckMessageCounter()); + +#if !defined(NDEBUG) + if (!payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) + { + ChipLogDetail(ExchangeManager, + "Piggybacking Ack for MessageCounter:" ChipLogFormatMessageCounter + " on exchange: " ChipLogFormatExchangeId, + payloadHeader.GetAckMessageCounter().Value(), ChipLogValueExchangeId(mExchangeId, IsInitiator())); + } +#endif + } + + SessionManager * sessionManager = GetExchangeMgr()->GetSessionManager(); + SessionHandle session = GetSessionHandle(); + + // If session requires MRP, NoAutoRequestAck send flag is not specified, request reliable transmission. + if (mSession->RequireMRP() && !sendFlags.Has(SendMessageFlags::kNoAutoRequestAck) && reliableMessageContext->AutoRequestAck()) + { + auto * reliableMessageMgr = reliableMessageContext->GetReliableMessageMgr(); + + payloadHeader.SetNeedsAck(true); + + ReliableMessageMgr::RetransTableEntry * entry = nullptr; + + // Add to Table for subsequent sending + ReturnErrorOnFailure(reliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); + auto deleter = [reliableMessageMgr](ReliableMessageMgr::RetransTableEntry * e) { + reliableMessageMgr->ClearRetransTable(*e); + }; + std::unique_ptr entryOwner(entry, deleter); + + ReturnErrorOnFailure(sessionManager->PrepareMessage(session, payloadHeader, std::move(msgBuf), entryOwner->retainedBuf)); + CHIP_ERROR err2 = sessionManager->SendPreparedMessage(session, entryOwner->retainedBuf); + if (err2 == CHIP_ERROR_POSIX(ENOBUFS)) + { + // sendmsg on BSD-based systems never blocks, no matter how the + // socket is configured, and will return ENOBUFS in situation in + // which Linux, for example, blocks. + // + // This is typically a transient situation, so we pretend like this + // packet drop happened somewhere on the network instead of inside + // sendmsg and will just resend it in the normal MRP way later. + ChipLogError(ExchangeManager, "Ignoring ENOBUFS: %" CHIP_ERROR_FORMAT " on exchange " ChipLogFormatExchangeId, + err2.Format(), ChipLogValueExchangeId(mExchangeId, IsInitiator())); + err2 = CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err2); + reliableMessageMgr->StartRetransmision(entryOwner.release()); + } + else + { + // If the channel itself is providing reliability, let's not request MRP acks + payloadHeader.SetNeedsAck(false); + EncryptedPacketBufferHandle preparedMessage; + ReturnErrorOnFailure(sessionManager->PrepareMessage(session, payloadHeader, std::move(msgBuf), preparedMessage)); + ReturnErrorOnFailure(sessionManager->SendPreparedMessage(session, preparedMessage)); + } + + return CHIP_NO_ERROR; + })(); if (err != CHIP_NO_ERROR && IsResponseExpected()) { CancelResponseTimer(); @@ -272,10 +336,10 @@ void ExchangeContextDeletor::Release(ExchangeContext * ec) ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, const SessionHandle & session, bool Initiator, ExchangeDelegate * delegate) : - mDispatch((delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance()), - mSession(*this) + mDispatch(ExchangeManager::GetDispatchForDelegate(delegate)), mSession(*this) { VerifyOrDie(mExchangeMgr == nullptr); + VerifyOrDie(mDispatch.IsEncryptionRequired() == session->IsEncrypted()); mExchangeMgr = em; mExchangeId = ExchangeId; @@ -286,9 +350,7 @@ ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, cons SetDropAckDebug(false); SetAckPending(false); SetMsgRcvdFromPeer(false); - - // Do not request Ack for multicast - SetAutoRequestAck(!session->IsGroupSession()); + SetAutoRequestAck(true); #if defined(CHIP_EXCHANGE_CONTEXT_DETAIL_LOGGING) ChipLogDetail(ExchangeManager, "ec++ id: " ChipLogFormatExchange, ChipLogValueExchange(this)); @@ -331,10 +393,6 @@ bool ExchangeContext::MatchExchange(const SessionHandle & session, const PacketH // AND The Session associated with the incoming message matches the Session associated with the exchange. && (mSession.Contains(session)) - // TODO: This check should be already implied by the equality of session check, - // It should be removed after we have implemented the temporary node id for PASE and CASE sessions - && (IsEncryptionRequired() == packetHeader.IsEncrypted()) - // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast // case, the initiator is ill defined) @@ -457,8 +515,24 @@ CHIP_ERROR ExchangeContext::HandleMessage(uint32_t messageCounter, const Payload MessageHandled(); }); - ReturnErrorOnFailure( - mDispatch.OnMessageReceived(messageCounter, payloadHeader, peerAddress, msgFlags, GetReliableMessageContext())); + VerifyOrReturnError(mDispatch.MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), CHIP_ERROR_INVALID_ARGUMENT); + + if (mSession->RequireMRP()) + { + ReliableMessageContext * reliableMessageContext = GetReliableMessageContext(); + + if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && + payloadHeader.GetAckMessageCounter().HasValue()) + { + reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); + } + + if (payloadHeader.NeedsAck()) + { + // An acknowledgment needs to be sent back to the peer for this message on this exchange, + ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags)); + } + } if (IsAckPending() && !mDelegate) { diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index 4e616194584037..4f7aa66df980c3 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -77,8 +77,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, */ bool IsInitiator() const; - bool IsEncryptionRequired() const { return mDispatch.IsEncryptionRequired(); } - bool IsGroupExchangeContext() const { return mSession && mSession->IsGroupSession(); } // Implement SessionReleaseDelegate diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp deleted file mode 100644 index 7ca2e5b59901c9..00000000000000 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ /dev/null @@ -1,143 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file provides implementation of ExchangeMessageDispatch class. - */ - -#ifndef __STDC_FORMAT_MACROS -#define __STDC_FORMAT_MACROS -#endif - -#ifndef __STDC_LIMIT_MACROS -#define __STDC_LIMIT_MACROS -#endif - -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace chip { -namespace Messaging { - -CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, - bool isInitiator, ReliableMessageContext * reliableMessageContext, - bool isReliableTransmission, Protocols::Id protocol, uint8_t type, - System::PacketBufferHandle && message) -{ - ReturnErrorCodeIf(!MessagePermitted(protocol.GetProtocolId(), type), CHIP_ERROR_INVALID_ARGUMENT); - - PayloadHeader payloadHeader; - payloadHeader.SetExchangeID(exchangeId).SetMessageType(protocol, type).SetInitiator(isInitiator); - - // If there is a pending acknowledgment piggyback it on this message. - if (reliableMessageContext->HasPiggybackAckPending()) - { - payloadHeader.SetAckMessageCounter(reliableMessageContext->TakePendingPeerAckMessageCounter()); - -#if !defined(NDEBUG) - if (!payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck)) - { - ChipLogDetail(ExchangeManager, - "Piggybacking Ack for MessageCounter:" ChipLogFormatMessageCounter - " on exchange: " ChipLogFormatExchangeId, - payloadHeader.GetAckMessageCounter().Value(), ChipLogValueExchangeId(exchangeId, isInitiator)); - } -#endif - } - - if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() && - reliableMessageContext->GetReliableMessageMgr() != nullptr && isReliableTransmission) - { - auto * reliableMessageMgr = reliableMessageContext->GetReliableMessageMgr(); - - payloadHeader.SetNeedsAck(true); - - ReliableMessageMgr::RetransTableEntry * entry = nullptr; - - // Add to Table for subsequent sending - ReturnErrorOnFailure(reliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); - auto deleter = [reliableMessageMgr](ReliableMessageMgr::RetransTableEntry * e) { - reliableMessageMgr->ClearRetransTable(*e); - }; - std::unique_ptr entryOwner(entry, deleter); - - ReturnErrorOnFailure(sessionManager->PrepareMessage(session, payloadHeader, std::move(message), entryOwner->retainedBuf)); - CHIP_ERROR err = sessionManager->SendPreparedMessage(session, entryOwner->retainedBuf); - if (err == CHIP_ERROR_POSIX(ENOBUFS)) - { - // sendmsg on BSD-based systems never blocks, no matter how the - // socket is configured, and will return ENOBUFS in situation in - // which Linux, for example, blocks. - // - // This is typically a transient situation, so we pretend like this - // packet drop happened somewhere on the network instead of inside - // sendmsg and will just resend it in the normal MRP way later. - ChipLogError(ExchangeManager, "Ignoring ENOBUFS: %" CHIP_ERROR_FORMAT " on exchange " ChipLogFormatExchangeId, - err.Format(), ChipLogValueExchangeId(exchangeId, isInitiator)); - err = CHIP_NO_ERROR; - } - ReturnErrorOnFailure(err); - reliableMessageMgr->StartRetransmision(entryOwner.release()); - } - else - { - // If the channel itself is providing reliability, let's not request MRP acks - payloadHeader.SetNeedsAck(false); - EncryptedPacketBufferHandle preparedMessage; - ReturnErrorOnFailure(sessionManager->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage)); - ReturnErrorOnFailure(sessionManager->SendPreparedMessage(session, preparedMessage)); - } - - return CHIP_NO_ERROR; -} - -CHIP_ERROR ExchangeMessageDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, - const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, - ReliableMessageContext * reliableMessageContext) -{ - ReturnErrorCodeIf(!MessagePermitted(payloadHeader.GetProtocolID().GetProtocolId(), payloadHeader.GetMessageType()), - CHIP_ERROR_INVALID_ARGUMENT); - - if (IsReliableTransmissionAllowed() && !reliableMessageContext->GetExchangeContext()->IsGroupExchangeContext()) - { - if (!msgFlags.Has(MessageFlagValues::kDuplicateMessage) && payloadHeader.IsAckMsg() && - payloadHeader.GetAckMessageCounter().HasValue()) - { - reliableMessageContext->HandleRcvdAck(payloadHeader.GetAckMessageCounter().Value()); - } - - if (payloadHeader.NeedsAck()) - { - // An acknowledgment needs to be sent back to the peer for this message on this exchange, - - ReturnErrorOnFailure(reliableMessageContext->HandleNeedsAck(messageCounter, msgFlags)); - } - } - - return CHIP_NO_ERROR; -} - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index 464387f4e05ff9..4ca3c599e5d2e9 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -34,23 +34,9 @@ class ReliableMessageContext; class ExchangeMessageDispatch { public: - ExchangeMessageDispatch() {} virtual ~ExchangeMessageDispatch() {} - virtual bool IsEncryptionRequired() const { return true; } - - CHIP_ERROR SendMessage(SessionManager * sessionManager, const SessionHandle & session, uint16_t exchangeId, bool isInitiator, - ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, - uint8_t type, System::PacketBufferHandle && message); - CHIP_ERROR OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, - const Transport::PeerAddress & peerAddress, MessageFlags msgFlags, - ReliableMessageContext * reliableMessageContext); - -protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; - - // TODO: remove IsReliableTransmissionAllowed, this function should be provided over session. - virtual bool IsReliableTransmissionAllowed() const { return true; } }; } // namespace Messaging diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index adfc214db89a46..58387d2225dd8e 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -288,6 +288,12 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const } } + if (GetDispatchForDelegate(delegate).IsEncryptionRequired() != session->IsEncrypted()) + { + ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_NO_UNSOLICITED_MESSAGE_HANDLER)); + return; + } + // If rcvd msg is from initiator then this exchange is created as not Initiator. // If rcvd msg is not from initiator then this exchange is created as Initiator. // Note that if matchingUMH is not null then rcvd msg if from initiator. @@ -310,13 +316,6 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: %p", ChipLogValueExchange(ec), ec->GetDelegate()); - if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) - { - ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE)); - ec->Close(); - return; - } - CHIP_ERROR err = ec->HandleMessage(packetHeader.GetMessageCounter(), payloadHeader, source, msgFlags, std::move(msgBuf)); if (err != CHIP_NO_ERROR) { diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index a52e24024f468a..853a5a79f9dfcc 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -194,6 +194,11 @@ class DLL_EXPORT ExchangeManager : public SessionMessageDelegate size_t GetNumActiveExchanges() { return mContextPool.Allocated(); } + static ExchangeMessageDispatch & GetDispatchForDelegate(ExchangeDelegate * delegate) + { + return (delegate != nullptr) ? delegate->GetMessageDispatch() : ApplicationExchangeDispatch::Instance(); + } + private: enum class State { diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 513935d3b0ee69..11b3cc453309cc 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -135,14 +135,8 @@ class MockAppDelegate : public UnsolicitedMessageHandler, public ExchangeDelegat class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationExchangeDispatch { public: - bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } - bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } - bool IsEncryptionRequired() const override { return mRequireEncryption; } - - bool mRetainMessageOnSend = true; - bool mRequireEncryption = false; }; @@ -332,6 +326,7 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) CHIP_ERROR err = CHIP_NO_ERROR; MockSessionEstablishmentDelegate mockSender; + mockSender.mMessageDispatch.mRequireEncryption = true; ExchangeContext * exchange = ctx.NewExchangeToAlice(&mockSender); NL_TEST_ASSERT(inSuite, exchange != nullptr); @@ -343,7 +338,6 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) 64_ms32, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL }); - mockSender.mMessageDispatch.mRetainMessageOnSend = false; // Let's drop the initial message gLoopback.mSentMessageCount = 0; gLoopback.mNumMessagesToDrop = 1; @@ -391,6 +385,7 @@ void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inConte gLoopback.mNumMessagesToDrop = 0; gLoopback.mDroppedMessageCount = 0; + // Send a plaintext message targeting a encrypted unsolicited message handler // We are sending a malicious packet, doesn't expect an ack err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer), SendFlags(SendMessageFlags::kNoAutoRequestAck)); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 9cb75a596ff6a4..c05886284c8dc5 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -37,10 +37,7 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi return instance; } - SessionEstablishmentExchangeDispatch() {} ~SessionEstablishmentExchangeDispatch() override {} - -protected: bool MessagePermitted(uint16_t protocol, uint8_t type) override; bool IsEncryptionRequired() const override { return false; } }; diff --git a/src/transport/Session.h b/src/transport/Session.h index b51c78964cf3d2..242b8da98015e5 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -90,6 +90,8 @@ class Session bool IsSecureSession() const { return GetSessionType() == SessionType::kSecure; } + bool IsEncrypted() const { return IsSecureSession() || IsGroupSession(); } + protected: // This should be called by sub-classes at the very beginning of the destructor, before any data field is disposed, such that // the session is still functional during the callback.