diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp index 78b0a56fdd7c26..470c875fbaac30 100644 --- a/src/app/ReadClient.cpp +++ b/src/app/ReadClient.cpp @@ -423,33 +423,54 @@ CHIP_ERROR ReadClient::BuildDataVersionFilterList(DataVersionFilterIBs::Builder continue; } - DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter(); - ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError()); - ClusterPathIB::Builder & path = filterIB.CreatePath(); - ReturnErrorOnFailure(filterIB.GetError()); - ReturnErrorOnFailure(path.Endpoint(filter.mEndpointId).Cluster(filter.mClusterId).EndOfClusterPathIB()); - VerifyOrReturnError(filter.mDataVersion.HasValue(), CHIP_ERROR_INVALID_ARGUMENT); - ReturnErrorOnFailure(filterIB.DataVersion(filter.mDataVersion.Value()).EndOfDataVersionFilterIB()); - aEncodedDataVersionList = true; + TLV::TLVWriter backup; + aDataVersionFilterIBsBuilder.Checkpoint(backup); + CHIP_ERROR err = EncodeDataVersionFilter(aDataVersionFilterIBsBuilder, filter); + if (err == CHIP_NO_ERROR) + { + aEncodedDataVersionList = true; + } + else if (err == CHIP_ERROR_NO_MEMORY || err == CHIP_ERROR_BUFFER_TOO_SMALL) + { + // Packet is full, ignore the rest of the list + aDataVersionFilterIBsBuilder.Rollback(backup); + return CHIP_NO_ERROR; + } + else + { + return err; + } } return CHIP_NO_ERROR; } +CHIP_ERROR ReadClient::EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder, + DataVersionFilter const & aFilter) +{ + // Caller has checked aFilter.IsValidDataVersionFilter() + DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter(); + ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError()); + ClusterPathIB::Builder & path = filterIB.CreatePath(); + ReturnErrorOnFailure(filterIB.GetError()); + ReturnErrorOnFailure(path.Endpoint(aFilter.mEndpointId).Cluster(aFilter.mClusterId).EndOfClusterPathIB()); + ReturnErrorOnFailure(filterIB.DataVersion(aFilter.mDataVersion.Value()).EndOfDataVersionFilterIB()); + return CHIP_NO_ERROR; +} + CHIP_ERROR ReadClient::GenerateDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder, const Span & aAttributePaths, const Span & aDataVersionFilters, bool & aEncodedDataVersionList) { - if (!aDataVersionFilters.empty()) + // Give the callback a chance first, otherwise use the list we have, if any. + ReturnErrorOnFailure( + mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList)); + + if (!aEncodedDataVersionList) { ReturnErrorOnFailure(BuildDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aDataVersionFilters, aEncodedDataVersionList)); } - else - { - ReturnErrorOnFailure( - mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList)); - } return CHIP_NO_ERROR; } diff --git a/src/app/ReadClient.h b/src/app/ReadClient.h index 219a5be73f0b09..9bfb628b47e252 100644 --- a/src/app/ReadClient.h +++ b/src/app/ReadClient.h @@ -339,6 +339,9 @@ class ReadClient : public Messaging::ExchangeDelegate * This will send either a Read Request or a Subscribe Request depending on * the InteractionType this read client was initialized with. * + * If the params contain more data version filters than can fit in the request packet + * the list will be truncated as needed, i.e. filter inclusion is on a best effort basis. + * * @retval #others fail to send read request * @retval #CHIP_NO_ERROR On success. */ @@ -559,6 +562,8 @@ class ReadClient : public Messaging::ExchangeDelegate CHIP_ERROR BuildDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder, const Span & aAttributePaths, const Span & aDataVersionFilters, bool & aEncodedDataVersionList); + CHIP_ERROR EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder, + DataVersionFilter const & aFilter); CHIP_ERROR ReadICDOperatingModeFromAttributeDataIB(TLV::TLVReader && aReader, PeerType & aType); CHIP_ERROR ProcessAttributeReportIBs(TLV::TLVReader & aAttributeDataIBsReader); CHIP_ERROR ProcessEventReportIBs(TLV::TLVReader & aEventReportIBsReader); diff --git a/src/controller/tests/data_model/BUILD.gn b/src/controller/tests/data_model/BUILD.gn index 017980c4cd11f7..d8ce821ed837af 100644 --- a/src/controller/tests/data_model/BUILD.gn +++ b/src/controller/tests/data_model/BUILD.gn @@ -35,6 +35,7 @@ chip_test_suite("data_model") { "${chip_root}/src/app/tests:helpers", "${chip_root}/src/app/util/mock:mock_ember", "${chip_root}/src/controller", + "${chip_root}/src/lib/core:string-builder-adapters", "${chip_root}/src/messaging/tests:helpers", "${chip_root}/src/transport/raw/tests:helpers", ] diff --git a/src/controller/tests/data_model/TestCommands.cpp b/src/controller/tests/data_model/TestCommands.cpp index b6d39ae41978bc..11b71415ad6ec5 100644 --- a/src/controller/tests/data_model/TestCommands.cpp +++ b/src/controller/tests/data_model/TestCommands.cpp @@ -22,7 +22,8 @@ * */ -#include +#include +#include #include "app/data-model/NullObject.h" #include diff --git a/src/controller/tests/data_model/TestRead.cpp b/src/controller/tests/data_model/TestRead.cpp index a27eb61fcb2901..906ba82a213042 100644 --- a/src/controller/tests/data_model/TestRead.cpp +++ b/src/controller/tests/data_model/TestRead.cpp @@ -16,7 +16,8 @@ * limitations under the License. */ -#include +#include +#include #include "system/SystemClock.h" #include "transport/SecureSession.h" @@ -25,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -3093,6 +3095,86 @@ TEST_F(TestRead, TestReadHandler_MultipleSubscriptionsWithDataVersionFilter) EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u); } +TEST_F(TestRead, TestReadHandler_DataVersionFiltersTruncated) +{ + struct : public chip::Test::LoopbackTransportDelegate + { + size_t requestSize = 0; + void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) override + { + // We only care about the messages we (Alice) send to Bob, not the responses. + // Assume the first message we see in an iteration is the request. + if (peer == mpContext->GetBobAddress() && requestSize == 0) + { + requestSize = message->TotalLength(); + } + } + } loopbackDelegate; + mpContext->GetLoopback().SetLoopbackTransportDelegate(&loopbackDelegate); + + // Note that on the server side, wildcard expansion does not actually work for kTestEndpointId due + // to lack of meta-data, but we don't care about the reports we get back in this test. + AttributePathParams wildcardPath(kTestEndpointId, kInvalidClusterId, kInvalidAttributeId); + constexpr size_t maxDataVersionFilterCount = 100; + DataVersionFilter dataVersionFilters[maxDataVersionFilterCount]; + ClusterId nextClusterId = 0; + for (auto & dv : dataVersionFilters) + { + dv.mEndpointId = wildcardPath.mEndpointId; + dv.mClusterId = nextClusterId++; + dv.mDataVersion = MakeOptional(0x01000000u); + } + + // Keep increasing the number of data version filters until we see truncation kick in. + size_t lastRequestSize; + for (size_t count = 1; count <= maxDataVersionFilterCount; count++) + { + lastRequestSize = loopbackDelegate.requestSize; + loopbackDelegate.requestSize = 0; // reset + + ReadPrepareParams read(mpContext->GetSessionAliceToBob()); + read.mpAttributePathParamsList = &wildcardPath; + read.mAttributePathParamsListSize = 1; + read.mpDataVersionFilterList = dataVersionFilters; + read.mDataVersionFilterListSize = count; + + struct : public ReadClient::Callback + { + CHIP_ERROR error = CHIP_NO_ERROR; + bool done = false; + void OnError(CHIP_ERROR aError) override { error = aError; } + void OnDone(ReadClient * apReadClient) override { done = true; }; + + } readCallback; + + ReadClient readClient(app::InteractionModelEngine::GetInstance(), &mpContext->GetExchangeManager(), readCallback, + ReadClient::InteractionType::Read); + + EXPECT_EQ(readClient.SendRequest(read), CHIP_NO_ERROR); + + mpContext->GetIOContext().DriveIOUntil(System::Clock::Seconds16(5), [&]() { return readCallback.done; }); + EXPECT_EQ(readCallback.error, CHIP_NO_ERROR); + EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u); + + EXPECT_NE(loopbackDelegate.requestSize, 0u); + EXPECT_GE(loopbackDelegate.requestSize, lastRequestSize); + if (loopbackDelegate.requestSize == lastRequestSize) + { + ChipLogProgress(DataManagement, "Data Version truncation detected after %llu elements", + static_cast(count - 1)); + // With the parameters used in this test and current encoding rules we can fit 68 data versions + // into a packet. If we're seeing substantially less then something is likely gone wrong. + EXPECT_GE(count, 60u); + ExitNow(); + } + } + ChipLogProgress(DataManagement, "Unable to detect Data Version truncation, maxDataVersionFilterCount too small?"); + ADD_FAILURE(); + +exit: + mpContext->GetLoopback().SetLoopbackTransportDelegate(nullptr); +} + TEST_F(TestRead, TestReadHandlerResourceExhaustion_MultipleReads) { auto sessionHandle = mpContext->GetSessionBobToAlice(); diff --git a/src/controller/tests/data_model/TestWrite.cpp b/src/controller/tests/data_model/TestWrite.cpp index e0f2c19480a1b0..da5505ccd613a1 100644 --- a/src/controller/tests/data_model/TestWrite.cpp +++ b/src/controller/tests/data_model/TestWrite.cpp @@ -16,7 +16,8 @@ * limitations under the License. */ -#include +#include +#include #include "app-common/zap-generated/ids/Clusters.h" #include diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 51525f967d1475..f0e950a2cfc2a9 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -94,7 +94,7 @@ class MessagingContext : public PlatformMemoryUser MessagingContext() : mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)), - mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)) + mBobAddress(LoopbackTransport::LoopbackPeer(mAliceAddress)) {} // TODO Replace VerifyOrDie with Pigweed assert after transition app/tests to Pigweed. // TODO Currently src/app/icd/server/tests is using MessagingConetext as dependency. diff --git a/src/transport/raw/tests/NetworkTestHelpers.h b/src/transport/raw/tests/NetworkTestHelpers.h index 49890406bc6afb..fca2b6b04cfd7c 100644 --- a/src/transport/raw/tests/NetworkTestHelpers.h +++ b/src/transport/raw/tests/NetworkTestHelpers.h @@ -64,6 +64,10 @@ class LoopbackTransportDelegate public: virtual ~LoopbackTransportDelegate() {} + // Called by the loopback transport when a message is requested to be sent. + // This is called even if the message is subsequently rejected or dropped. + virtual void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) {} + // Called by the loopback transport when it drops one of a configurable number of messages (mDroppedMessageCount) after a // configurable allowed number of messages (mNumMessagesToAllowBeforeDropping) virtual void OnMessageDropped() {} @@ -72,6 +76,18 @@ class LoopbackTransportDelegate class LoopbackTransport : public Transport::Base { public: + // In test scenarios using the loopback transport, we're only ever given + // the address we're sending to, but we don't have any information about + // what our local address is. Assume our fake addresses come in pairs of + // even and odd port numbers, so we can calculate one from the other by + // flipping the LSB of the port number. + static Transport::PeerAddress LoopbackPeer(const Transport::PeerAddress & address) + { + Transport::PeerAddress other(address); + other.SetPort(address.GetPort() ^ 1); + return other; + } + void InitLoopbackTransport(System::Layer * systemLayer) { Reset(); @@ -100,7 +116,7 @@ class LoopbackTransport : public Transport::Base { auto item = std::move(_this->mPendingMessageQueue.front()); _this->mPendingMessageQueue.pop(); - _this->HandleMessageReceived(item.mDestinationAddress, std::move(item.mPendingMessage)); + _this->HandleMessageReceived(LoopbackPeer(item.mDestinationAddress), std::move(item.mPendingMessage)); } } @@ -108,6 +124,11 @@ class LoopbackTransport : public Transport::Base CHIP_ERROR SendMessage(const Transport::PeerAddress & address, System::PacketBufferHandle && msgBuf) override { + if (mDelegate != nullptr) + { + mDelegate->WillSendMessage(address, msgBuf); + } + if (mNumMessagesToAllowBeforeError == 0) { ReturnErrorOnFailure(mMessageSendError);