diff --git a/src/app/CommandHandler.cpp b/src/app/CommandHandler.cpp index c2ada186d9e53d..26cdc6364f229b 100644 --- a/src/app/CommandHandler.cpp +++ b/src/app/CommandHandler.cpp @@ -65,25 +65,19 @@ CHIP_ERROR CommandHandler::AllocateBuffer() CHIP_ERROR CommandHandler::OnInvokeCommandRequest(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) { - CHIP_ERROR err = CHIP_NO_ERROR; System::PacketBufferHandle response; - VerifyOrReturnError(mState == CommandState::Idle, CHIP_ERROR_INCORRECT_STATE); // NOTE: we already know this is an InvokeCommand Request message because we explicitly registered with the // Exchange Manager for unsolicited InvokeCommand Requests. - mpExchangeCtx = ec; - err = ProcessInvokeRequest(std::move(payload)); - SuccessOrExit(err); - - err = SendCommandResponse(); - SuccessOrExit(err); + // Use the RAII feature, if this is the only Handle when this function returns, DecRef will trigger sending response. + Handle workHandle(this); + ReturnErrorOnFailure(ProcessInvokeRequest(std::move(payload))); + mpExchangeCtx->WillSendMessage(); -exit: - Close(); - return err; + return CHIP_NO_ERROR; } CHIP_ERROR CommandHandler::ProcessInvokeRequest(System::PacketBufferHandle && payload) @@ -124,6 +118,10 @@ void CommandHandler::Close() { mSuppressResponse = false; MoveToState(CommandState::AwaitingDestruction); + // We must finish all async work before we can shut down a CommandHandler. The actual CommandHandler MUST finish their work in + // reasonable time or there is a bug. + VerifyOrDieWithMsg(mRefCount == 0, DataManagement, "CommandHandler::Close() called with %zu unfinished async work items", + mRefCount); Command::Close(); @@ -133,10 +131,37 @@ void CommandHandler::Close() } } +void CommandHandler::IncRef() +{ + mRefCount++; +} + +void CommandHandler::DecRef() +{ + mRefCount--; + ChipLogDetail(DataManagement, "Decreasing reference count for CommandHandler, remaining %zu", mRefCount); + if (mRefCount != 0) + { + return; + } + CHIP_ERROR err = SendCommandResponse(); + if (err != CHIP_NO_ERROR) + { + ChipLogError(DataManagement, "Failed to send command response: %s", err.AsString()); + // We marked the exchange as "WillSendMessage", need to shutdown the exchange manually to avoid leaking exchanges. + if (mpExchangeCtx != nullptr) + { + mpExchangeCtx->Close(); + } + } + Close(); +} + CHIP_ERROR CommandHandler::SendCommandResponse() { System::PacketBufferHandle commandPacket; + VerifyOrReturnError(mRefCount == 0, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mState == CommandState::AddedCommand, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/app/CommandHandler.h b/src/app/CommandHandler.h index 3bcd67df051f96..e592f8427ca09b 100644 --- a/src/app/CommandHandler.h +++ b/src/app/CommandHandler.h @@ -59,6 +59,53 @@ class CommandHandler : public Command virtual void OnDone(CommandHandler * apCommandObj) = 0; }; + class Handle + { + public: + Handle() {} + Handle(const Handle & handle) = delete; + Handle(Handle && handle) + { + mpHandler = handle.mpHandler; + handle.mpHandler = nullptr; + } + Handle(decltype(nullptr)) {} + Handle(CommandHandler * handle) + { + handle->IncRef(); + mpHandler = handle; + } + ~Handle() { Release(); } + + Handle & operator=(Handle && handle) + { + Release(); + mpHandler = handle.mpHandler; + handle.mpHandler = nullptr; + return *this; + } + + Handle & operator=(decltype(nullptr)) + { + Release(); + return *this; + } + + CommandHandler * operator->() { return mpHandler; } + + void Release() + { + if (mpHandler != nullptr) + { + mpHandler->DecRef(); + mpHandler = nullptr; + } + } + + private: + CommandHandler * mpHandler = nullptr; + }; + /* * Constructor. * @@ -80,6 +127,8 @@ class CommandHandler : public Command CHIP_ERROR AddClusterSpecificFailure(const ConcreteCommandPath & aCommandPath, ClusterStatus aClusterStatus) override; + size_t RefCount() const { return mRefCount; } + CHIP_ERROR ProcessInvokeRequest(System::PacketBufferHandle && payload); CHIP_ERROR PrepareCommand(const CommandPathParams & aCommandPathParams, bool aStartDataStruct = true); CHIP_ERROR FinishCommand(bool aStartDataStruct = true); @@ -110,6 +159,22 @@ class CommandHandler : public Command private: friend class TestCommandInteraction; + friend class CommandHandler::Handle; + + /** + * IncRef will increase the inner refcount of the CommandHandler. + * + * Users should use CommandHandler::Handle for management the lifespan of the CommandHandler. + * DefRef should be released in reasonable time, and Close() should only be called when the refcount reached 0. + */ + void IncRef(); + + /** + * DefRef is used by CommandHandler::Handle for decreasing the refcount of the CommandHandler. + * When refcount reached 0, CommandHandler will send the response to the peer and shutdown. + */ + void DecRef(); + /* * Allocates a packet buffer used for encoding an invoke response payload. * @@ -134,6 +199,7 @@ class CommandHandler : public Command Callback * mpCallback = nullptr; InvokeResponseMessage::Builder mInvokeResponseBuilder; TLV::TLVType mDataElementContainerType = TLV::kTLVType_NotSpecified; + size_t mRefCount = 0; bool mSuppressResponse = false; bool mTimedRequest = false; }; diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index b248171eb30427..836eba68419e98 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -52,6 +52,7 @@ namespace { bool isCommandDispatched = false; bool sendResponse = true; +bool asyncCommand = false; constexpr EndpointId kTestEndpointId = 1; constexpr ClusterId kTestClusterId = 3; @@ -61,6 +62,9 @@ constexpr CommandId kTestNonExistCommandId = 0; } // namespace namespace app { + +CommandHandler::Handle asyncCommandHandle; + bool ServerClusterCommandExists(const ConcreteCommandPath & aCommandPath) { // Mock cluster catalog, only support one command on one cluster on one endpoint. @@ -75,6 +79,12 @@ void DispatchSingleClusterCommand(const ConcreteCommandPath & aCommandPath, chip "Received Cluster Command: Endpoint=%" PRIx16 " Cluster=" ChipLogFormatMEI " Command=" ChipLogFormatMEI, aCommandPath.mEndpointId, ChipLogValueMEI(aCommandPath.mClusterId), ChipLogValueMEI(aCommandPath.mCommandId)); + if (asyncCommand) + { + asyncCommandHandle = apCommandObj; + asyncCommand = false; + } + if (sendResponse) { if (aCommandPath.mCommandId == kTestCommandId) @@ -162,6 +172,7 @@ class TestCommandInteraction static void TestCommandHandlerWithProcessReceivedEmptyDataMsg(nlTestSuite * apSuite, void * apContext); static void TestCommandSenderCommandSuccessResponseFlow(nlTestSuite * apSuite, void * apContext); + static void TestCommandSenderCommandAsyncSuccessResponseFlow(nlTestSuite * apSuite, void * apContext); static void TestCommandSenderCommandFailureResponseFlow(nlTestSuite * apSuite, void * apContext); static void TestCommandSenderCommandSpecificResponseFlow(nlTestSuite * apSuite, void * apContext); @@ -588,6 +599,36 @@ void TestCommandInteraction::TestCommandSenderCommandSuccessResponseFlow(nlTestS NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); } +void TestCommandInteraction::TestCommandSenderCommandAsyncSuccessResponseFlow(nlTestSuite * apSuite, void * apContext) +{ + TestContext & ctx = *static_cast(apContext); + CHIP_ERROR err = CHIP_NO_ERROR; + + mockCommandSenderDelegate.ResetCounter(); + app::CommandSender commandSender(&mockCommandSenderDelegate, &ctx.GetExchangeManager()); + + AddInvokeRequestData(apSuite, apContext, &commandSender); + asyncCommand = true; + err = commandSender.SendCommandRequest(ctx.GetSessionBobToAlice()); + + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(apSuite, + mockCommandSenderDelegate.onResponseCalledTimes == 0 && mockCommandSenderDelegate.onFinalCalledTimes == 0 && + mockCommandSenderDelegate.onErrorCalledTimes == 0); + + NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 1); + NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 2); + + // Decrease CommandHandler refcount and send response + asyncCommandHandle = nullptr; + NL_TEST_ASSERT(apSuite, + mockCommandSenderDelegate.onResponseCalledTimes == 1 && mockCommandSenderDelegate.onFinalCalledTimes == 1 && + mockCommandSenderDelegate.onErrorCalledTimes == 0); + + NL_TEST_ASSERT(apSuite, GetNumActiveHandlerObjects() == 0); + NL_TEST_ASSERT(apSuite, ctx.GetExchangeManager().GetNumActiveExchanges() == 0); +} + void TestCommandInteraction::TestCommandSenderCommandSpecificResponseFlow(nlTestSuite * apSuite, void * apContext) { TestContext & ctx = *static_cast(apContext); @@ -687,6 +728,7 @@ const nlTest sTests[] = NL_TEST_DEF("TestCommandHandlerWithProcessReceivedEmptyDataMsg", chip::app::TestCommandInteraction::TestCommandHandlerWithProcessReceivedEmptyDataMsg), NL_TEST_DEF("TestCommandSenderCommandSuccessResponseFlow", chip::app::TestCommandInteraction::TestCommandSenderCommandSuccessResponseFlow), + NL_TEST_DEF("TestCommandSenderCommandAsyncSuccessResponseFlow", chip::app::TestCommandInteraction::TestCommandSenderCommandAsyncSuccessResponseFlow), NL_TEST_DEF("TestCommandSenderCommandSpecificResponseFlow", chip::app::TestCommandInteraction::TestCommandSenderCommandSpecificResponseFlow), NL_TEST_DEF("TestCommandSenderCommandFailureResponseFlow", chip::app::TestCommandInteraction::TestCommandSenderCommandFailureResponseFlow), NL_TEST_DEF("TestCommandSenderAbruptDestruction", chip::app::TestCommandInteraction::TestCommandSenderAbruptDestruction),