From 2358740ee14d2d8388f7b0100c592cb833c2ad61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damian=20Kr=C3=B3lik?= <66667989+Damian-Nordic@users.noreply.github.com> Date: Wed, 14 Dec 2022 19:00:35 +0100 Subject: [PATCH] Simplify CASEClient initialization code (#24079) * Remove CASEClientInitParams from CASEClient to save RAM It is only needed in EstablishSession so there is no point in keeping it as a class member. * Devirtualize CASEServer::GetSession Overriding it in unit tests does not seem to change the test results. * Merge DeviceProxyInitParams with CASEClientInitParams Both structures are almost the same and as we tend to pass more and more interfaces down the stack, translating between all the different structures becomes cumbersome. --- src/app/CASEClient.cpp | 19 +++++----- src/app/CASEClient.h | 22 ++++++++---- src/app/CASEClientPool.h | 4 +-- src/app/CASESessionManager.cpp | 4 +-- src/app/CASESessionManager.h | 3 +- src/app/OperationalSessionSetup.cpp | 12 +++---- src/app/OperationalSessionSetup.h | 35 +++---------------- src/app/OperationalSessionSetupPool.h | 10 +++--- src/app/server/Server.cpp | 4 +-- src/app/tests/TestOperationalDeviceProxy.cpp | 2 +- .../CHIPDeviceControllerFactory.cpp | 6 ++-- src/protocols/secure_channel/CASEServer.h | 2 +- .../secure_channel/tests/TestCASESession.cpp | 11 +----- 13 files changed, 53 insertions(+), 81 deletions(-) diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index f514b3aa4c089c..f6fa82c686e827 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -19,21 +19,20 @@ namespace chip { -CASEClient::CASEClient(const CASEClientInitParams & params) : mInitParams(params) {} - void CASEClient::SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig) { mCASESession.SetRemoteMRPConfig(remoteMRPConfig); } -CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress, +CHIP_ERROR CASEClient::EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer, + const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate) { - VerifyOrReturnError(mInitParams.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(params.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT); // Create a UnauthenticatedSession for CASE pairing. - Optional session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig); + Optional session = params.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); // Allocate the exchange immediately before calling CASESession::EstablishSession. @@ -42,13 +41,13 @@ CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transpo // free it on error, but can only do this if it is actually called. // Allocating the exchange context right before calling EstablishSession // ensures that if allocation succeeds, CASESession has taken ownership. - Messaging::ExchangeContext * exchange = mInitParams.exchangeMgr->NewContext(session.Value(), &mCASESession); + Messaging::ExchangeContext * exchange = params.exchangeMgr->NewContext(session.Value(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); - mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); - ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, mInitParams.fabricTable, peer, exchange, - mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy, - delegate, mInitParams.mrpLocalConfig)); + mCASESession.SetGroupDataProvider(params.groupDataProvider); + ReturnErrorOnFailure(mCASESession.EstablishSession(*params.sessionManager, params.fabricTable, peer, exchange, + params.sessionResumptionStorage, params.certificateValidityPolicy, delegate, + params.mrpLocalConfig)); return CHIP_NO_ERROR; } diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 33dfad16bab0cf..3a5aa8ded7ff08 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -34,23 +34,31 @@ struct CASEClientInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; FabricTable * fabricTable = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; + Optional mrpLocalConfig = Optional::Missing(); - Optional mrpLocalConfig = Optional::Missing(); + CHIP_ERROR Validate() const + { + // sessionResumptionStorage can be nullptr when resumption is disabled. + // certificateValidityPolicy is optional, too. + ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); + ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); + + return CHIP_NO_ERROR; + } }; class DLL_EXPORT CASEClient { public: - CASEClient(const CASEClientInitParams & params); - void SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig); - CHIP_ERROR EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress, - const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate); + CHIP_ERROR EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer, + const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig, + SessionEstablishmentDelegate * delegate); private: - CASEClientInitParams mInitParams; - CASESession mCASESession; }; diff --git a/src/app/CASEClientPool.h b/src/app/CASEClientPool.h index f44d487771c8fc..41b372cd030e5c 100644 --- a/src/app/CASEClientPool.h +++ b/src/app/CASEClientPool.h @@ -25,7 +25,7 @@ namespace chip { class CASEClientPoolDelegate { public: - virtual CASEClient * Allocate(CASEClientInitParams params) = 0; + virtual CASEClient * Allocate() = 0; virtual void Release(CASEClient * client) = 0; @@ -38,7 +38,7 @@ class CASEClientPool : public CASEClientPoolDelegate public: ~CASEClientPool() override { mClientPool.ReleaseAll(); } - CASEClient * Allocate(CASEClientInitParams params) override { return mClientPool.CreateObject(params); } + CASEClient * Allocate() override { return mClientPool.CreateObject(); } void Release(CASEClient * client) override { mClientPool.ReleaseObject(client); } diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 447d7a663d388f..9d4f3814943e42 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -41,7 +41,7 @@ void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Cal { ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalSessionSetup instance found"); - session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this); + session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this); if (session == nullptr) { @@ -83,7 +83,7 @@ void CASESessionManager::UpdatePeerAddress(ScopedNodeId peerId) { ChipLogDetail(CASESessionManager, "UpdatePeerAddress: No existing OperationalSessionSetup instance found"); - session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this); + session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this); if (session == nullptr) { ChipLogDetail(CASESessionManager, "UpdatePeerAddress: Failed to allocate OperationalSessionSetup instance"); diff --git a/src/app/CASESessionManager.h b/src/app/CASESessionManager.h index ddc971a1e7de86..1e901478aaf6b7 100644 --- a/src/app/CASESessionManager.h +++ b/src/app/CASESessionManager.h @@ -36,7 +36,8 @@ class OperationalSessionSetupPoolDelegate; struct CASESessionManagerConfig { - DeviceProxyInitParams sessionInitParams; + CASEClientInitParams sessionInitParams; + CASEClientPoolDelegate * clientPool = nullptr; OperationalSessionSetupPoolDelegate * sessionSetupPool = nullptr; }; diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 38aa12566c59b8..2997c877c4beee 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -221,12 +221,10 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config) { - mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{ - mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy, - mInitParams.exchangeMgr, mFabricTable, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); + mCASEClient = mClientPool->Allocate(); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); - CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, config, this); + CHIP_ERROR err = mCASEClient->EstablishSession(mInitParams, mPeerId, mDeviceAddress, config, this); if (err != CHIP_NO_ERROR) { CleanupCASEClient(); @@ -330,7 +328,7 @@ void OperationalSessionSetup::CleanupCASEClient() { if (mCASEClient) { - mInitParams.clientPool->Release(mCASEClient); + mClientPool->Release(mCASEClient); mCASEClient = nullptr; } } @@ -364,7 +362,7 @@ OperationalSessionSetup::~OperationalSessionSetup() if (mCASEClient) { // Make sure we don't leak it. - mInitParams.clientPool->Release(mCASEClient); + mClientPool->Release(mCASEClient); } } @@ -382,7 +380,7 @@ CHIP_ERROR OperationalSessionSetup::LookupPeerAddress() return CHIP_NO_ERROR; } - auto const * fabricInfo = mFabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex()); + auto const * fabricInfo = mInitParams.fabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex()); VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INVALID_FABRIC_INDEX); PeerId peerId(fabricInfo->GetCompressedFabricId(), mPeerId.GetNodeId()); diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index cde2fddc6dac0b..6ed951f46cf654 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -45,31 +45,6 @@ namespace chip { -struct DeviceProxyInitParams -{ - SessionManager * sessionManager = nullptr; - SessionResumptionStorage * sessionResumptionStorage = nullptr; - Credentials::CertificateValidityPolicy * certificateValidityPolicy = nullptr; - Messaging::ExchangeManager * exchangeMgr = nullptr; - FabricTable * fabricTable = nullptr; - CASEClientPoolDelegate * clientPool = nullptr; - Credentials::GroupDataProvider * groupDataProvider = nullptr; - - Optional mrpLocalConfig = Optional::Missing(); - - CHIP_ERROR Validate() const - { - ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); - // sessionResumptionStorage can be nullptr when resumption is disabled - ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE); - - return CHIP_NO_ERROR; - } -}; - class OperationalSessionSetup; /** @@ -171,20 +146,20 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate, public: ~OperationalSessionSetup() override; - OperationalSessionSetup(DeviceProxyInitParams & params, ScopedNodeId peerId, + OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) : mSecureSession(*this) { mInitParams = params; - if (params.Validate() != CHIP_NO_ERROR || releaseDelegate == nullptr) + if (params.Validate() != CHIP_NO_ERROR || clientPool == nullptr || releaseDelegate == nullptr) { mState = State::Uninitialized; return; } + mClientPool = clientPool; mSystemLayer = params.exchangeMgr->GetSessionManager()->SystemLayer(); mPeerId = peerId; - mFabricTable = params.fabricTable; mReleaseDelegate = releaseDelegate; mState = State::NeedsAddress; mAddressLookupHandle.SetListener(this); @@ -260,8 +235,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate, SecureConnected, // CASE session established. }; - DeviceProxyInitParams mInitParams; - FabricTable * mFabricTable = nullptr; + CASEClientInitParams mInitParams; + CASEClientPoolDelegate * mClientPool = nullptr; System::Layer * mSystemLayer; // mCASEClient is only non-null if we are in State::Connecting or just diff --git a/src/app/OperationalSessionSetupPool.h b/src/app/OperationalSessionSetupPool.h index 50d1fbd1567bca..8f40b37ebf5c4b 100644 --- a/src/app/OperationalSessionSetupPool.h +++ b/src/app/OperationalSessionSetupPool.h @@ -27,8 +27,8 @@ namespace chip { class OperationalSessionSetupPoolDelegate { public: - virtual OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId, - OperationalSessionReleaseDelegate * releaseDelegate) = 0; + virtual OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, + ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) = 0; virtual void Release(OperationalSessionSetup * device) = 0; @@ -47,10 +47,10 @@ class OperationalSessionSetupPool : public OperationalSessionSetupPoolDelegate public: ~OperationalSessionSetupPool() override { mSessionSetupPool.ReleaseAll(); } - OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId, - OperationalSessionReleaseDelegate * releaseDelegate) override + OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, + ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) override { - return mSessionSetupPool.CreateObject(params, peerId, releaseDelegate); + return mSessionSetupPool.CreateObject(params, clientPool, peerId, releaseDelegate); } void Release(OperationalSessionSetup * device) override { mSessionSetupPool.ReleaseObject(device); } diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 3705cb47d3212c..b6c0d571d9e727 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -290,11 +290,11 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) .certificateValidityPolicy = mCertificateValidityPolicy, .exchangeMgr = &mExchangeMgr, .fabricTable = &mFabrics, - .clientPool = &mCASEClientPool, .groupDataProvider = mGroupsProvider, .mrpLocalConfig = GetLocalMRPConfig(), }, - .sessionSetupPool = &mSessionSetupPool, + .clientPool = &mCASEClientPool, + .sessionSetupPool = &mSessionSetupPool, }; err = mCASESessionManager.Init(&DeviceLayer::SystemLayer(), caseSessionManagerConfig); diff --git a/src/app/tests/TestOperationalDeviceProxy.cpp b/src/app/tests/TestOperationalDeviceProxy.cpp index 5d6f928fdc4d7d..3b8252323f920a 100644 --- a/src/app/tests/TestOperationalDeviceProxy.cpp +++ b/src/app/tests/TestOperationalDeviceProxy.cpp @@ -69,7 +69,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, VerifyOrDie(groupDataProvider.Init() == CHIP_NO_ERROR); // TODO: Set IPK in groupDataProvider - DeviceProxyInitParams params = { + CASEClientInitParams params = { .sessionManager = &sessionManager, .sessionResumptionStorage = &sessionResumptionStorage, .exchangeMgr = &exchangeMgr, diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 84a7c4388525bd..25457368d945a3 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -245,18 +245,18 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) stateParams.sessionSetupPool = Platform::New(); stateParams.caseClientPool = Platform::New(); - DeviceProxyInitParams deviceInitParams = { + CASEClientInitParams sessionInitParams = { .sessionManager = stateParams.sessionMgr, .sessionResumptionStorage = stateParams.sessionResumptionStorage.get(), .exchangeMgr = stateParams.exchangeMgr, .fabricTable = stateParams.fabricTable, - .clientPool = stateParams.caseClientPool, .groupDataProvider = stateParams.groupDataProvider, .mrpLocalConfig = GetLocalMRPConfig(), }; CASESessionManagerConfig sessionManagerConfig = { - .sessionInitParams = deviceInitParams, + .sessionInitParams = sessionInitParams, + .clientPool = stateParams.caseClientPool, .sessionSetupPool = stateParams.sessionSetupPool, }; diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index 8c8f79f547fc7f..894d496c93bfac 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -69,7 +69,7 @@ class CASEServer : public SessionEstablishmentDelegate, void OnResponseTimeout(Messaging::ExchangeContext * ec) override {} Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return GetSession().GetMessageDispatch(); } - virtual CASESession & GetSession() { return mPairingSession; } + CASESession & GetSession() { return mPairingSession; } private: Messaging::ExchangeManager * mExchangeManager = nullptr; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 8af555796e5bec..bcfe98f1d57af0 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -124,15 +124,6 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate uint32_t mNumPairingComplete = 0; }; -class CASEServerForTest : public CASEServer -{ -public: - CASESession & GetSession() override { return mCaseSession; } - -private: - CASESession mCaseSession; -}; - class TestOperationalKeystore : public chip::Crypto::OperationalKeystore { public: @@ -469,7 +460,7 @@ void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * i SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, delegateCommissioner); } -CASEServerForTest gPairingServer; +CASEServer gPairingServer; void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext) {