Skip to content

Commit

Permalink
Add async message dispatch to loopback (#11461)
Browse files Browse the repository at this point in the history
* Add async message dispatch to loopback

This PR was triggered by some test failures in some of the end-to-end IM
unit tests that utilized the loopback transport to send/receive payloads
from client to server and back. Since the current loopback transport
processes 'transmitted' messages synchronously without completing the
execution of the original context, it results in call flows that are not
typical of actual devices interacting with each other. This resulted in
a use-after-free error where the upon calling SendMessage() within the
CommandSender, the synchronous execution resulted in the eventual
destruction of the original CommandSender object immediately after
SendMessage() was called.

This PR adds support for asynchronous dispatch and handling of
transmitted messages that is more representative of real-world CHIP
node interactions to the existing loopback interface. It utilizes
SystemLayer::ScheduleWork to handle the processing of the sent message
as a bottom half handler.

It also adds a DrainAndServiceIO method on the AppContext that will
automatically drain and service the IO till all messages have been
handled.

Tests:
- Ensured the TestCommand failure doesn't happen again.

* Apply suggestions from code review

Co-authored-by: Boris Zbarsky <[email protected]>

Co-authored-by: Boris Zbarsky <[email protected]>
  • Loading branch information
2 people authored and pull[bot] committed Mar 28, 2023
1 parent 8a80e88 commit 1465445
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 4 deletions.
54 changes: 54 additions & 0 deletions src/app/tests/AppTestContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once

#include "system/SystemClock.h"
#include <messaging/tests/MessagingContext.h>
#include <transport/raw/tests/NetworkTestHelpers.h>

Expand All @@ -36,12 +37,65 @@ class AppContext : public MessagingContext
// Shutdown all layers, finalize operations
CHIP_ERROR Shutdown();

/*
* For unit-tests that simulate end-to-end transmission and reception of messages in loopback mode,
* this mode better replicates a real-functioning stack that correctly handles the processing
* of a transmitted message as an asynchronous, bottom half handler dispatched after the current execution context has
completed.
* This is achieved using SystemLayer::ScheduleWork.
* This should be used in conjunction with the DrainAndServiceIO function below to correctly service and drain the event queue.
*
*/
void EnableAsyncDispatch()
{
auto & impl = mTransportManager.GetTransport().GetImplAtIndex<0>();
impl.EnableAsyncDispatch(&mIOContext.GetSystemLayer());
}

/*
* This drives the servicing of events using the embedded IOContext while there are pending
* messages in the loopback transport's pending message queue. This should run to completion
* in well-behaved logic (i.e there isn't an indefinite ping-pong of messages transmitted back
* and forth).
*
* Consequently, this is guarded with a user-provided timeout to ensure we don't have unit-tests that stall
* in CI due to bugs in the code that is being tested.
*
* This DOES NOT ensure that all pending events are serviced to completion (i.e timers, any ScheduleWork calls).
*
*/
void DrainAndServiceIO(System::Clock::Timeout maxWait = chip::System::Clock::Seconds16(5))
{
auto & impl = mTransportManager.GetTransport().GetImplAtIndex<0>();
System::Clock::Timestamp startTime = System::SystemClock().GetMonotonicTimestamp();

while (impl.HasPendingMessages())
{
mIOContext.DriveIO();
if ((System::SystemClock().GetMonotonicTimestamp() - startTime) >= maxWait)
{
break;
}
}
}

static int Initialize(void * context)
{
auto * ctx = static_cast<AppContext *>(context);
return ctx->Init() == CHIP_NO_ERROR ? SUCCESS : FAILURE;
}

static int InitializeAsync(void * context)
{
auto * ctx = static_cast<AppContext *>(context);

VerifyOrReturnError(ctx->Init() == CHIP_NO_ERROR, FAILURE);
ctx->EnableAsyncDispatch();

return SUCCESS;
}

static int Finalize(void * context)
{
auto * ctx = static_cast<AppContext *>(context);
Expand Down
12 changes: 11 additions & 1 deletion src/controller/tests/data_model/TestCommands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ void TestCommandInteraction::TestDataResponse(nlTestSuite * apSuite, void * apCo
chip::Controller::InvokeCommandRequest<TestCluster::Commands::TestStructArrayArgumentResponse::DecodableType>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
}
Expand Down Expand Up @@ -231,6 +233,8 @@ void TestCommandInteraction::TestSuccessNoDataResponse(nlTestSuite * apSuite, vo
chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb,
onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled && statusCheck);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
}
Expand Down Expand Up @@ -263,6 +267,8 @@ void TestCommandInteraction::TestFailure(nlTestSuite * apSuite, void * apContext
chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb,
onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, !onSuccessWasCalled && onFailureWasCalled && statusCheck);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
}
Expand Down Expand Up @@ -296,6 +302,8 @@ void TestCommandInteraction::TestSuccessNoDataResponseWithClusterStatus(nlTestSu
chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb,
onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, onSuccessWasCalled && !onFailureWasCalled && statusCheck);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
}
Expand Down Expand Up @@ -329,6 +337,8 @@ void TestCommandInteraction::TestFailureWithClusterStatus(nlTestSuite * apSuite,
chip::Controller::InvokeCommandRequest(&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, request, onSuccessCb,
onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, !onSuccessWasCalled && onFailureWasCalled && statusCheck);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
}
Expand All @@ -350,7 +360,7 @@ nlTestSuite sSuite =
{
"TestCommands",
&sTests[0],
TestContext::Initialize,
TestContext::InitializeAsync,
TestContext::Finalize
};
// clang-format on
Expand Down
12 changes: 11 additions & 1 deletion src/controller/tests/data_model/TestRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ void TestReadInteraction::TestDataResponse(nlTestSuite * apSuite, void * apConte
chip::Controller::ReadAttribute<TestCluster::Attributes::ListStructOctetString::TypeInfo>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();
chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run();
ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, onSuccessCbInvoked && !onFailureCbInvoked);
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0);
Expand Down Expand Up @@ -177,7 +179,9 @@ void TestReadInteraction::TestAttributeError(nlTestSuite * apSuite, void * apCon
chip::Controller::ReadAttribute<TestCluster::Attributes::ListStructOctetString::TypeInfo>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();
chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run();
ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked);
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0);
Expand Down Expand Up @@ -210,11 +214,15 @@ void TestReadInteraction::TestReadTimeout(nlTestSuite * apSuite, void * apContex
chip::Controller::ReadAttribute<TestCluster::Attributes::ListStructOctetString::TypeInfo>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 1);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 2);

ctx.GetExchangeManager().OnConnectionExpired(ctx.GetSessionBobToAlice());

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked);
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveReadClients() == 0);

Expand All @@ -223,7 +231,9 @@ void TestReadInteraction::TestReadTimeout(nlTestSuite * apSuite, void * apContex
//
// NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 1);

ctx.DrainAndServiceIO();
chip::app::InteractionModelEngine::GetInstance()->GetReportingEngine().Run();
ctx.DrainAndServiceIO();

ctx.GetExchangeManager().OnConnectionExpired(ctx.GetSessionAliceToBob());

Expand All @@ -250,7 +260,7 @@ nlTestSuite sSuite =
{
"TestRead",
&sTests[0],
TestContext::Initialize,
TestContext::InitializeAsync,
TestContext::Finalize
};
// clang-format on
Expand Down
6 changes: 5 additions & 1 deletion src/controller/tests/data_model/TestWrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ void TestWriteInteraction::TestDataResponse(nlTestSuite * apSuite, void * apCont
chip::Controller::WriteAttribute<TestCluster::Attributes::ListStructOctetString::TypeInfo>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, value, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, onSuccessCbInvoked && !onFailureCbInvoked);
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveWriteHandlers() == 0);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
Expand Down Expand Up @@ -190,6 +192,8 @@ void TestWriteInteraction::TestAttributeError(nlTestSuite * apSuite, void * apCo
chip::Controller::WriteAttribute<TestCluster::Attributes::ListStructOctetString::TypeInfo>(
&ctx.GetExchangeManager(), sessionHandle, kTestEndpointId, value, onSuccessCb, onFailureCb);

ctx.DrainAndServiceIO();

NL_TEST_ASSERT(apSuite, !onSuccessCbInvoked && onFailureCbInvoked);
NL_TEST_ASSERT(apSuite, chip::app::InteractionModelEngine::GetInstance()->GetNumActiveWriteHandlers() == 0);
NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0);
Expand All @@ -209,7 +213,7 @@ nlTestSuite sSuite =
{
"TestWrite",
&sTests[0],
TestContext::Initialize,
TestContext::InitializeAsync,
TestContext::Finalize
};
// clang-format on
Expand Down
3 changes: 3 additions & 0 deletions src/transport/TransportMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class TransportMgr : public TransportMgrBase

private:
Transport::Tuple<TransportTypes...> mTransport;

public:
auto & GetTransport() { return mTransport; }
};

} // namespace chip
7 changes: 7 additions & 0 deletions src/transport/raw/Tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ class Tuple : public Base
CHIP_ERROR InitImpl(RawTransportDelegate * delegate) { return CHIP_NO_ERROR; }

std::tuple<TransportTypes...> mTransports;

public:
template <size_t i>
auto GetImplAtIndex() -> decltype(std::get<i>(mTransports)) &
{
return std::get<i>(mTransports);
}
};

} // namespace Transport
Expand Down
52 changes: 51 additions & 1 deletion src/transport/raw/tests/NetworkTestHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <transport/raw/PeerAddress.h>

#include <nlbyteorder.h>
#include <queue>

namespace chip {
namespace Test {
Expand Down Expand Up @@ -63,6 +64,32 @@ class LoopbackTransport : public Transport::Base
/// Transports are required to have a constructor that takes exactly one argument
CHIP_ERROR Init(const char *) { return CHIP_NO_ERROR; }

/*
* For unit-tests that simulate end-to-end transmission and reception of messages in loopback mode,
* this mode better replicates a real-functioning stack that correctly handles the processing
* of a transmitted message as an asynchronous, bottom half handler dispatched after the current execution context has
* completed. This is achieved using SystemLayer::ScheduleWork.
*/
void EnableAsyncDispatch(System::Layer * aSystemLayer)
{
mSystemLayer = aSystemLayer;
mAsyncMessageDispatch = true;
}

bool HasPendingMessages() { return !mPendingMessageQueue.empty(); }

static void OnMessageReceived(System::Layer * aSystemLayer, void * aAppState)
{
LoopbackTransport * _this = static_cast<LoopbackTransport *>(aAppState);

while (!_this->mPendingMessageQueue.empty())
{
auto item = std::move(_this->mPendingMessageQueue.front());
_this->mPendingMessageQueue.pop();
_this->HandleMessageReceived(item.mDestinationAddress, std::move(item.mPendingMessage));
}
}

CHIP_ERROR SendMessage(const Transport::PeerAddress & address, System::PacketBufferHandle && msgBuf) override
{
ReturnErrorOnFailure(mMessageSendError);
Expand All @@ -71,7 +98,16 @@ class LoopbackTransport : public Transport::Base
if (mNumMessagesToDrop == 0)
{
System::PacketBufferHandle receivedMessage = msgBuf.CloneData();
HandleMessageReceived(address, std::move(receivedMessage));

if (mAsyncMessageDispatch)
{
mPendingMessageQueue.push(PendingMessageItem(address, std::move(receivedMessage)));
mSystemLayer->ScheduleWork(OnMessageReceived, this);
}
else
{
HandleMessageReceived(address, std::move(receivedMessage));
}
}
else
{
Expand All @@ -93,9 +129,23 @@ class LoopbackTransport : public Transport::Base
mMessageSendError = CHIP_NO_ERROR;
}

struct PendingMessageItem
{
PendingMessageItem(const Transport::PeerAddress destinationAddress, System::PacketBufferHandle && pendingMessage) :
mDestinationAddress(destinationAddress), mPendingMessage(std::move(pendingMessage))
{}

const Transport::PeerAddress mDestinationAddress;
System::PacketBufferHandle mPendingMessage;
};

// Hook for subclasses to perform custom logic on message drops.
virtual void MessageDropped() {}

System::Layer * mSystemLayer = nullptr;
bool mAsyncMessageDispatch = false;
std::queue<PendingMessageItem> mPendingMessageQueue;
Transport::PeerAddress mTxAddress;
uint32_t mNumMessagesToDrop = 0;
uint32_t mDroppedMessageCount = 0;
uint32_t mSentMessageCount = 0;
Expand Down

0 comments on commit 1465445

Please sign in to comment.