Skip to content

Commit

Permalink
ReadClient: Truncate data version list during encoding if necessary (#…
Browse files Browse the repository at this point in the history
…34111)

* ReadClient: Truncate data version list during encoding if necessary

The existing code made the assumption that if a list of versions was able to
fit into the request packet when generating the first subscribe request, then
any resubscribe containing data versions for the same clusters would also fit.
However the data version numbers themselves can be updated when we receive
reports, and this can cause the list to no longer fit the request packet,
leaving us in a state where every resubscribe attempt would fail.

Note that this change means even an initial subscribe request with a data
version list that is too long will no longer fail; ReadClient will simply
truncate the list as needed in all cases.

* Apply comment suggestions from code review

Co-authored-by: Boris Zbarsky <[email protected]>

* Treat CHIP_ERROR_BUFFER_TOO_SMALL the same

* Switch data_model tests to pw_unit_test

* Add WillSendMessage to loopback delegate and make source addresses more plausible

* Add test for ReadClient data version truncation

* Make the linter happy

---------

Co-authored-by: Boris Zbarsky <[email protected]>
  • Loading branch information
ksperling-apple and bzbarsky-apple authored Jul 2, 2024
1 parent 247f05d commit 102faca
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 19 deletions.
49 changes: 35 additions & 14 deletions src/app/ReadClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,33 +423,54 @@ CHIP_ERROR ReadClient::BuildDataVersionFilterList(DataVersionFilterIBs::Builder
continue;
}

DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter();
ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError());
ClusterPathIB::Builder & path = filterIB.CreatePath();
ReturnErrorOnFailure(filterIB.GetError());
ReturnErrorOnFailure(path.Endpoint(filter.mEndpointId).Cluster(filter.mClusterId).EndOfClusterPathIB());
VerifyOrReturnError(filter.mDataVersion.HasValue(), CHIP_ERROR_INVALID_ARGUMENT);
ReturnErrorOnFailure(filterIB.DataVersion(filter.mDataVersion.Value()).EndOfDataVersionFilterIB());
aEncodedDataVersionList = true;
TLV::TLVWriter backup;
aDataVersionFilterIBsBuilder.Checkpoint(backup);
CHIP_ERROR err = EncodeDataVersionFilter(aDataVersionFilterIBsBuilder, filter);
if (err == CHIP_NO_ERROR)
{
aEncodedDataVersionList = true;
}
else if (err == CHIP_ERROR_NO_MEMORY || err == CHIP_ERROR_BUFFER_TOO_SMALL)
{
// Packet is full, ignore the rest of the list
aDataVersionFilterIBsBuilder.Rollback(backup);
return CHIP_NO_ERROR;
}
else
{
return err;
}
}
return CHIP_NO_ERROR;
}

CHIP_ERROR ReadClient::EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
DataVersionFilter const & aFilter)
{
// Caller has checked aFilter.IsValidDataVersionFilter()
DataVersionFilterIB::Builder & filterIB = aDataVersionFilterIBsBuilder.CreateDataVersionFilter();
ReturnErrorOnFailure(aDataVersionFilterIBsBuilder.GetError());
ClusterPathIB::Builder & path = filterIB.CreatePath();
ReturnErrorOnFailure(filterIB.GetError());
ReturnErrorOnFailure(path.Endpoint(aFilter.mEndpointId).Cluster(aFilter.mClusterId).EndOfClusterPathIB());
ReturnErrorOnFailure(filterIB.DataVersion(aFilter.mDataVersion.Value()).EndOfDataVersionFilterIB());
return CHIP_NO_ERROR;
}

CHIP_ERROR ReadClient::GenerateDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
const Span<AttributePathParams> & aAttributePaths,
const Span<DataVersionFilter> & aDataVersionFilters,
bool & aEncodedDataVersionList)
{
if (!aDataVersionFilters.empty())
// Give the callback a chance first, otherwise use the list we have, if any.
ReturnErrorOnFailure(
mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList));

if (!aEncodedDataVersionList)
{
ReturnErrorOnFailure(BuildDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aDataVersionFilters,
aEncodedDataVersionList));
}
else
{
ReturnErrorOnFailure(
mpCallback.OnUpdateDataVersionFilterList(aDataVersionFilterIBsBuilder, aAttributePaths, aEncodedDataVersionList));
}

return CHIP_NO_ERROR;
}
Expand Down
5 changes: 5 additions & 0 deletions src/app/ReadClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ class ReadClient : public Messaging::ExchangeDelegate
* This will send either a Read Request or a Subscribe Request depending on
* the InteractionType this read client was initialized with.
*
* If the params contain more data version filters than can fit in the request packet
* the list will be truncated as needed, i.e. filter inclusion is on a best effort basis.
*
* @retval #others fail to send read request
* @retval #CHIP_NO_ERROR On success.
*/
Expand Down Expand Up @@ -559,6 +562,8 @@ class ReadClient : public Messaging::ExchangeDelegate
CHIP_ERROR BuildDataVersionFilterList(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
const Span<AttributePathParams> & aAttributePaths,
const Span<DataVersionFilter> & aDataVersionFilters, bool & aEncodedDataVersionList);
CHIP_ERROR EncodeDataVersionFilter(DataVersionFilterIBs::Builder & aDataVersionFilterIBsBuilder,
DataVersionFilter const & aFilter);
CHIP_ERROR ReadICDOperatingModeFromAttributeDataIB(TLV::TLVReader && aReader, PeerType & aType);
CHIP_ERROR ProcessAttributeReportIBs(TLV::TLVReader & aAttributeDataIBsReader);
CHIP_ERROR ProcessEventReportIBs(TLV::TLVReader & aEventReportIBsReader);
Expand Down
1 change: 1 addition & 0 deletions src/controller/tests/data_model/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ chip_test_suite("data_model") {
"${chip_root}/src/app/tests:helpers",
"${chip_root}/src/app/util/mock:mock_ember",
"${chip_root}/src/controller",
"${chip_root}/src/lib/core:string-builder-adapters",
"${chip_root}/src/messaging/tests:helpers",
"${chip_root}/src/transport/raw/tests:helpers",
]
Expand Down
3 changes: 2 additions & 1 deletion src/controller/tests/data_model/TestCommands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
*
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "app/data-model/NullObject.h"
#include <app-common/zap-generated/cluster-objects.h>
Expand Down
84 changes: 83 additions & 1 deletion src/controller/tests/data_model/TestRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "system/SystemClock.h"
#include "transport/SecureSession.h"
Expand All @@ -25,6 +26,7 @@
#include <app/ConcreteAttributePath.h>
#include <app/ConcreteEventPath.h>
#include <app/InteractionModelEngine.h>
#include <app/ReadClient.h>
#include <app/tests/AppTestContext.h>
#include <app/util/mock/Constants.h>
#include <app/util/mock/Functions.h>
Expand Down Expand Up @@ -3093,6 +3095,86 @@ TEST_F(TestRead, TestReadHandler_MultipleSubscriptionsWithDataVersionFilter)
EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u);
}

TEST_F(TestRead, TestReadHandler_DataVersionFiltersTruncated)
{
struct : public chip::Test::LoopbackTransportDelegate
{
size_t requestSize = 0;
void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) override
{
// We only care about the messages we (Alice) send to Bob, not the responses.
// Assume the first message we see in an iteration is the request.
if (peer == mpContext->GetBobAddress() && requestSize == 0)
{
requestSize = message->TotalLength();
}
}
} loopbackDelegate;
mpContext->GetLoopback().SetLoopbackTransportDelegate(&loopbackDelegate);

// Note that on the server side, wildcard expansion does not actually work for kTestEndpointId due
// to lack of meta-data, but we don't care about the reports we get back in this test.
AttributePathParams wildcardPath(kTestEndpointId, kInvalidClusterId, kInvalidAttributeId);
constexpr size_t maxDataVersionFilterCount = 100;
DataVersionFilter dataVersionFilters[maxDataVersionFilterCount];
ClusterId nextClusterId = 0;
for (auto & dv : dataVersionFilters)
{
dv.mEndpointId = wildcardPath.mEndpointId;
dv.mClusterId = nextClusterId++;
dv.mDataVersion = MakeOptional(0x01000000u);
}

// Keep increasing the number of data version filters until we see truncation kick in.
size_t lastRequestSize;
for (size_t count = 1; count <= maxDataVersionFilterCount; count++)
{
lastRequestSize = loopbackDelegate.requestSize;
loopbackDelegate.requestSize = 0; // reset

ReadPrepareParams read(mpContext->GetSessionAliceToBob());
read.mpAttributePathParamsList = &wildcardPath;
read.mAttributePathParamsListSize = 1;
read.mpDataVersionFilterList = dataVersionFilters;
read.mDataVersionFilterListSize = count;

struct : public ReadClient::Callback
{
CHIP_ERROR error = CHIP_NO_ERROR;
bool done = false;
void OnError(CHIP_ERROR aError) override { error = aError; }
void OnDone(ReadClient * apReadClient) override { done = true; };

} readCallback;

ReadClient readClient(app::InteractionModelEngine::GetInstance(), &mpContext->GetExchangeManager(), readCallback,
ReadClient::InteractionType::Read);

EXPECT_EQ(readClient.SendRequest(read), CHIP_NO_ERROR);

mpContext->GetIOContext().DriveIOUntil(System::Clock::Seconds16(5), [&]() { return readCallback.done; });
EXPECT_EQ(readCallback.error, CHIP_NO_ERROR);
EXPECT_EQ(mpContext->GetExchangeManager().GetNumActiveExchanges(), 0u);

EXPECT_NE(loopbackDelegate.requestSize, 0u);
EXPECT_GE(loopbackDelegate.requestSize, lastRequestSize);
if (loopbackDelegate.requestSize == lastRequestSize)
{
ChipLogProgress(DataManagement, "Data Version truncation detected after %llu elements",
static_cast<unsigned long long>(count - 1));
// With the parameters used in this test and current encoding rules we can fit 68 data versions
// into a packet. If we're seeing substantially less then something is likely gone wrong.
EXPECT_GE(count, 60u);
ExitNow();
}
}
ChipLogProgress(DataManagement, "Unable to detect Data Version truncation, maxDataVersionFilterCount too small?");
ADD_FAILURE();

exit:
mpContext->GetLoopback().SetLoopbackTransportDelegate(nullptr);
}

TEST_F(TestRead, TestReadHandlerResourceExhaustion_MultipleReads)
{
auto sessionHandle = mpContext->GetSessionBobToAlice();
Expand Down
3 changes: 2 additions & 1 deletion src/controller/tests/data_model/TestWrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <lib/core/StringBuilderAdapters.h>
#include <pw_unit_test/framework.h>

#include "app-common/zap-generated/ids/Clusters.h"
#include <app-common/zap-generated/cluster-objects.h>
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/tests/MessagingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class MessagingContext : public PlatformMemoryUser

MessagingContext() :
mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)),
mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT))
mBobAddress(LoopbackTransport::LoopbackPeer(mAliceAddress))
{}
// TODO Replace VerifyOrDie with Pigweed assert after transition app/tests to Pigweed.
// TODO Currently src/app/icd/server/tests is using MessagingConetext as dependency.
Expand Down
23 changes: 22 additions & 1 deletion src/transport/raw/tests/NetworkTestHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class LoopbackTransportDelegate
public:
virtual ~LoopbackTransportDelegate() {}

// Called by the loopback transport when a message is requested to be sent.
// This is called even if the message is subsequently rejected or dropped.
virtual void WillSendMessage(const Transport::PeerAddress & peer, const System::PacketBufferHandle & message) {}

// Called by the loopback transport when it drops one of a configurable number of messages (mDroppedMessageCount) after a
// configurable allowed number of messages (mNumMessagesToAllowBeforeDropping)
virtual void OnMessageDropped() {}
Expand All @@ -72,6 +76,18 @@ class LoopbackTransportDelegate
class LoopbackTransport : public Transport::Base
{
public:
// In test scenarios using the loopback transport, we're only ever given
// the address we're sending to, but we don't have any information about
// what our local address is. Assume our fake addresses come in pairs of
// even and odd port numbers, so we can calculate one from the other by
// flipping the LSB of the port number.
static Transport::PeerAddress LoopbackPeer(const Transport::PeerAddress & address)
{
Transport::PeerAddress other(address);
other.SetPort(address.GetPort() ^ 1);
return other;
}

void InitLoopbackTransport(System::Layer * systemLayer)
{
Reset();
Expand Down Expand Up @@ -100,14 +116,19 @@ class LoopbackTransport : public Transport::Base
{
auto item = std::move(_this->mPendingMessageQueue.front());
_this->mPendingMessageQueue.pop();
_this->HandleMessageReceived(item.mDestinationAddress, std::move(item.mPendingMessage));
_this->HandleMessageReceived(LoopbackPeer(item.mDestinationAddress), std::move(item.mPendingMessage));
}
}

static constexpr uint32_t kUnlimitedMessageCount = std::numeric_limits<uint32_t>::max();

CHIP_ERROR SendMessage(const Transport::PeerAddress & address, System::PacketBufferHandle && msgBuf) override
{
if (mDelegate != nullptr)
{
mDelegate->WillSendMessage(address, msgBuf);
}

if (mNumMessagesToAllowBeforeError == 0)
{
ReturnErrorOnFailure(mMessageSendError);
Expand Down

0 comments on commit 102faca

Please sign in to comment.