Skip to content

Commit

Permalink
Simplify CASEClient initialization code (#24079)
Browse files Browse the repository at this point in the history
* Remove CASEClientInitParams from CASEClient to save RAM

It is only needed in EstablishSession so there is no point
in keeping it as a class member.

* Devirtualize CASEServer::GetSession

Overriding it in unit tests does not seem to change the test
results.

* Merge DeviceProxyInitParams with CASEClientInitParams

Both structures are almost the same and as we tend to pass
more and more interfaces down the stack, translating between
all the different structures becomes cumbersome.
  • Loading branch information
Damian-Nordic authored and pull[bot] committed Nov 3, 2023
1 parent e1ee016 commit 3424641
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 81 deletions.
19 changes: 9 additions & 10 deletions src/app/CASEClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,20 @@

namespace chip {

CASEClient::CASEClient(const CASEClientInitParams & params) : mInitParams(params) {}

void CASEClient::SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig)
{
mCASESession.SetRemoteMRPConfig(remoteMRPConfig);
}

CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress,
CHIP_ERROR CASEClient::EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer,
const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & remoteMRPConfig,
SessionEstablishmentDelegate * delegate)
{
VerifyOrReturnError(mInitParams.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(params.fabricTable != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

// Create a UnauthenticatedSession for CASE pairing.
Optional<SessionHandle> session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig);
Optional<SessionHandle> session = params.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig);
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);

// Allocate the exchange immediately before calling CASESession::EstablishSession.
Expand All @@ -42,13 +41,13 @@ CHIP_ERROR CASEClient::EstablishSession(const ScopedNodeId & peer, const Transpo
// free it on error, but can only do this if it is actually called.
// Allocating the exchange context right before calling EstablishSession
// ensures that if allocation succeeds, CASESession has taken ownership.
Messaging::ExchangeContext * exchange = mInitParams.exchangeMgr->NewContext(session.Value(), &mCASESession);
Messaging::ExchangeContext * exchange = params.exchangeMgr->NewContext(session.Value(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider);
ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, mInitParams.fabricTable, peer, exchange,
mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy,
delegate, mInitParams.mrpLocalConfig));
mCASESession.SetGroupDataProvider(params.groupDataProvider);
ReturnErrorOnFailure(mCASESession.EstablishSession(*params.sessionManager, params.fabricTable, peer, exchange,
params.sessionResumptionStorage, params.certificateValidityPolicy, delegate,
params.mrpLocalConfig));

return CHIP_NO_ERROR;
}
Expand Down
22 changes: 15 additions & 7 deletions src/app/CASEClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,31 @@ struct CASEClientInitParams
Messaging::ExchangeManager * exchangeMgr = nullptr;
FabricTable * fabricTable = nullptr;
Credentials::GroupDataProvider * groupDataProvider = nullptr;
Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();
CHIP_ERROR Validate() const
{
// sessionResumptionStorage can be nullptr when resumption is disabled.
// certificateValidityPolicy is optional, too.
ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE);

return CHIP_NO_ERROR;
}
};

class DLL_EXPORT CASEClient
{
public:
CASEClient(const CASEClientInitParams & params);

void SetRemoteMRPIntervals(const ReliableMessageProtocolConfig & remoteMRPConfig);

CHIP_ERROR EstablishSession(const ScopedNodeId & peer, const Transport::PeerAddress & peerAddress,
const ReliableMessageProtocolConfig & remoteMRPConfig, SessionEstablishmentDelegate * delegate);
CHIP_ERROR EstablishSession(const CASEClientInitParams & params, const ScopedNodeId & peer,
const Transport::PeerAddress & peerAddress, const ReliableMessageProtocolConfig & remoteMRPConfig,
SessionEstablishmentDelegate * delegate);

private:
CASEClientInitParams mInitParams;

CASESession mCASESession;
};

Expand Down
4 changes: 2 additions & 2 deletions src/app/CASEClientPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace chip {
class CASEClientPoolDelegate
{
public:
virtual CASEClient * Allocate(CASEClientInitParams params) = 0;
virtual CASEClient * Allocate() = 0;

virtual void Release(CASEClient * client) = 0;

Expand All @@ -38,7 +38,7 @@ class CASEClientPool : public CASEClientPoolDelegate
public:
~CASEClientPool() override { mClientPool.ReleaseAll(); }

CASEClient * Allocate(CASEClientInitParams params) override { return mClientPool.CreateObject(params); }
CASEClient * Allocate() override { return mClientPool.CreateObject(); }

void Release(CASEClient * client) override { mClientPool.ReleaseObject(client); }

Expand Down
4 changes: 2 additions & 2 deletions src/app/CASESessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void CASESessionManager::FindOrEstablishSession(const ScopedNodeId & peerId, Cal
{
ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalSessionSetup instance found");

session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this);
session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this);

if (session == nullptr)
{
Expand Down Expand Up @@ -83,7 +83,7 @@ void CASESessionManager::UpdatePeerAddress(ScopedNodeId peerId)
{
ChipLogDetail(CASESessionManager, "UpdatePeerAddress: No existing OperationalSessionSetup instance found");

session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, peerId, this);
session = mConfig.sessionSetupPool->Allocate(mConfig.sessionInitParams, mConfig.clientPool, peerId, this);
if (session == nullptr)
{
ChipLogDetail(CASESessionManager, "UpdatePeerAddress: Failed to allocate OperationalSessionSetup instance");
Expand Down
3 changes: 2 additions & 1 deletion src/app/CASESessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class OperationalSessionSetupPoolDelegate;

struct CASESessionManagerConfig
{
DeviceProxyInitParams sessionInitParams;
CASEClientInitParams sessionInitParams;
CASEClientPoolDelegate * clientPool = nullptr;
OperationalSessionSetupPoolDelegate * sessionSetupPool = nullptr;
};

Expand Down
12 changes: 5 additions & 7 deletions src/app/OperationalSessionSetup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,10 @@ void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & ad

CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessageProtocolConfig & config)
{
mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{
mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.certificateValidityPolicy,
mInitParams.exchangeMgr, mFabricTable, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig });
mCASEClient = mClientPool->Allocate();
ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY);

CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, config, this);
CHIP_ERROR err = mCASEClient->EstablishSession(mInitParams, mPeerId, mDeviceAddress, config, this);
if (err != CHIP_NO_ERROR)
{
CleanupCASEClient();
Expand Down Expand Up @@ -330,7 +328,7 @@ void OperationalSessionSetup::CleanupCASEClient()
{
if (mCASEClient)
{
mInitParams.clientPool->Release(mCASEClient);
mClientPool->Release(mCASEClient);
mCASEClient = nullptr;
}
}
Expand Down Expand Up @@ -364,7 +362,7 @@ OperationalSessionSetup::~OperationalSessionSetup()
if (mCASEClient)
{
// Make sure we don't leak it.
mInitParams.clientPool->Release(mCASEClient);
mClientPool->Release(mCASEClient);
}
}

Expand All @@ -382,7 +380,7 @@ CHIP_ERROR OperationalSessionSetup::LookupPeerAddress()
return CHIP_NO_ERROR;
}

auto const * fabricInfo = mFabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex());
auto const * fabricInfo = mInitParams.fabricTable->FindFabricWithIndex(mPeerId.GetFabricIndex());
VerifyOrReturnError(fabricInfo != nullptr, CHIP_ERROR_INVALID_FABRIC_INDEX);

PeerId peerId(fabricInfo->GetCompressedFabricId(), mPeerId.GetNodeId());
Expand Down
35 changes: 5 additions & 30 deletions src/app/OperationalSessionSetup.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,6 @@

namespace chip {

struct DeviceProxyInitParams
{
SessionManager * sessionManager = nullptr;
SessionResumptionStorage * sessionResumptionStorage = nullptr;
Credentials::CertificateValidityPolicy * certificateValidityPolicy = nullptr;
Messaging::ExchangeManager * exchangeMgr = nullptr;
FabricTable * fabricTable = nullptr;
CASEClientPoolDelegate * clientPool = nullptr;
Credentials::GroupDataProvider * groupDataProvider = nullptr;

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();

CHIP_ERROR Validate() const
{
ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE);
// sessionResumptionStorage can be nullptr when resumption is disabled
ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE);

return CHIP_NO_ERROR;
}
};

class OperationalSessionSetup;

/**
Expand Down Expand Up @@ -171,20 +146,20 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate,
public:
~OperationalSessionSetup() override;

OperationalSessionSetup(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) :
mSecureSession(*this)
{
mInitParams = params;
if (params.Validate() != CHIP_NO_ERROR || releaseDelegate == nullptr)
if (params.Validate() != CHIP_NO_ERROR || clientPool == nullptr || releaseDelegate == nullptr)
{
mState = State::Uninitialized;
return;
}

mClientPool = clientPool;
mSystemLayer = params.exchangeMgr->GetSessionManager()->SystemLayer();
mPeerId = peerId;
mFabricTable = params.fabricTable;
mReleaseDelegate = releaseDelegate;
mState = State::NeedsAddress;
mAddressLookupHandle.SetListener(this);
Expand Down Expand Up @@ -260,8 +235,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionDelegate,
SecureConnected, // CASE session established.
};

DeviceProxyInitParams mInitParams;
FabricTable * mFabricTable = nullptr;
CASEClientInitParams mInitParams;
CASEClientPoolDelegate * mClientPool = nullptr;
System::Layer * mSystemLayer;

// mCASEClient is only non-null if we are in State::Connecting or just
Expand Down
10 changes: 5 additions & 5 deletions src/app/OperationalSessionSetupPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace chip {
class OperationalSessionSetupPoolDelegate
{
public:
virtual OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) = 0;
virtual OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool,
ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) = 0;

virtual void Release(OperationalSessionSetup * device) = 0;

Expand All @@ -47,10 +47,10 @@ class OperationalSessionSetupPool : public OperationalSessionSetupPoolDelegate
public:
~OperationalSessionSetupPool() override { mSessionSetupPool.ReleaseAll(); }

OperationalSessionSetup * Allocate(DeviceProxyInitParams & params, ScopedNodeId peerId,
OperationalSessionReleaseDelegate * releaseDelegate) override
OperationalSessionSetup * Allocate(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool,
ScopedNodeId peerId, OperationalSessionReleaseDelegate * releaseDelegate) override
{
return mSessionSetupPool.CreateObject(params, peerId, releaseDelegate);
return mSessionSetupPool.CreateObject(params, clientPool, peerId, releaseDelegate);
}

void Release(OperationalSessionSetup * device) override { mSessionSetupPool.ReleaseObject(device); }
Expand Down
4 changes: 2 additions & 2 deletions src/app/server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,11 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams)
.certificateValidityPolicy = mCertificateValidityPolicy,
.exchangeMgr = &mExchangeMgr,
.fabricTable = &mFabrics,
.clientPool = &mCASEClientPool,
.groupDataProvider = mGroupsProvider,
.mrpLocalConfig = GetLocalMRPConfig(),
},
.sessionSetupPool = &mSessionSetupPool,
.clientPool = &mCASEClientPool,
.sessionSetupPool = &mSessionSetupPool,
};

err = mCASESessionManager.Init(&DeviceLayer::SystemLayer(), caseSessionManagerConfig);
Expand Down
2 changes: 1 addition & 1 deletion src/app/tests/TestOperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite,
VerifyOrDie(groupDataProvider.Init() == CHIP_NO_ERROR);
// TODO: Set IPK in groupDataProvider

DeviceProxyInitParams params = {
CASEClientInitParams params = {
.sessionManager = &sessionManager,
.sessionResumptionStorage = &sessionResumptionStorage,
.exchangeMgr = &exchangeMgr,
Expand Down
6 changes: 3 additions & 3 deletions src/controller/CHIPDeviceControllerFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,18 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params)
stateParams.sessionSetupPool = Platform::New<DeviceControllerSystemStateParams::SessionSetupPool>();
stateParams.caseClientPool = Platform::New<DeviceControllerSystemStateParams::CASEClientPool>();

DeviceProxyInitParams deviceInitParams = {
CASEClientInitParams sessionInitParams = {
.sessionManager = stateParams.sessionMgr,
.sessionResumptionStorage = stateParams.sessionResumptionStorage.get(),
.exchangeMgr = stateParams.exchangeMgr,
.fabricTable = stateParams.fabricTable,
.clientPool = stateParams.caseClientPool,
.groupDataProvider = stateParams.groupDataProvider,
.mrpLocalConfig = GetLocalMRPConfig(),
};

CASESessionManagerConfig sessionManagerConfig = {
.sessionInitParams = deviceInitParams,
.sessionInitParams = sessionInitParams,
.clientPool = stateParams.caseClientPool,
.sessionSetupPool = stateParams.sessionSetupPool,
};

Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/CASEServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CASEServer : public SessionEstablishmentDelegate,
void OnResponseTimeout(Messaging::ExchangeContext * ec) override {}
Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return GetSession().GetMessageDispatch(); }

virtual CASESession & GetSession() { return mPairingSession; }
CASESession & GetSession() { return mPairingSession; }

private:
Messaging::ExchangeManager * mExchangeManager = nullptr;
Expand Down
11 changes: 1 addition & 10 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,6 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate
uint32_t mNumPairingComplete = 0;
};

class CASEServerForTest : public CASEServer
{
public:
CASESession & GetSession() override { return mCaseSession; }

private:
CASESession mCaseSession;
};

class TestOperationalKeystore : public chip::Crypto::OperationalKeystore
{
public:
Expand Down Expand Up @@ -469,7 +460,7 @@ void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * i
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, delegateCommissioner);
}

CASEServerForTest gPairingServer;
CASEServer gPairingServer;

void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext)
{
Expand Down

0 comments on commit 3424641

Please sign in to comment.