From 41a431dcc7c123699818636afd01b57412a6f3a4 Mon Sep 17 00:00:00 2001 From: Michael Sandstedt Date: Sun, 3 Apr 2022 22:54:51 -0500 Subject: [PATCH] Fix Session ID Allocation (#16895) * Fix Session ID Allocation Secure session ID allocation currently suffers the following problems: * fragmentation has worst case behavior where there may be as few as 2 outstanding sessions, but the allocator will become full * there is no formal coupling to the session manager object, but yet there can only be one allocator per session manager; the current solution is for the allocator to share static state * IDs are proposed *to* the session manager, which means the session manager can only prevent collisions by failing on session creation or by evicting sessions; it currently does the latter * session ID allocation is manual, so leaks are likely This commit solves these problems by moving ID allocation into the session manager and by leveraging the session table itself as the source of truth for available IDs. Whereas the old flow was: * allocate session ID * initiate PASE or CASE * (on success) allocate session table entry The new flow is: * allocate session table entry with non-colliding ID * initiate PASE or CASE * activate the session in the table Allocation uses a next-session-ID clue, which is 1 more than the most recent allocation, and also searches the session table to make sure a non-colliding value is returned. Allocation time complexity is O(kMaxSessionCount^2/64) in the current implementation. Lifecycle of the pending session table entries also leverages the SessionHolder object to alleviate our need to manually free resources. Fixes: #7835, #12821 Testing: * Added an allocation test to TestSessionManager * All other code paths are heavily integrated into existing tests * All existing unit tests pass * fix Werror=conversion * fix VerifyOrExit lint error * fix Werror=unused-but-set-variable with detail logging disabled * fix tv-casting-app build * pass SessionHolder by reference to reduce stack size * bypass -Wstack-usage= in TestPASESession to fix spurious stack warning * Update src/transport/SecureSession.h Co-authored-by: Boris Zbarsky * Update src/transport/SecureSessionTable.h Co-authored-by: Boris Zbarsky * per bzbarsky-apple, document that allocation of sessions to caller-specified IDs is for testing * Update src/transport/PairingSession.h Co-authored-by: Boris Zbarsky * per mrjerryjohns, delegate allocation of secure sessions to base PairingSession object * remove use of Optional session IDs Overloads make this superfluous. * init session ID randomly; add more checks for invalid 0 session ID * Update src/transport/PairingSession.h Co-authored-by: Boris Zbarsky * per bzbarsky-apple, pass reference type to SessionHolderWithDelegate * restyle * per kghost, pass SessionHandle, not SessionHolder * make sure to grab the session in NewPairing * fixup doxy params * fix up comments * add a test case to verify the session ID allocator does not have collisions * reduce number of session ID allocator collision test iterations to fix CI timeout * fix loop sentinel in TestSessionManager * increase gcc_debug test phase timeout * increase gcc-debug total timeout to 65 minutes Co-authored-by: Boris Zbarsky --- .github/workflows/build.yaml | 4 +- examples/tv-casting-app/linux/main.cpp | 1 - src/app/CASEClient.cpp | 9 +- src/app/CASEClient.h | 2 - src/app/OperationalDeviceProxy.cpp | 6 +- src/app/OperationalDeviceProxy.h | 3 - src/app/server/CommissioningWindowManager.cpp | 14 +- src/app/server/CommissioningWindowManager.h | 4 - src/app/server/Server.cpp | 2 - src/app/server/Server.h | 4 - src/app/tests/TestOperationalDeviceProxy.cpp | 3 - .../tests/integration/chip_im_initiator.cpp | 1 + .../tests/integration/chip_im_responder.cpp | 1 + src/controller/CHIPDeviceController.cpp | 9 +- .../CHIPDeviceControllerFactory.cpp | 9 +- .../CHIPDeviceControllerSystemState.h | 11 +- src/controller/CommissioneeDeviceProxy.h | 5 - src/messaging/tests/MessagingContext.cpp | 8 + src/messaging/tests/MessagingContext.h | 5 +- src/messaging/tests/echo/echo_requester.cpp | 1 + src/messaging/tests/echo/echo_responder.cpp | 1 + src/protocols/secure_channel/BUILD.gn | 2 - src/protocols/secure_channel/CASEServer.cpp | 5 +- src/protocols/secure_channel/CASEServer.h | 3 - src/protocols/secure_channel/CASESession.cpp | 31 ++-- src/protocols/secure_channel/CASESession.h | 12 +- src/protocols/secure_channel/PASESession.cpp | 94 ++--------- src/protocols/secure_channel/PASESession.h | 97 ++++-------- .../secure_channel/SessionIDAllocator.cpp | 83 ---------- .../secure_channel/SessionIDAllocator.h | 55 ------- src/protocols/secure_channel/tests/BUILD.gn | 1 - .../secure_channel/tests/TestCASESession.cpp | 43 ++--- .../secure_channel/tests/TestPASESession.cpp | 115 ++++---------- .../tests/TestSessionIDAllocator.cpp | 141 ----------------- src/transport/PairingSession.cpp | 16 ++ src/transport/PairingSession.h | 47 ++++-- src/transport/SecureSession.h | 43 ++++- src/transport/SecureSessionTable.h | 141 +++++++++++++++-- src/transport/Session.h | 2 + src/transport/SessionHolder.h | 4 + src/transport/SessionManager.cpp | 46 ++++-- src/transport/SessionManager.h | 23 +++ src/transport/tests/TestSessionManager.cpp | 149 ++++++++++++++++-- 43 files changed, 578 insertions(+), 678 deletions(-) delete mode 100644 src/protocols/secure_channel/SessionIDAllocator.cpp delete mode 100644 src/protocols/secure_channel/SessionIDAllocator.h delete mode 100644 src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 57f3765c2c206a..ece9336a084665 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -26,7 +26,7 @@ concurrency: jobs: build_linux_gcc_debug: name: Build on Linux (gcc_debug) - timeout-minutes: 60 + timeout-minutes: 65 runs-on: ubuntu-latest if: github.actor != 'restyled-io[bot]' @@ -73,7 +73,7 @@ jobs: timeout-minutes: 20 run: scripts/run_in_build_env.sh "ninja -C ./out" - name: Run Tests - timeout-minutes: 10 + timeout-minutes: 15 run: scripts/tests/gn_tests.sh # TODO Log Upload https://github.com/project-chip/connectedhomeip/issues/2227 # TODO https://github.com/project-chip/connectedhomeip/issues/1512 diff --git a/examples/tv-casting-app/linux/main.cpp b/examples/tv-casting-app/linux/main.cpp index 820b5c73857eb3..3aec560c0c06c5 100644 --- a/examples/tv-casting-app/linux/main.cpp +++ b/examples/tv-casting-app/linux/main.cpp @@ -321,7 +321,6 @@ class TargetVideoPlayerInfo chip::DeviceProxyInitParams initParams = { .sessionManager = &(server->GetSecureSessionManager()), .exchangeMgr = &(server->GetExchangeManager()), - .idAllocator = &(server->GetSessionIDAllocator()), .fabricTable = &(server->GetFabricTable()), .clientPool = &gCASEClientPool, }; diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index 32b842a2ba2dee..94666eb4a8d503 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -35,9 +35,6 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres Optional session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, mrpConfig); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); - uint16_t keyID = 0; - ReturnErrorOnFailure(mInitParams.idAllocator->Allocate(keyID)); - // Allocate the exchange immediately before calling CASESession::EstablishSession. // // CASESession::EstablishSession takes ownership of the exchange and will @@ -48,8 +45,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); - ReturnErrorOnFailure(mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this, - mInitParams.mrpLocalConfig)); + ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, peerAddress, mInitParams.fabricInfo, + peer.GetNodeId(), exchange, this, mInitParams.mrpLocalConfig)); mConnectionSuccessCallback = onConnection; mConnectionFailureCallback = onFailure; mConectionContext = context; @@ -61,8 +58,6 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres void CASEClient::OnSessionEstablishmentError(CHIP_ERROR error) { - mInitParams.idAllocator->Free(mCASESession.GetLocalSessionId()); - if (mConnectionFailureCallback) { mConnectionFailureCallback(mConectionContext, this, error); diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 6a1c708fc3a7de..fba48029a0ecca 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -21,7 +21,6 @@ #include #include #include -#include namespace chip { @@ -34,7 +33,6 @@ struct CASEClientInitParams { SessionManager * sessionManager = nullptr; Messaging::ExchangeManager * exchangeMgr = nullptr; - SessionIDAllocator * idAllocator = nullptr; FabricInfo * fabricInfo = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index efc18d11c3a9b5..0ce22499730342 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -167,9 +167,9 @@ bool OperationalDeviceProxy::GetAddress(Inet::IPAddress & addr, uint16_t & port) CHIP_ERROR OperationalDeviceProxy::EstablishConnection() { - mCASEClient = mInitParams.clientPool->Allocate( - CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, mInitParams.idAllocator, mFabricInfo, - mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); + mCASEClient = + mInitParams.clientPool->Allocate(CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, mFabricInfo, + mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this); diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 463f05061716ad..16e4f6e63b6594 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -38,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -51,7 +50,6 @@ struct DeviceProxyInitParams { SessionManager * sessionManager = nullptr; Messaging::ExchangeManager * exchangeMgr = nullptr; - SessionIDAllocator * idAllocator = nullptr; FabricTable * fabricTable = nullptr; CASEClientPoolDelegate * clientPool = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; @@ -62,7 +60,6 @@ struct DeviceProxyInitParams { ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(exchangeMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - ReturnErrorCodeIf(idAllocator == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(fabricTable == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(groupDataProvider == nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorCodeIf(clientPool == nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/app/server/CommissioningWindowManager.cpp b/src/app/server/CommissioningWindowManager.cpp index a6fc8ae7c9f5a3..58b1638ba67c5d 100644 --- a/src/app/server/CommissioningWindowManager.cpp +++ b/src/app/server/CommissioningWindowManager.cpp @@ -176,9 +176,6 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() { VerifyOrReturnError(mCommissioningTimeoutTimerArmed, CHIP_ERROR_INCORRECT_STATE); - uint16_t keyID = 0; - ReturnErrorOnFailure(mIDAllocator->Allocate(keyID)); - mPairingSession.Clear(); ReturnErrorOnFailure(mServer->GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( @@ -188,9 +185,9 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() if (mUseECM) { ReturnErrorOnFailure(SetTemporaryDiscriminator(mECMDiscriminator)); - ReturnErrorOnFailure( - mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), keyID, - Optional::Value(GetLocalMRPConfig()), this)); + ReturnErrorOnFailure(mPairingSession.WaitForPairing( + mServer->GetSecureSessionManager(), mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), + Optional::Value(GetLocalMRPConfig()), this)); } else { @@ -211,8 +208,9 @@ CHIP_ERROR CommissioningWindowManager::AdvertiseAndListenForPASE() ReturnErrorOnFailure(verifier.Deserialize(ByteSpan(serializedVerifier))); - ReturnErrorOnFailure(mPairingSession.WaitForPairing( - verifier, iterationCount, saltSpan, keyID, Optional::Value(GetLocalMRPConfig()), this)); + ReturnErrorOnFailure(mPairingSession.WaitForPairing(mServer->GetSecureSessionManager(), verifier, iterationCount, saltSpan, + Optional::Value(GetLocalMRPConfig()), + this)); } ReturnErrorOnFailure(StartAdvertisement()); diff --git a/src/app/server/CommissioningWindowManager.h b/src/app/server/CommissioningWindowManager.h index fffe78b4ade5cc..6cc557eaf9cd75 100644 --- a/src/app/server/CommissioningWindowManager.h +++ b/src/app/server/CommissioningWindowManager.h @@ -23,7 +23,6 @@ #include #include #include -#include #include namespace chip { @@ -65,8 +64,6 @@ class CommissioningWindowManager : public SessionEstablishmentDelegate, public a void SetAppDelegate(AppDelegate * delegate) { mAppDelegate = delegate; } - void SetSessionIDAllocator(SessionIDAllocator * idAllocator) { mIDAllocator = idAllocator; } - /** * Open the pairing window using default configured parameters. */ @@ -146,7 +143,6 @@ class CommissioningWindowManager : public SessionEstablishmentDelegate, public a bool mIsBLE = true; - SessionIDAllocator * mIDAllocator = nullptr; PASESession mPairingSession; uint8_t mFailedCommissioningAttempts = 0; diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 43a8eeae53ae2a..65ad11086bd036 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -111,7 +111,6 @@ CHIP_ERROR Server::Init(AppDelegate * delegate, uint16_t secureServicePort, uint SuccessOrExit(err = mCommissioningWindowManager.Init(this)); mCommissioningWindowManager.SetAppDelegate(delegate); - mCommissioningWindowManager.SetSessionIDAllocator(&mSessionIDAllocator); // Set up attribute persistence before we try to bring up the data model // handler. @@ -241,7 +240,6 @@ CHIP_ERROR Server::Init(AppDelegate * delegate, uint16_t secureServicePort, uint .sessionInitParams = { .sessionManager = &mSessions, .exchangeMgr = &mExchangeMgr, - .idAllocator = &mSessionIDAllocator, .fabricTable = &mFabrics, .clientPool = &mCASEClientPool, .groupDataProvider = &mGroupsProvider, diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 78adc35bfb292e..afba9c4b17d3a0 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -80,8 +80,6 @@ class Server Messaging::ExchangeManager & GetExchangeManager() { return mExchangeMgr; } - SessionIDAllocator & GetSessionIDAllocator() { return mSessionIDAllocator; } - SessionManager & GetSecureSessionManager() { return mSessions; } TransportMgrBase & GetTransportManager() { return mTransports; } @@ -248,12 +246,10 @@ class Server Messaging::ExchangeManager mExchangeMgr; FabricTable mFabrics; - SessionIDAllocator mSessionIDAllocator; secure_channel::MessageCounterManager mMessageCounterManager; #if CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT chip::Protocols::UserDirectedCommissioning::UserDirectedCommissioningClient gUDCClient; #endif // CHIP_DEVICE_CONFIG_ENABLE_COMMISSIONER_DISCOVERY_CLIENT - SecurePairingUsingTestSecret mTestPairing; CommissioningWindowManager mCommissioningWindowManager; // Both PersistentStorageDelegate, and GroupDataProvider should be injected by the applications diff --git a/src/app/tests/TestOperationalDeviceProxy.cpp b/src/app/tests/TestOperationalDeviceProxy.cpp index cf31ff169520ee..0a32373f57c443 100644 --- a/src/app/tests/TestOperationalDeviceProxy.cpp +++ b/src/app/tests/TestOperationalDeviceProxy.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -56,7 +55,6 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, VerifyOrDie(fabric != nullptr); secure_channel::MessageCounterManager messageCounterManager; chip::TestPersistentStorageDelegate deviceStorage; - SessionIDAllocator idAllocator; GroupDataProviderImpl groupDataProvider; systemLayer.Init(); @@ -72,7 +70,6 @@ void TestOperationalDeviceProxy_EstablishSessionDirectly(nlTestSuite * inSuite, DeviceProxyInitParams params = { .sessionManager = &sessionManager, .exchangeMgr = &exchangeMgr, - .idAllocator = &idAllocator, .fabricInfo = fabric, .groupDataProvider = &groupDataProvider, }; diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index a016509340f082..90bb9caf41027e 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -433,6 +433,7 @@ CHIP_ERROR EstablishSecureSession() chip::SecurePairingUsingTestSecret * testSecurePairingSecret = chip::Platform::New(); VerifyOrExit(testSecurePairingSecret != nullptr, err = CHIP_ERROR_NO_MEMORY); + testSecurePairingSecret->Init(gSessionManager); // Attempt to connect to the peer. err = gSessionManager.NewPairing(gSession, diff --git a/src/app/tests/integration/chip_im_responder.cpp b/src/app/tests/integration/chip_im_responder.cpp index fe793f407c8a0f..b05fb611e4a73a 100644 --- a/src/app/tests/integration/chip_im_responder.cpp +++ b/src/app/tests/integration/chip_im_responder.cpp @@ -197,6 +197,7 @@ int main(int argc, char * argv[]) InitializeEventLogging(&gExchangeManager); + gTestPairing.Init(gSessionManager); err = gSessionManager.NewPairing(gSession, peer, chip::kTestControllerNodeId, &gTestPairing, chip::CryptoContext::SessionRole::kResponder, gFabricIndex); SuccessOrExit(err); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index d7a8ac8301e5ae..9dd4abbae67d81 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -413,7 +413,6 @@ ControllerDeviceInitParams DeviceController::GetControllerDeviceInitParams() .exchangeMgr = mSystemState->ExchangeMgr(), .udpEndPointManager = mSystemState->UDPEndPointManager(), .storageDelegate = mStorageDelegate, - .idAllocator = mSystemState->SessionIDAlloc(), .fabricsTable = mSystemState->Fabrics(), }; } @@ -610,8 +609,7 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re Messaging::ExchangeContext * exchangeCtxt = nullptr; Optional session; - - uint16_t keyID = 0; + SessionHolder secureSessionHolder; VerifyOrExit(mState == State::Initialized, err = CHIP_ERROR_INCORRECT_STATE); VerifyOrExit(mDeviceInPASEEstablishment == nullptr, err = CHIP_ERROR_INCORRECT_STATE); @@ -677,9 +675,6 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re session = mSystemState->SessionMgr()->CreateUnauthenticatedSession(params.GetPeerAddress(), device->GetMRPConfig()); VerifyOrExit(session.HasValue(), err = CHIP_ERROR_NO_MEMORY); - err = mSystemState->SessionIDAlloc()->Allocate(keyID); - SuccessOrExit(err); - // TODO - Remove use of SetActive/IsActive from CommissioneeDeviceProxy device->SetActive(true); @@ -692,7 +687,7 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re exchangeCtxt = mSystemState->ExchangeMgr()->NewContext(session.Value(), &device->GetPairing()); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); - err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, + err = device->GetPairing().Pair(*mSystemState->SessionMgr(), params.GetPeerAddress(), params.GetSetupPINCode(), Optional::Value(GetLocalMRPConfig()), exchangeCtxt, this); SuccessOrExit(err); diff --git a/src/controller/CHIPDeviceControllerFactory.cpp b/src/controller/CHIPDeviceControllerFactory.cpp index 3f0c03cf1f7717..f5e5aab66bb278 100644 --- a/src/controller/CHIPDeviceControllerFactory.cpp +++ b/src/controller/CHIPDeviceControllerFactory.cpp @@ -214,14 +214,12 @@ CHIP_ERROR DeviceControllerFactory::InitSystemState(FactoryInitParams params) chip::app::DnssdServer::Instance().StartServer(); } - stateParams.sessionIDAllocator = Platform::New(); stateParams.operationalDevicePool = Platform::New(); stateParams.caseClientPool = Platform::New(); DeviceProxyInitParams deviceInitParams = { .sessionManager = stateParams.sessionMgr, .exchangeMgr = stateParams.exchangeMgr, - .idAllocator = stateParams.sessionIDAllocator, .fabricTable = stateParams.fabricTable, .clientPool = stateParams.caseClientPool, .groupDataProvider = stateParams.groupDataProvider, @@ -336,13 +334,8 @@ CHIP_ERROR DeviceControllerSystemState::Shutdown() mCASESessionManager = nullptr; } - // mSessionIDAllocator, mCASEClientPool, and mDevicePool must be deallocated + // mCASEClientPool and mDevicePool must be deallocated // after mCASESessionManager, which uses them. - if (mSessionIDAllocator != nullptr) - { - Platform::Delete(mSessionIDAllocator); - mSessionIDAllocator = nullptr; - } if (mOperationalDevicePool != nullptr) { diff --git a/src/controller/CHIPDeviceControllerSystemState.h b/src/controller/CHIPDeviceControllerSystemState.h index df921bbd2a0138..acd00705bed6da 100644 --- a/src/controller/CHIPDeviceControllerSystemState.h +++ b/src/controller/CHIPDeviceControllerSystemState.h @@ -36,7 +36,6 @@ #include #include #include -#include #include #include @@ -88,7 +87,6 @@ struct DeviceControllerSystemStateParams FabricTable * fabricTable = nullptr; CASEServer * caseServer = nullptr; CASESessionManager * caseSessionManager = nullptr; - SessionIDAllocator * sessionIDAllocator = nullptr; OperationalDevicePool * operationalDevicePool = nullptr; CASEClientPool * caseClientPool = nullptr; Credentials::GroupDataProvider * groupDataProvider = nullptr; @@ -109,8 +107,8 @@ class DeviceControllerSystemState mUDPEndPointManager(params.udpEndPointManager), mTransportMgr(params.transportMgr), mSessionMgr(params.sessionMgr), mExchangeMgr(params.exchangeMgr), mMessageCounterManager(params.messageCounterManager), mFabrics(params.fabricTable), mCASEServer(params.caseServer), mCASESessionManager(params.caseSessionManager), - mSessionIDAllocator(params.sessionIDAllocator), mOperationalDevicePool(params.operationalDevicePool), - mCASEClientPool(params.caseClientPool), mGroupDataProvider(params.groupDataProvider) + mOperationalDevicePool(params.operationalDevicePool), mCASEClientPool(params.caseClientPool), + mGroupDataProvider(params.groupDataProvider) { #if CONFIG_NETWORK_LAYER_BLE mBleLayer = params.bleLayer; @@ -143,8 +141,7 @@ class DeviceControllerSystemState { return mSystemLayer != nullptr && mUDPEndPointManager != nullptr && mTransportMgr != nullptr && mSessionMgr != nullptr && mExchangeMgr != nullptr && mMessageCounterManager != nullptr && mFabrics != nullptr && mCASESessionManager != nullptr && - mSessionIDAllocator != nullptr && mOperationalDevicePool != nullptr && mCASEClientPool != nullptr && - mGroupDataProvider != nullptr; + mOperationalDevicePool != nullptr && mCASEClientPool != nullptr && mGroupDataProvider != nullptr; }; System::Layer * SystemLayer() const { return mSystemLayer; }; @@ -159,7 +156,6 @@ class DeviceControllerSystemState Ble::BleLayer * BleLayer() const { return mBleLayer; }; #endif CASESessionManager * CASESessionMgr() const { return mCASESessionManager; } - SessionIDAllocator * SessionIDAlloc() const { return mSessionIDAllocator; } Credentials::GroupDataProvider * GetGroupDataProvider() const { return mGroupDataProvider; } private: @@ -178,7 +174,6 @@ class DeviceControllerSystemState FabricTable * mFabrics = nullptr; CASEServer * mCASEServer = nullptr; CASESessionManager * mCASESessionManager = nullptr; - SessionIDAllocator * mSessionIDAllocator = nullptr; OperationalDevicePool * mOperationalDevicePool = nullptr; CASEClientPool * mCASEClientPool = nullptr; Credentials::GroupDataProvider * mGroupDataProvider = nullptr; diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h index 6d053ffec9b9d2..fa9278c593713f 100644 --- a/src/controller/CommissioneeDeviceProxy.h +++ b/src/controller/CommissioneeDeviceProxy.h @@ -39,7 +39,6 @@ #include #include #include -#include #include #include #include @@ -69,7 +68,6 @@ struct ControllerDeviceInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; Inet::EndPointManager * udpEndPointManager = nullptr; PersistentStorageDelegate * storageDelegate = nullptr; - SessionIDAllocator * idAllocator = nullptr; #if CONFIG_NETWORK_LAYER_BLE Ble::BleLayer * bleLayer = nullptr; #endif @@ -120,7 +118,6 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat mExchangeMgr = params.exchangeMgr; mUDPEndPointManager = params.udpEndPointManager; mFabricIndex = fabric; - mIDAllocator = params.idAllocator; #if CONFIG_NETWORK_LAYER_BLE mBleLayer = params.bleLayer; #endif @@ -287,8 +284,6 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat CHIP_ERROR LoadSecureSessionParametersIfNeeded(bool & didLoad); FabricIndex mFabricIndex = kUndefinedFabricIndex; - - SessionIDAllocator * mIDAllocator = nullptr; }; } // namespace chip diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 9fe19f46ca78ba..9e09a5b1f7eddc 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -95,6 +95,10 @@ CHIP_ERROR MessagingContext::ShutdownAndRestoreExisting(MessagingContext & exist CHIP_ERROR MessagingContext::CreateSessionBobToAlice() { + if (!mPairingBobToAlice.GetSecureSessionHandle().HasValue()) + { + mPairingBobToAlice.Init(mSessionManager); + } return mSessionManager.NewPairing(mSessionBobToAlice, Optional::Value(mAliceAddress), GetAliceFabric()->GetNodeId(), &mPairingBobToAlice, CryptoContext::SessionRole::kInitiator, mBobFabricIndex); @@ -102,6 +106,10 @@ CHIP_ERROR MessagingContext::CreateSessionBobToAlice() CHIP_ERROR MessagingContext::CreateSessionAliceToBob() { + if (!mPairingAliceToBob.GetSecureSessionHandle().HasValue()) + { + mPairingAliceToBob.Init(mSessionManager); + } return mSessionManager.NewPairing(mSessionAliceToBob, Optional::Value(mBobAddress), GetBobFabric()->GetNodeId(), &mPairingAliceToBob, CryptoContext::SessionRole::kResponder, mAliceFabricIndex); diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index bff1a6e9606f72..902749df580196 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -71,8 +71,9 @@ class MessagingContext : public PlatformMemoryUser public: MessagingContext() : mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)), - mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), mPairingAliceToBob(kBobKeyId, kAliceKeyId), - mPairingBobToAlice(kAliceKeyId, kBobKeyId) + mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), + mPairingAliceToBob(kBobKeyId, kAliceKeyId, GetSecureSessionManager()), + mPairingBobToAlice(kAliceKeyId, kBobKeyId, GetSecureSessionManager()) {} ~MessagingContext() { VerifyOrDie(mInitialized == false); } diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index e8c35e798d8937..f090811a0a440c 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -155,6 +155,7 @@ CHIP_ERROR EstablishSecureSession() chip::Optional peerAddr; chip::SecurePairingUsingTestSecret * testSecurePairingSecret = chip::Platform::New(); VerifyOrExit(testSecurePairingSecret != nullptr, err = CHIP_ERROR_NO_MEMORY); + testSecurePairingSecret->Init(gSessionManager); if (gUseTCP) { diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index c8255b1ef2703c..1d1618bfcbbe2a 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -123,6 +123,7 @@ int main(int argc, char * argv[]) SuccessOrExit(err); } + gTestPairing.Init(gSessionManager); err = gSessionManager.NewPairing(gSession, peer, chip::kTestControllerNodeId, &gTestPairing, chip::CryptoContext::SessionRole::kResponder, gFabricIndex); SuccessOrExit(err); diff --git a/src/protocols/secure_channel/BUILD.gn b/src/protocols/secure_channel/BUILD.gn index 5037378f195e9f..4cc84703343fcd 100644 --- a/src/protocols/secure_channel/BUILD.gn +++ b/src/protocols/secure_channel/BUILD.gn @@ -18,8 +18,6 @@ static_library("secure_channel") { "SessionEstablishmentDelegate.h", "SessionEstablishmentExchangeDispatch.cpp", "SessionEstablishmentExchangeDispatch.h", - "SessionIDAllocator.cpp", - "SessionIDAllocator.h", "StatusReport.cpp", "StatusReport.h", ] diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 59dcffe9943f2c..470f6cae3aaf45 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -74,12 +74,10 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec) } #endif - ReturnErrorOnFailure(mSessionIDAllocator.Allocate(mSessionKeyId)); - // Setup CASE state machine using the credentials for the current fabric. GetSession().SetGroupDataProvider(mGroupDataProvider); ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment( - mSessionKeyId, mFabrics, this, Optional::Value(GetLocalMRPConfig()))); + *mSessionManager, mFabrics, this, Optional::Value(GetLocalMRPConfig()))); // Hand over the exchange context to the CASE session. ec->SetDelegate(&GetSession()); @@ -123,7 +121,6 @@ void CASEServer::Cleanup() void CASEServer::OnSessionEstablishmentError(CHIP_ERROR err) { ChipLogError(Inet, "CASE Session establishment failed: %s", ErrorStr(err)); - mSessionIDAllocator.Free(mSessionKeyId); Cleanup(); } diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index f57ef94baaf87e..6e93f558a88ee3 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -24,7 +24,6 @@ #include #include #include -#include namespace chip { @@ -63,7 +62,6 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan Messaging::ExchangeManager * mExchangeManager = nullptr; CASESession mPairingSession; - uint16_t mSessionKeyId = 0; SessionManager * mSessionManager = nullptr; #if CONFIG_NETWORK_LAYER_BLE Ble::BleLayer * mBleLayer = nullptr; @@ -71,7 +69,6 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan FabricTable * mFabrics = nullptr; Credentials::GroupDataProvider * mGroupDataProvider = nullptr; - SessionIDAllocator mSessionIDAllocator; CHIP_ERROR InitCASEHandshake(Messaging::ExchangeContext * ec); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 0ef0f76c0f6efc..a3783a977be8b9 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -187,10 +187,9 @@ CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::Init(uint16_t localSessionId, SessionEstablishmentDelegate * delegate) +CHIP_ERROR CASESession::Init(SessionManager & sessionManager, SessionEstablishmentDelegate * delegate) { VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(mGroupDataProvider != nullptr, CHIP_ERROR_INVALID_ARGUMENT); Clear(); @@ -198,7 +197,7 @@ CHIP_ERROR CASESession::Init(uint16_t localSessionId, SessionEstablishmentDelega ReturnErrorOnFailure(mCommissioningHash.Begin()); mDelegate = delegate; - SetLocalSessionId(localSessionId); + ReturnErrorOnFailure(AllocateSecureSession(sessionManager)); mValidContext.Reset(); mValidContext.mRequiredKeyUsages.Set(KeyUsageFlags::kDigitalSignature); @@ -208,11 +207,12 @@ CHIP_ERROR CASESession::Init(uint16_t localSessionId, SessionEstablishmentDelega } CHIP_ERROR -CASESession::ListenForSessionEstablishment(uint16_t localSessionId, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, +CASESession::ListenForSessionEstablishment(SessionManager & sessionManager, FabricTable * fabrics, + SessionEstablishmentDelegate * delegate, Optional mrpConfig) { VerifyOrReturnError(fabrics != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - ReturnErrorOnFailure(Init(localSessionId, delegate)); + ReturnErrorOnFailure(Init(sessionManager, delegate)); mFabricsTable = fabrics; mLocalMRPConfig = mrpConfig; @@ -224,8 +224,8 @@ CASESession::ListenForSessionEstablishment(uint16_t localSessionId, FabricTable return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, - uint16_t localSessionId, ExchangeContext * exchangeCtxt, +CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, + FabricInfo * fabric, NodeId peerNodeId, ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, Optional mrpConfig) { MATTER_TRACE_EVENT_SCOPE("EstablishSession", "CASESession"); @@ -241,7 +241,7 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - err = Init(localSessionId, delegate); + err = Init(sessionManager, delegate); // We are setting the exchange context specifically before checking for error. // This is to make sure the exchange will get closed if Init() returned an error. @@ -358,6 +358,9 @@ CHIP_ERROR CASESession::SendSigma1() TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 }; + // Validate that we have a session ID allocated. + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + // Generate an ephemeral keypair ReturnErrorOnFailure(mEphemeralKey.Initialize()); @@ -372,7 +375,7 @@ CHIP_ERROR CASESession::SendSigma1() ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(mInitiatorRandom))); // Retrieve Session Identifier - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); // Generate a Destination Identifier based on the node we are attempting to reach { ReturnErrorCodeIf(mFabricInfo == nullptr, CHIP_ERROR_INCORRECT_STATE); @@ -582,6 +585,9 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) System::PacketBufferHandle msg_R2_resume; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + // Validate that we have a session ID allocated. + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + msg_R2_resume = System::PacketBufferHandle::New(max_sigma2_resume_data_len); VerifyOrReturnError(!msg_R2_resume.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -600,7 +606,7 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), resumeMICSpan)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId().Value())); if (mLocalMRPConfig.HasValue()) { @@ -625,6 +631,9 @@ CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) CHIP_ERROR CASESession::SendSigma2() { MATTER_TRACE_EVENT_SCOPE("SendSigma2", "CASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(mFabricInfo != nullptr, CHIP_ERROR_INCORRECT_STATE); ByteSpan icaCert; @@ -724,7 +733,7 @@ CHIP_ERROR CASESession::SendSigma2() tlvWriterMsg2.Init(std::move(msg_R2)); ReturnErrorOnFailure(tlvWriterMsg2.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(1), &msg_rand[0], sizeof(msg_rand))); - ReturnErrorOnFailure(tlvWriterMsg2.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriterMsg2.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); ReturnErrorOnFailure( tlvWriterMsg2.PutBytes(TLV::ContextTag(3), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); ReturnErrorOnFailure(tlvWriterMsg2.PutBytes(TLV::ContextTag(4), msg_R2_Encrypted.Get(), diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 63e51073b64606..a6651b1f1fb9a6 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -79,32 +79,32 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @brief * Initialize using configured fabrics and wait for session establishment requests. * - * @param mySessionId Session ID to be assigned to the secure session on the peer node + * @param sessionManager session manager from which to allocate a secure session object * @param fabrics Table of fabrics that are currently configured on the device * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ CHIP_ERROR ListenForSessionEstablishment( - uint16_t mySessionId, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, + SessionManager & sessionManager, FabricTable * fabrics, SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** * @brief * Create and send session establishment request using device's operational credentials. * + * @param sessionManager session manager from which to allocate a secure session object * @param peerAddress Address of peer with which to establish a session. * @param fabric The fabric that should be used for connecting with the peer * @param peerNodeId Node id of the peer node - * @param mySessionId Session ID to be assigned to the secure session on the peer node * @param exchangeCtxt The exchange context to send and receive messages with the peer * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ CHIP_ERROR - EstablishSession(const Transport::PeerAddress peerAddress, FabricInfo * fabric, NodeId peerNodeId, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, + EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, FabricInfo * fabric, + NodeId peerNodeId, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -190,7 +190,7 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin kSentSigma2Resume = 4, }; - CHIP_ERROR Init(uint16_t mySessionId, SessionEstablishmentDelegate * delegate); + CHIP_ERROR Init(SessionManager & sessionManager, SessionEstablishmentDelegate * delegate); // On success, sets mIpk to the correct value for outgoing Sigma1 based on internal state CHIP_ERROR RecoverInitiatorIpk(); diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index c714f860225ae4..c23d8faeec69ce 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -117,73 +117,7 @@ void PASESession::DiscardExchange() } } -CHIP_ERROR PASESession::Serialize(PASESessionSerialized & output) -{ - PASESessionSerializable serializable; - VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - - ReturnErrorOnFailure(ToSerializable(serializable)); - - uint16_t serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast(&serializable)), - static_cast(sizeof(serializable)), Uint8::to_char(output.inner)); - VerifyOrReturnError(serializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(serializedLen < sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - output.inner[serializedLen] = '\0'; - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::Deserialize(PASESessionSerialized & input) -{ - PASESessionSerializable serializable; - size_t maxlen = BASE64_ENCODED_LEN(sizeof(serializable)); - size_t len = strnlen(Uint8::to_char(input.inner), maxlen); - uint16_t deserializedLen = 0; - - VerifyOrReturnError(len < sizeof(PASESessionSerialized), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_INVALID_ARGUMENT); - - memset(&serializable, 0, sizeof(serializable)); - deserializedLen = - Base64Decode(Uint8::to_const_char(input.inner), static_cast(len), Uint8::to_uchar((uint8_t *) &serializable)); - - VerifyOrReturnError(deserializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(deserializedLen <= sizeof(serializable), CHIP_ERROR_INVALID_ARGUMENT); - - return FromSerializable(serializable); -} - -CHIP_ERROR PASESession::ToSerializable(PASESessionSerializable & serializable) -{ - VerifyOrReturnError(CanCastTo(mKeLen), CHIP_ERROR_INTERNAL); - - memset(&serializable, 0, sizeof(serializable)); - serializable.mKeLen = static_cast(mKeLen); - serializable.mPairingComplete = (mPairingComplete) ? 1 : 0; - serializable.mLocalSessionId = GetLocalSessionId(); - serializable.mPeerSessionId = GetPeerSessionId(); - - memcpy(serializable.mKe, mKe, mKeLen); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::FromSerializable(const PASESessionSerializable & serializable) -{ - mPairingComplete = (serializable.mPairingComplete == 1); - mKeLen = static_cast(serializable.mKeLen); - - VerifyOrReturnError(mKeLen <= sizeof(mKe), CHIP_ERROR_INVALID_ARGUMENT); - memset(mKe, 0, sizeof(mKe)); - memcpy(mKe, serializable.mKe, mKeLen); - - SetLocalSessionId(serializable.mLocalSessionId); - SetPeerSessionId(serializable.mPeerSessionId); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR PASESession::Init(uint16_t mySessionId, uint32_t setupCode, SessionEstablishmentDelegate * delegate) +CHIP_ERROR PASESession::Init(SessionManager & sessionManager, uint32_t setupCode, SessionEstablishmentDelegate * delegate) { VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT); @@ -194,9 +128,9 @@ CHIP_ERROR PASESession::Init(uint16_t mySessionId, uint32_t setupCode, SessionEs ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ Uint8::from_const_char(kSpake2pContext), strlen(kSpake2pContext) })); mDelegate = delegate; - - ChipLogDetail(SecureChannel, "Assigned local session key ID %d", mySessionId); - SetLocalSessionId(mySessionId); + ReturnErrorOnFailure(AllocateSecureSession(sessionManager)); + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ChipLogDetail(SecureChannel, "Assigned local session key ID %u", GetLocalSessionId().Value()); ReturnErrorCodeIf(setupCode >= (1 << kSetupPINCodeFieldLengthInBits), CHIP_ERROR_INVALID_ARGUMENT); mSetupPINCode = setupCode; @@ -232,8 +166,8 @@ CHIP_ERROR PASESession::SetupSpake2p() return CHIP_NO_ERROR; } -CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, Optional mrpConfig, +CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, + const ByteSpan & salt, Optional mrpConfig, SessionEstablishmentDelegate * delegate) { // Return early on error here, as we have not initialized any state yet @@ -242,7 +176,7 @@ CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_ ReturnErrorCodeIf(salt.size() < kSpake2p_Min_PBKDF_Salt_Length || salt.size() > kSpake2p_Max_PBKDF_Salt_Length, CHIP_ERROR_INVALID_ARGUMENT); - CHIP_ERROR err = Init(mySessionId, kSetupPINCodeUndefinedValue, delegate); + CHIP_ERROR err = Init(sessionManager, kSetupPINCodeUndefinedValue, delegate); // From here onwards, let's go to exit on error, as some state might have already // been initialized SuccessOrExit(err); @@ -279,13 +213,13 @@ CHIP_ERROR PASESession::WaitForPairing(const Spake2pVerifier & verifier, uint32_ return err; } -CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, +CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) { MATTER_TRACE_EVENT_SCOPE("Pair", "PASESession"); ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - CHIP_ERROR err = Init(mySessionId, peerSetUpPINCode, delegate); + CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate); SuccessOrExit(err); mExchangeCtxt = exchangeCtxt; @@ -334,6 +268,9 @@ CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session, CryptoConte CHIP_ERROR PASESession::SendPBKDFParamRequest() { MATTER_TRACE_EVENT_SCOPE("SendPBKDFParamRequest", "PASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; @@ -353,7 +290,7 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId().Value())); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), kDefaultCommissioningPasscodeId)); ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); if (mLocalMRPConfig.HasValue()) @@ -442,6 +379,9 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool initiatorHasPBKDFParams) { MATTER_TRACE_EVENT_SCOPE("SendPBKDFParamResponse", "PASESession"); + + VerifyOrReturnError(GetLocalSessionId().HasValue(), CHIP_ERROR_INCORRECT_STATE); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; @@ -464,7 +404,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in // The initiator random value is being sent back in the response as required by the specifications ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), initiatorRandom)); ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId())); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalSessionId().Value())); if (!initiatorHasPBKDFParams) { diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index a6f9b687c6566a..b3fdd396bc0314 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -83,34 +83,34 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @brief * Initialize using PASE verifier and wait for pairing requests. * - * @param verifier PASE verifier to be used for SPAKE2P pairing - * @param pbkdf2IterCount Iteration count for PBKDF2 function - * @param salt Salt to be used for SPAKE2P operation - * @param mySessionId Session ID to be assigned to the secure session on the peer node - * @param delegate Callback object + * @param sessionManager session manager from which to allocate a secure session object + * @param verifier PASE verifier to be used for SPAKE2P pairing + * @param pbkdf2IterCount Iteration count for PBKDF2 function + * @param salt Salt to be used for SPAKE2P operation + * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ - CHIP_ERROR WaitForPairing(const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, Optional mrpConfig, + CHIP_ERROR WaitForPairing(SessionManager & sessionManager, const Spake2pVerifier & verifier, uint32_t pbkdf2IterCount, + const ByteSpan & salt, Optional mrpConfig, SessionEstablishmentDelegate * delegate); /** * @brief * Create a pairing request using peer's setup PIN code. * - * @param peerAddress Address of peer to pair - * @param peerSetUpPINCode Setup PIN code of the peer device - * @param mySessionId Session ID to be assigned to the secure session on the peer node - * @param exchangeCtxt The exchange context to send and receive messages with the peer - * Note: It's expected that the caller of this API hands over the - * ownership of the exchangeCtxt to PASESession object. PASESession - * will close the exchange on (successful/failed) handshake completion. - * @param delegate Callback object + * @param sessionManager session manager from which to allocate a secure session object + * @param peerAddress Address of peer to pair + * @param peerSetUpPINCode Setup PIN code of the peer device + * @param exchangeCtxt The exchange context to send and receive messages with the peer + * Note: It's expected that the caller of this API hands over the + * ownership of the exchangeCtxt to PASESession object. PASESession + * will close the exchange on (successful/failed) handshake completion. + * @param delegate Callback object * * @return CHIP_ERROR The result of initialization */ - CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, + CHIP_ERROR Pair(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); @@ -141,30 +141,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin */ CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; - /** @brief Serialize the Pairing Session to a string. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR Serialize(PASESessionSerialized & output); - - /** @brief Deserialize the Pairing Session from the string. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR Deserialize(PASESessionSerialized & input); - - /** @brief Serialize the PASESession to the given serializable data structure for secure pairing - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR ToSerializable(PASESessionSerializable & output); - - /** @brief Reconstruct secure pairing class from the serializable data structure. - * - * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise - **/ - CHIP_ERROR FromSerializable(const PASESessionSerializable & output); - // TODO: remove Clear, we should create a new instance instead reset the old instance. /** @brief This function zeroes out and resets the memory used by the object. **/ @@ -205,7 +181,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin kUnexpected = 0xff, }; - CHIP_ERROR Init(uint16_t mySessionId, uint32_t setupCode, SessionEstablishmentDelegate * delegate); + CHIP_ERROR Init(SessionManager & sessionManager, uint32_t setupCode, SessionEstablishmentDelegate * delegate); CHIP_ERROR ValidateReceivedMessage(Messaging::ExchangeContext * exchange, const PayloadHeader & payloadHeader, const System::PacketBufferHandle & msg); @@ -294,16 +270,22 @@ class SecurePairingUsingTestSecret : public PairingSession public: SecurePairingUsingTestSecret() : PairingSession(Transport::SecureSession::Type::kPASE) { - // Do not set to 0 to prevent unwanted unsecured session + // Do not set to 0 to prevent an unwanted unsecured session // since the session type is unknown. - SetLocalSessionId(1); SetPeerSessionId(1); } - SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId) : - PairingSession(Transport::SecureSession::Type::kPASE) + void Init(SessionManager & sessionManager) { - SetLocalSessionId(localSessionId); + // Do not set to 0 to prevent an unwanted unsecured session + // since the session type is unknown. + AllocateSecureSession(sessionManager, mLocalSessionId); + } + + SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId, SessionManager & sessionManager) : + PairingSession(Transport::SecureSession::Type::kPASE), mLocalSessionId(localSessionId) + { + AllocateSecureSession(sessionManager, localSessionId); SetPeerSessionId(peerSessionId); } @@ -314,28 +296,11 @@ class SecurePairingUsingTestSecret : public PairingSession CryptoContext::SessionInfoType::kSessionEstablishment, role); } - CHIP_ERROR ToSerializable(PASESessionSerializable & serializable) - { - size_t secretLen = strlen(kTestSecret); - - memset(&serializable, 0, sizeof(serializable)); - serializable.mKeLen = static_cast(secretLen); - serializable.mPairingComplete = 1; - serializable.mLocalSessionId = GetLocalSessionId(); - serializable.mPeerSessionId = GetPeerSessionId(); - - memcpy(serializable.mKe, kTestSecret, secretLen); - return CHIP_NO_ERROR; - } - private: + // Do not set to 0 to prevent an unwanted unsecured session + // since the session type is unknown. + uint16_t mLocalSessionId = 1; const char * kTestSecret = CHIP_CONFIG_TEST_SHARED_SECRET_VALUE; }; -typedef struct PASESessionSerialized -{ - // Extra uint64_t to account for padding bytes (NULL termination, and some decoding overheads) - uint8_t inner[BASE64_ENCODED_LEN(sizeof(PASESessionSerializable) + sizeof(uint64_t))]; -} PASESessionSerialized; - } // namespace chip diff --git a/src/protocols/secure_channel/SessionIDAllocator.cpp b/src/protocols/secure_channel/SessionIDAllocator.cpp deleted file mode 100644 index 9df635d98d84fd..00000000000000 --- a/src/protocols/secure_channel/SessionIDAllocator.cpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -namespace chip { - -uint16_t SessionIDAllocator::sNextAvailable = 1; - -CHIP_ERROR SessionIDAllocator::Allocate(uint16_t & id) -{ - VerifyOrReturnError(sNextAvailable < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - VerifyOrReturnError(sNextAvailable > kUnsecuredSessionId, CHIP_ERROR_INTERNAL); - id = sNextAvailable; - - // TODO - Update SessionID allocator to use freed session IDs - sNextAvailable++; - - return CHIP_NO_ERROR; -} - -void SessionIDAllocator::Free(uint16_t id) -{ - // As per spec 4.4.1.3 Session ID of 0 is reserved for Unsecure communication - if (sNextAvailable > (kUnsecuredSessionId + 1) && (sNextAvailable - 1) == id) - { - sNextAvailable--; - } -} - -CHIP_ERROR SessionIDAllocator::Reserve(uint16_t id) -{ - VerifyOrReturnError(id < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - if (id >= sNextAvailable) - { - sNextAvailable = id; - sNextAvailable++; - } - - // TODO - Check if ID is already allocated in SessionIDAllocator::Reserve() - - return CHIP_NO_ERROR; -} - -CHIP_ERROR SessionIDAllocator::ReserveUpTo(uint16_t id) -{ - VerifyOrReturnError(id < kMaxSessionID, CHIP_ERROR_NO_MEMORY); - if (id >= sNextAvailable) - { - sNextAvailable = id; - sNextAvailable++; - } - - // TODO - Update ReserveUpTo to mark all IDs in use - // Current SessionIDAllocator only tracks the smallest unused session ID. - // If/when we change it to track all in use IDs, we should also update ReserveUpTo - // to reserve all individual session IDs, instead of just setting the sNextAvailable. - - return CHIP_NO_ERROR; -} - -uint16_t SessionIDAllocator::Peek() -{ - return sNextAvailable; -} - -} // namespace chip diff --git a/src/protocols/secure_channel/SessionIDAllocator.h b/src/protocols/secure_channel/SessionIDAllocator.h deleted file mode 100644 index c56b292c1b3ed8..00000000000000 --- a/src/protocols/secure_channel/SessionIDAllocator.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -// Spec 4.4.1.3 -// ===== Session ID (16 bits) -// An unsigned integer value identifying the session associated with this message. -// The session identifies the particular key used to encrypt a message out of the set of -// available keys (either session or group), and the particular encryption/message -// integrity algorithm to use for the message.The Session ID field is always present. -// A Session ID of 0 SHALL indicate an unsecured session with no encryption or message integrity checking. -// -// The Session ID is allocated from a global numerical space shared across all fabrics and nodes on the resident process instance. -// - -namespace chip { - -class SessionIDAllocator -{ -public: - SessionIDAllocator() {} - ~SessionIDAllocator() {} - - CHIP_ERROR Allocate(uint16_t & id); - void Free(uint16_t id); - CHIP_ERROR Reserve(uint16_t id); - CHIP_ERROR ReserveUpTo(uint16_t id); - uint16_t Peek(); - -private: - static constexpr uint16_t kMaxSessionID = UINT16_MAX; - static constexpr uint16_t kUnsecuredSessionId = 0; - - static uint16_t sNextAvailable; -}; - -} // namespace chip diff --git a/src/protocols/secure_channel/tests/BUILD.gn b/src/protocols/secure_channel/tests/BUILD.gn index c2ce3df2fe54b7..070a757fbdcf97 100644 --- a/src/protocols/secure_channel/tests/BUILD.gn +++ b/src/protocols/secure_channel/tests/BUILD.gn @@ -15,7 +15,6 @@ chip_test_suite("tests") { # TODO - Fix Message Counter Sync to use group key # "TestMessageCounterManager.cpp", "TestPASESession.cpp", - "TestSessionIDAllocator.cpp", "TestStatusReport.cpp", ] diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index c89f68b929cd5f..418591d2a66e7b 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -179,6 +179,7 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) TestCASESecurePairingDelegate delegate; CASESession pairing; FabricTable fabrics; + SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kCASE); CATValues peerCATs; @@ -186,9 +187,10 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); pairing.SetGroupDataProvider(&gDeviceGroupDataProvider); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); - NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(0, &fabrics, &delegate) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(sessionManager, nullptr, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, + pairing.ListenForSessionEstablishment(sessionManager, nullptr, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(sessionManager, &fabrics, &delegate) == CHIP_NO_ERROR); } void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) @@ -202,22 +204,23 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); + SessionManager sessionManager; ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, 0, nullptr, - nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, + nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, nullptr, - nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, context, - &delegate) == CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -237,8 +240,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, context1, - &delegate) == CHIP_ERROR_BAD_REQUEST); + pairing1.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, + context1, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; @@ -254,6 +257,7 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESession pairingAccessory; CASESessionCachable serializableCommissioner; CASESessionCachable serializableAccessory; + SessionManager sessionManager; gLoopback.mSentMessageCount = 0; @@ -268,10 +272,11 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte pairingAccessory.SetGroupDataProvider(&gDeviceGroupDataProvider); NL_TEST_ASSERT(inSuite, - pairingAccessory.ListenForSessionEstablishment(0, &gDeviceFabrics, &delegateAccessory) == CHIP_NO_ERROR); + pairingAccessory.ListenForSessionEstablishment(sessionManager, &gDeviceFabrics, &delegateAccessory) == + CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, + Node01_01, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); @@ -304,6 +309,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte auto * pairingCommissioner = chip::Platform::New(); pairingCommissioner->SetGroupDataProvider(&gCommissionerGroupDataProvider); + SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); gLoopback.mSentMessageCount = 0; @@ -322,8 +329,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, fabric != nullptr); NL_TEST_ASSERT(inSuite, - pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner->EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, + Node01_01, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); @@ -334,8 +341,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner1); NL_TEST_ASSERT(inSuite, - pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, - contextCommissioner1, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner1->EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, + Node01_01, contextCommissioner1, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); chip::Platform::Delete(pairingCommissioner); diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 27fb6a9ac875ee..838b24804cdf63 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -35,6 +35,12 @@ #include #include +// This test suite pushes multiple PASESession objects onto the stack for the +// purposes of testing device-to-device communication. However, in the real +// world, these won't live in a single device's memory. Hence, disable stack +// warning. +#pragma GCC diagnostic ignored "-Wstack-usage=" + using namespace chip; using namespace chip::Inet; using namespace chip::Transport; @@ -114,6 +120,7 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; PASESession pairing; + SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kPASE); CATValues peerCATs; @@ -123,28 +130,29 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(nullptr, 0), 0, - Optional::Missing(), + pairing.WaitForPairing(sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(nullptr, 0), Optional::Missing(), &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(reinterpret_cast("saltSalt"), 8), 0, + pairing.WaitForPairing(sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(reinterpret_cast("saltSalt"), 8), Optional::Missing(), nullptr) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(reinterpret_cast("saltSalt"), 8), 0, + pairing.WaitForPairing(sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(reinterpret_cast("saltSalt"), 8), Optional::Missing(), &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), - 0, Optional::Missing(), &delegate) == CHIP_NO_ERROR); + pairing.WaitForPairing(sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(sTestSpake2p01_Salt), Optional::Missing(), + &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); } @@ -154,20 +162,20 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; - PASESession pairing; + SessionManager sessionManager; gLoopback.Reset(); ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, Optional::Missing(), nullptr, nullptr) != CHIP_NO_ERROR); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, Optional::Missing(), context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -185,7 +193,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) PASESession pairing1; ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairing1.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, Optional::Missing(), context1, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); @@ -202,6 +210,7 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; + SessionManager sessionManager; gLoopback.mSentMessageCount = 0; @@ -225,13 +234,13 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P Protocols::SecureChannel::MsgType::PBKDFParamRequest, &pairingAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingAccessory.WaitForPairing(sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, - ByteSpan(sTestSpake2p01_Salt), 0, mrpAccessoryConfig, + pairingAccessory.WaitForPairing(sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, + ByteSpan(sTestSpake2p01_Salt), mrpAccessoryConfig, &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, 0, + pairingCommissioner.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, mrpCommissionerConfig, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -332,6 +341,8 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegateAccessory; PASESession pairingAccessory; + SessionManager sessionManager; + gLoopback.Reset(); gLoopback.mSentMessageCount = 0; @@ -353,12 +364,12 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing( - sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), 0, + sessionManager, sTestSpake2p01_PASEVerifier, sTestSpake2p01_IterationCount, ByteSpan(sTestSpake2p01_Salt), Optional::Missing(), &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, + pairingCommissioner.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), 4321, Optional::Missing(), contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -369,75 +380,6 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingErrors == 1); } -void SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, - PASESession & deserialized) -{ - PASESessionSerialized serialized; - gLoopback.Reset(); - NL_TEST_ASSERT(inSuite, pairingCommissioner.Serialize(serialized) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, deserialized.Deserialize(serialized) == CHIP_NO_ERROR); - - // Serialize from the deserialized session, and check we get the same string back - PASESessionSerialized serialized2; - NL_TEST_ASSERT(inSuite, deserialized.Serialize(serialized2) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, strncmp(Uint8::to_char(serialized.inner), Uint8::to_char(serialized2.inner), sizeof(serialized)) == 0); -} - -void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) -{ - TestSecurePairingDelegate delegateCommissioner; - - // Allocate on the heap to avoid stack overflow in some restricted test scenarios (e.g. QEMU) - auto * testPairingSession1 = chip::Platform::New(); - auto * testPairingSession2 = chip::Platform::New(); - - gLoopback.Reset(); - - SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, Optional::Missing(), - Optional::Missing(), delegateCommissioner); - SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); - - const uint8_t plain_text[] = { 0x86, 0x74, 0x64, 0xe5, 0x0b, 0xd4, 0x0d, 0x90, 0xe1, 0x17, 0xa3, 0x2d, 0x4b, 0xd4, 0xe1, 0xe6 }; - uint8_t encrypted[64]; - PacketHeader header; - MessageAuthenticationCode mac; - - header.SetSessionId(1); - NL_TEST_ASSERT(inSuite, header.IsEncrypted() == true); - NL_TEST_ASSERT(inSuite, header.MICTagLength() == 16); - - // Let's try encrypting using original session, and decrypting using deserialized - { - CryptoContext session1; - - CHIP_ERROR err = testPairingSession1->DeriveSecureSession(session1, CryptoContext::SessionRole::kInitiator); - - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - CryptoContext::NonceStorage nonce; - CryptoContext::BuildNonce(nonce, header.GetSecurityFlags(), header.GetMessageCounter(), kUndefinedNodeId); - err = session1.Encrypt(plain_text, sizeof(plain_text), encrypted, nonce, header, mac); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - } - - { - CryptoContext session2; - NL_TEST_ASSERT(inSuite, - testPairingSession2->DeriveSecureSession(session2, CryptoContext::SessionRole::kResponder) == CHIP_NO_ERROR); - - uint8_t decrypted[64]; - CryptoContext::NonceStorage nonce; - CryptoContext::BuildNonce(nonce, header.GetSecurityFlags(), header.GetMessageCounter(), kUndefinedNodeId); - NL_TEST_ASSERT(inSuite, session2.Decrypt(encrypted, sizeof(plain_text), decrypted, nonce, header, mac) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, memcmp(plain_text, decrypted, sizeof(plain_text)) == 0); - } - - chip::Platform::Delete(testPairingSession1); - chip::Platform::Delete(testPairingSession2); -} - void PASEVerifierSerializeTest(nlTestSuite * inSuite, void * inContext) { Spake2pVerifier verifier; @@ -474,7 +416,6 @@ static const nlTest sTests[] = NL_TEST_DEF("Handshake with Both MRP Parameters", SecurePairingHandshakeWithAllMRPTest), NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake), - NL_TEST_DEF("Serialize", SecurePairingSerializeTest), NL_TEST_DEF("PASE Verifier Serialize", PASEVerifierSerializeTest), NL_TEST_SENTINEL() diff --git a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp b/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp deleted file mode 100644 index 55ce38f24b30d8..00000000000000 --- a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -using namespace chip; - -void TestSessionIDAllocator_Free(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - - uint16_t id; - - for (uint16_t j = 0; j < 17; j++) - { - CHIP_ERROR err = allocator.Allocate(id); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, id == static_cast(i + j)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + j + 1)); - } - - // Free an intermediate ID - allocator.Free(10); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 17)); - - // Free the last allocated ID - allocator.Free(static_cast(i + 16)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 16)); - - // Free some random unallocated ID - allocator.Free(100); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 16)); -} - -void TestSessionIDAllocator_Reserve(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - uint16_t id; - - for (uint16_t j = 0; j < 17; j++) - { - CHIP_ERROR err = allocator.Allocate(id); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, id == static_cast(i + j)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + j + 1)); - } - - i = allocator.Peek(); - allocator.Reserve(static_cast(i + 100)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 101)); -} - -void TestSessionIDAllocator_ReserveUpTo(nlTestSuite * inSuite, void * inContext) -{ - SessionIDAllocator allocator; - uint16_t i = allocator.Peek(); - - i = allocator.Peek(); - allocator.Reserve(static_cast(i + 100)); - NL_TEST_ASSERT(inSuite, allocator.Peek() == static_cast(i + 101)); -} - -// Test Suite - -/** - * Test Suite that lists all the test functions. - */ -// clang-format off -static const nlTest sTests[] = -{ - NL_TEST_DEF("SessionIDAllocator_Free", TestSessionIDAllocator_Free), - NL_TEST_DEF("SessionIDAllocator_Reserve", TestSessionIDAllocator_Reserve), - NL_TEST_DEF("SessionIDAllocator_ReserveUpTo", TestSessionIDAllocator_ReserveUpTo), - - NL_TEST_SENTINEL() -}; -// clang-format on - -/** - * Set up the test suite. - */ -static int TestSetup(void * inContext) -{ - CHIP_ERROR error = chip::Platform::MemoryInit(); - if (error != CHIP_NO_ERROR) - return FAILURE; - return SUCCESS; -} - -/** - * Tear down the test suite. - */ -static int TestTeardown(void * inContext) -{ - chip::Platform::MemoryShutdown(); - return SUCCESS; -} - -// clang-format off -static nlTestSuite sSuite = -{ - "Test-CHIP-SessionIDAllocator", - &sTests[0], - TestSetup, - TestTeardown, -}; -// clang-format on - -/** - * Main - */ -int TestSessionIDAllocator() -{ - // Run test suit against one context - nlTestRunner(&sSuite, nullptr); - - return (nlTestRunnerStats(&sSuite)); -} - -CHIP_REGISTER_TEST_SUITE(TestSessionIDAllocator) diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp index cbbd8bc0c73273..ccd38d736a1116 100644 --- a/src/transport/PairingSession.cpp +++ b/src/transport/PairingSession.cpp @@ -23,6 +23,22 @@ namespace chip { +CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager, uint16_t sessionId) +{ + auto handle = sessionManager.AllocateSession(sessionId); + VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); + mSecureSessionHolder.Grab(handle.Value()); + return CHIP_NO_ERROR; +} + +CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager) +{ + auto handle = sessionManager.AllocateSession(); + VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); + mSecureSessionHolder.Grab(handle.Value()); + return CHIP_NO_ERROR; +} + CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig, TLV::TLVWriter & tlvWriter) { diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 1a8682f8578f4b..2b256857850248 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -50,11 +50,18 @@ class DLL_EXPORT PairingSession CATValues GetPeerCATs() const { return mPeerCATs; } - // TODO: the local key id should be allocateed at start - // mLocalSessionId should be const and assigned at the construction, such that GetLocalSessionId will always return a valid key - // id , and SetLocalSessionId is not necessary. - uint16_t GetLocalSessionId() const { return mLocalSessionId; } - bool IsValidLocalSessionId() const { return mLocalSessionId != kInvalidKeyId; } + Optional GetLocalSessionId() const + { + Optional localSessionId; + VerifyOrExit(mSecureSessionHolder, localSessionId = NullOptional); + VerifyOrExit(mSecureSessionHolder->GetSessionType() == Transport::Session::SessionType::kSecure, + localSessionId = Optional::Missing()); + localSessionId.SetValue(mSecureSessionHolder->AsSecureSession()->GetLocalSessionId()); + exit: + return localSessionId; + } + + auto GetSecureSessionHandle() const { return mSecureSessionHolder.ToOptional(); } uint16_t GetPeerSessionId() const { @@ -94,10 +101,31 @@ class DLL_EXPORT PairingSession TLV::TLVWriter & tlvWriter); protected: + /** + * Allocate a secure session object from the passed session manager for the + * pending session establishment operation. + * + * @param sessionManager session manager from which to allocate a secure session object + * @return CHIP_ERROR The outcome of the allocation attempt + */ + CHIP_ERROR AllocateSecureSession(SessionManager & sessionManager); + + /** + * Allocate a secure session object from the passed session manager with the + * specified session ID. + * + * This variant of the interface may be used in test scenarios where + * session IDs need to be predetermined. + + * @param sessionManager session manager from which to allocate a secure session object + * @param sessionId caller-requested session ID + * @return CHIP_ERROR The outcome of the allocation attempt + */ + CHIP_ERROR AllocateSecureSession(SessionManager & sessionManager, uint16_t sessionId); + void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } void SetPeerCATs(CATValues peerCATs) { mPeerCATs = peerCATs; } void SetPeerSessionId(uint16_t id) { mPeerSessionId.SetValue(id); } - void SetLocalSessionId(uint16_t id) { mLocalSessionId = id; } void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } virtual void OnSuccessStatusReport() {} virtual CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) @@ -170,7 +198,7 @@ class DLL_EXPORT PairingSession mPeerCATs = kUndefinedCATs; mPeerAddress = Transport::PeerAddress::Uninitialized(); mPeerSessionId.ClearValue(); - mLocalSessionId = kInvalidKeyId; + mSecureSessionHolder.Release(); } private: @@ -178,10 +206,7 @@ class DLL_EXPORT PairingSession NodeId mPeerNodeId = kUndefinedNodeId; CATValues mPeerCATs; - // TODO: the local key id should be allocateed at start - // then we can remove kInvalidKeyId - static constexpr uint16_t kInvalidKeyId = UINT16_MAX; - uint16_t mLocalSessionId = kInvalidKeyId; + SessionHolder mSecureSessionHolder; // TODO: decouple peer address into transport, such that pairing session do not need to handle peer address Transport::PeerAddress mPeerAddress = Transport::PeerAddress::Uninitialized(); diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 8146884cf15373..b8e8a423d14f29 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -61,6 +61,14 @@ class SecureSession : public Session kUndefined = 0, kPASE = 1, kCASE = 2, + // kPending denotes a secure session object that is internally + // reserved by the stack before and during session establishment. + // + // Although the stack can tolerate eviction of these (releasing one + // out from under the holder would exhibit as CHIP_ERROR_INCORRECT_STATE + // during CASE or PASE), intent is that we should not and would leave + // these untouched until CASE or PASE complete. + kPending = 3, }; SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, @@ -71,6 +79,33 @@ class SecureSession : public Session { SetFabricIndex(fabric); } + + /** + * @brief + * Construct a secure session object to associate with a pending secure + * session establishment attempt. The object for the pending session + * receives a local session ID, but no other state. + */ + SecureSession(uint16_t localSessionId) : + SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig()) + {} + + /** + * @brief + * Activate a pending Secure Session that had been reserved during CASE or + * PASE, setting internal state according to the parameters used and + * discovered during session establishment. + */ + void Activate(Type secureSessionType, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, + const ReliableMessageProtocolConfig & config) + { + mSecureSessionType = secureSessionType; + mPeerNodeId = peerNodeId; + mPeerCATs = peerCATs; + mPeerSessionId = peerSessionId; + mMRPConfig = config; + SetFabricIndex(fabric); + } ~SecureSession() override { NotifySessionReleased(); } SecureSession(SecureSession &&) = delete; @@ -141,11 +176,11 @@ class SecureSession : public Session SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; } private: - const Type mSecureSessionType; - const NodeId mPeerNodeId; - const CATValues mPeerCATs; + Type mSecureSessionType; + NodeId mPeerNodeId; + CATValues mPeerCATs; const uint16_t mLocalSessionId; - const uint16_t mPeerSessionId; + uint16_t mPeerSessionId; PeerAddress mPeerAddress; System::Clock::Timestamp mLastActivityTime; diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 9dce2198c3bbc4..d023ccb62fdbdb 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -25,10 +25,8 @@ namespace chip { namespace Transport { -// TODO; use 0xffff to match any key id, this is a temporary solution for -// InteractionModel, where key id is not obtainable. This will be removed when -// InteractionModel is migrated to messaging layer -constexpr const uint16_t kAnyKeyId = 0xffff; +constexpr uint16_t kMaxSessionID = UINT16_MAX; +constexpr uint16_t kUnsecuredSessionId = 0; /** * Handles a set of sessions. @@ -43,11 +41,13 @@ class SecureSessionTable public: ~SecureSessionTable() { mEntries.ReleaseAll(); } + void Init() { mNextSessionId = chip::Crypto::GetRandU16(); } + /** - * Allocates a new secure session out of the internal resource pool. + * Allocate a new secure session out of the internal resource pool. * * @param secureSessionType secure session type - * @param localSessionId represents the encryption key ID assigned by local node + * @param localSessionId unique identifier for the local node's secure unicast session context * @param peerNodeId represents peer Node's ID * @param peerCATs represents peer CASE Authenticated Tags * @param peerSessionId represents the encryption key ID assigned by peer node @@ -69,6 +69,57 @@ class SecureSessionTable return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } + /** + * Allocate a new secure session out of the internal resource pool with the + * specified session ID. The returned secure session will not become active + * until the call to SecureSession::Activate. If there is a resident + * session at the passed ID, an empty Optional will be returned to signal + * the error. + * + * This variant of the interface is primarily useful in testing, where + * session IDs may need to be predetermined. + * + * @param localSessionId unique identifier for the local node's secure unicast session context + * @returns allocated session, or NullOptional on failure + */ + CHECK_RETURN_VALUE + Optional CreateNewSecureSession(uint16_t localSessionId) + { + Optional rv = Optional::Missing(); + SecureSession * allocated = nullptr; + VerifyOrExit(localSessionId != kUnsecuredSessionId, rv = NullOptional); + VerifyOrExit(!FindSecureSessionByLocalKey(localSessionId).HasValue(), rv = NullOptional); + allocated = mEntries.CreateObject(localSessionId); + VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); + rv = MakeOptional(*allocated); + exit: + return rv; + } + + /** + * Allocate a new secure session out of the internal resource pool with a + * non-colliding session ID and increments mNextSessionId to give a clue to + * the allocator for the next allocation. The secure session session will + * not become active until the call to SecureSession::Activate. + * + * @returns allocated session, or NullOptional on failure + */ + CHECK_RETURN_VALUE + Optional CreateNewSecureSession() + { + Optional rv = Optional::Missing(); + auto sessionId = FindUnusedSessionId(); + SecureSession * allocated = nullptr; + VerifyOrExit(sessionId.HasValue(), rv = Optional::Missing()); + allocated = mEntries.CreateObject(sessionId.Value()); + VerifyOrExit(allocated != nullptr, rv = Optional::Missing()); + rv = MakeOptional(*allocated); + mNextSessionId = sessionId.Value() == kMaxSessionID ? static_cast(kUnsecuredSessionId + 1) + : static_cast(sessionId.Value() + 1); + exit: + return rv; + } + void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); } template @@ -78,11 +129,11 @@ class SecureSessionTable } /** - * Get a secure session given a Node Id and Peer's Encryption Key Id. + * Get a secure session given its session ID. * - * @param localSessionId Encryption key ID used by the local node. + * @param localSessionId the identifier of a secure unicast session context within the local node * - * @return the state found, nullptr if not found + * @return the session if found, NullOptional if not found */ CHECK_RETURN_VALUE Optional FindSecureSessionByLocalKey(uint16_t localSessionId) @@ -109,7 +160,8 @@ class SecureSessionTable void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback) { mEntries.ForEachActiveObject([&](auto session) { - if (session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp()) + if (session->GetSecureSessionType() != SecureSession::Type::kPending && + session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp()) { callback(*session); ReleaseSession(session); @@ -119,7 +171,76 @@ class SecureSessionTable } private: + /** + * Find an available session ID that is unused in the secure session table. + * + * The search algorithm iterates over the session ID space in the outer loop + * and the session table in the inner loop to locate an available session ID + * from the starting mNextSessionId clue. + * + * The outer-loop considers 64 session IDs in each iteration to give a + * runtime complexity of O(kMaxSessionCount^2/64). Speed up could be + * achieved with a sorted session table or additional storage. + * + * @return an unused session ID if any is found, else NullOptional + */ + CHECK_RETURN_VALUE + Optional FindUnusedSessionId() + { + uint16_t candidate_base = 0; + uint64_t candidate_mask = 0; + for (uint32_t i = 0; i <= kMaxSessionID; i += 64) + { + // candidate_base is the base session ID we are searching from. + // We have a 64-bit mask anchored at this ID and iterate over the + // whole session table, setting bits in the mask for in-use IDs. + // If we can iterate through the entire session table and have + // any bits clear in the mask, we have available session IDs. + candidate_base = static_cast(i + mNextSessionId); + candidate_mask = 0; + { + uint16_t shift = static_cast(kUnsecuredSessionId - candidate_base); + if (shift <= 63) + { + candidate_mask |= (1ULL << shift); // kUnsecuredSessionId is never available + } + } + mEntries.ForEachActiveObject([&](auto session) { + uint16_t shift = static_cast(session->GetLocalSessionId() - candidate_base); + if (shift <= 63) + { + candidate_mask |= (1ULL << shift); + } + if (candidate_mask == UINT64_MAX) + { + return Loop::Break; // No bits clear means this bucket is full. + } + return Loop::Continue; + }); + if (candidate_mask != UINT64_MAX) + { + break; // Any bit clear means we have an available ID in this bucket. + } + } + if (candidate_mask != UINT64_MAX) + { + uint16_t offset = 0; + while (candidate_mask & 1) + { + candidate_mask >>= 1; + ++offset; + } + uint16_t available = static_cast(candidate_base + offset); + return MakeOptional(available); + } + else + { + return NullOptional; + } + } + BitMapObjectPool mEntries; + uint16_t mNextSessionId = 0; }; } // namespace Transport diff --git a/src/transport/Session.h b/src/transport/Session.h index 8c1ea3291e9bc9..f10954b1f07801 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -85,6 +85,8 @@ class Session return GetSessionType() == SessionType::kGroupIncoming || GetSessionType() == SessionType::kGroupOutgoing; } + bool IsSecureSession() const { return GetSessionType() == SessionType::kSecure; } + protected: // This should be called by sub-classes at the very beginning of the destructor, before any data field is disposed, such that // the session is still functional during the callback. diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index ded017e210da9f..f30dabc9d09bcc 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -68,6 +68,10 @@ class SessionHolderWithDelegate : public SessionHolder { public: SessionHolderWithDelegate(SessionReleaseDelegate & delegate) : mDelegate(delegate) {} + SessionHolderWithDelegate(const SessionHandle & handle, SessionReleaseDelegate & delegate) : mDelegate(delegate) + { + Grab(handle); + } operator bool() const { return SessionHolder::operator bool(); } void OnSessionReleased() override diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 652193a7d93fc1..7ca5fca3234ae2 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -90,6 +90,8 @@ CHIP_ERROR SessionManager::Init(System::Layer * systemLayer, TransportMgrBase * mMessageCounterManager = messageCounterManager; mFabricTable = fabricTable; + mSecureSessions.Init(); + // TODO: Handle error from mGlobalEncryptedMessageCounter! Unit tests currently crash if you do! (void) mGlobalEncryptedMessageCounter.Init(); mGlobalUnencryptedMessageCounter.Init(); @@ -372,27 +374,39 @@ void SessionManager::ExpireAllPairingsForFabric(FabricIndex fabric) }); } -CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, - NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, - FabricIndex fabric) +Optional SessionManager::AllocateSession() { - uint16_t peerSessionId = pairing->GetPeerSessionId(); - uint16_t localSessionId = pairing->GetLocalSessionId(); - Optional session = mSecureSessions.FindSecureSessionByLocalKey(localSessionId); + return mSecureSessions.CreateNewSecureSession(); +} - // Find any existing connection with the same local key ID - if (session.HasValue()) +Optional SessionManager::AllocateSession(uint16_t sessionId) +{ + // If we forego SessionManager session ID allocation, we can have a + // collission. In case of such a collission, we must evict first. + Optional oldSession = mSecureSessions.FindSecureSessionByLocalKey(sessionId); + if (oldSession.HasValue()) { - mSecureSessions.ReleaseSession(session.Value()->AsSecureSession()); + mSecureSessions.ReleaseSession(oldSession.Value()->AsSecureSession()); } + return mSecureSessions.CreateNewSecureSession(sessionId); +} + +CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, + NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, + FabricIndex fabric) +{ + uint16_t peerSessionId = pairing->GetPeerSessionId(); + SecureSession * secureSession; + auto handle = pairing->GetSecureSessionHandle(); + VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(handle.Value()->IsSecureSession(), CHIP_ERROR_INCORRECT_STATE); + secureSession = handle.Value()->AsSecureSession(); ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", LSID:%d PSID:%d!", - ChipLogValueX64(peerNodeId), localSessionId, peerSessionId); - session = mSecureSessions.CreateNewSecureSession(pairing->GetSecureSessionType(), localSessionId, peerNodeId, - pairing->GetPeerCATs(), peerSessionId, fabric, pairing->GetMRPConfig()); - ReturnErrorCodeIf(!session.HasValue(), CHIP_ERROR_NO_MEMORY); + ChipLogValueX64(peerNodeId), secureSession->GetLocalSessionId(), peerSessionId); + secureSession->Activate(pairing->GetSecureSessionType(), peerNodeId, pairing->GetPeerCATs(), peerSessionId, fabric, + pairing->GetMRPConfig()); - Transport::SecureSession * secureSession = session.Value()->AsSecureSession(); if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) { secureSession->SetPeerAddress(peerAddr.Value()); @@ -411,7 +425,9 @@ CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optio ReturnErrorOnFailure(pairing->DeriveSecureSession(secureSession->GetCryptoContext(), direction)); secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); - sessionHolder.Grab(session.Value()); + + sessionHolder.Grab(handle.Value()); + return CHIP_NO_ERROR; } diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 87aef5649a4959..57de6e0d14b0a4 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -168,6 +168,29 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate CHIP_ERROR NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, FabricIndex fabric); + /** + * @brief + * Allocate a secure session and non-colliding session ID in the secure + * session table. + * + * @return SessionHandle with a reference to a SecureSession, else NullOptional on failure + */ + CHECK_RETURN_VALUE + Optional AllocateSession(); + + /** + * @brief + * Allocate a secure session in the secure session table at the specified + * session ID. If the session ID collides with an existing session, evict + * it. This variant of the interface may be used in test scenarios where + * session IDs need to be predetermined. + * + * @param localSessionId a unique identifier for the local node's secure unicast session context + * @return SessionHandle with a reference to a SecureSession, else NullOptional on failure + */ + CHECK_RETURN_VALUE + Optional AllocateSession(uint16_t localSessionId); + void ExpirePairing(const SessionHandle & session); void ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric); void ExpireAllPairingsForFabric(FabricIndex fabric); diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 4cda4fcaa699f3..255fe5d3bb5691 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -168,12 +168,12 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -283,12 +283,12 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -384,12 +384,12 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -497,36 +497,36 @@ void StaleConnectionDropTest(nlTestSuite * inSuite, void * inContext) SessionHolderWithDelegate session5(callback); // First pairing - SecurePairingUsingTestSecret pairing1(1, 1); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing1(1, 1, sessionManager); err = sessionManager.NewPairing(session1, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with different peer node ID and different local key ID (same peer key ID) - SecurePairingUsingTestSecret pairing2(1, 2); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing2(1, 2, sessionManager); err = sessionManager.NewPairing(session2, peer, kSourceNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with undefined node ID and different local key ID (same peer key ID) - SecurePairingUsingTestSecret pairing3(1, 3); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing3(1, 3, sessionManager); err = sessionManager.NewPairing(session3, peer, kUndefinedNodeId, &pairing3, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); // New pairing with same local key ID, and a given node ID - SecurePairingUsingTestSecret pairing4(1, 2); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing4(1, 2, sessionManager); err = sessionManager.NewPairing(session4, peer, kSourceNodeId, &pairing4, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); // New pairing with same local key ID, and undefined node ID - SecurePairingUsingTestSecret pairing5(1, 1); callback.mOldConnectionDropped = false; + SecurePairingUsingTestSecret pairing5(1, 1, sessionManager); err = sessionManager.NewPairing(session5, peer, kUndefinedNodeId, &pairing5, CryptoContext::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); @@ -590,12 +590,12 @@ void SendPacketWithOldCounterTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -704,12 +704,12 @@ void SendPacketWithTooOldCounterTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.AddNewFabric(bobFabric, &bobFabricIndex)); SessionHolder aliceToBobSession; - SecurePairingUsingTestSecret aliceToBobPairing(1, 2); + SecurePairingUsingTestSecret aliceToBobPairing(1, 2, sessionManager); err = sessionManager.NewPairing(aliceToBobSession, peer, fabricTable.FindFabricWithIndex(bobFabricIndex)->GetNodeId(), &aliceToBobPairing, CryptoContext::SessionRole::kInitiator, aliceFabricIndex); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SecurePairingUsingTestSecret bobToAlicePairing(2, 1); + SecurePairingUsingTestSecret bobToAlicePairing(2, 1, sessionManager); SessionHolder bobToAliceSession; err = sessionManager.NewPairing(bobToAliceSession, peer, fabricTable.FindFabricWithIndex(aliceFabricIndex)->GetNodeId(), &bobToAlicePairing, CryptoContext::SessionRole::kResponder, bobFabricIndex); @@ -764,6 +764,124 @@ void SendPacketWithTooOldCounterTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } +static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager & sessionManager, int max) +{ + // Allocate + free a pseudo-random number of sessions to create a + // pseudo-random offset in mNextSessionId. + const int bound = rand() % max; + for (int i = 0; i < bound; ++i) + { + auto handle = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + sessionManager.ExpirePairing(handle.Value()); + } +} + +void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) +{ + SessionManager sessionManager; + TestSessionReleaseCallback callback; + + // Allocate a session. + uint16_t sessionId1; + { + auto handle = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + SessionHolderWithDelegate session(handle.Value(), callback); + sessionId1 = session->AsSecureSession()->GetLocalSessionId(); + } + + // Allocate a session at a colliding ID, verify eviction. + { + callback.mOldConnectionDropped = false; + auto handle = sessionManager.AllocateSession(sessionId1); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + SessionHolderWithDelegate session(handle.Value(), callback); + } + + // Verify that we increment session ID by 1 for each allocation, except for + // the wraparound case where we skip session ID 0. + auto prevSessionId = sessionId1; + for (uint32_t i = 0; i < 10; ++i) + { + auto handle = sessionManager.AllocateSession(); + if (!handle.HasValue()) + { + break; + } + auto sessionId = handle.Value()->AsSecureSession()->GetLocalSessionId(); + NL_TEST_ASSERT(inSuite, sessionId - prevSessionId == 1 || (sessionId == 1 && prevSessionId == 65535)); + NL_TEST_ASSERT(inSuite, sessionId != 0); + prevSessionId = sessionId; + } + + // Reconstruct the Session Manager to reset state. + sessionManager.~SessionManager(); + new (&sessionManager) SessionManager(); + + prevSessionId = 0; + // Verify that we increment session ID by 1 for each allocation (except for + // the wraparound case where we skip session ID 0), even when allocated + // sessions are immediately freed. + for (uint32_t i = 0; i < UINT16_MAX + 10; ++i) + { + auto handle = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + auto sessionId = handle.Value()->AsSecureSession()->GetLocalSessionId(); + NL_TEST_ASSERT(inSuite, sessionId - prevSessionId == 1 || (sessionId == 1 && prevSessionId == 65535)); + NL_TEST_ASSERT(inSuite, sessionId != 0); + prevSessionId = sessionId; + sessionManager.ExpirePairing(handle.Value()); + } + + // Verify that the allocator does not give colliding IDs. + constexpr int collisionTestIterations = 1; + for (int i = 0; i < collisionTestIterations; ++i) + { + // Allocate some session handles at pseudo-random offsets in the session + // ID space. + constexpr size_t numHandles = CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE - 1; + Optional handles[numHandles]; + uint16_t sessionIds[numHandles]; + for (size_t h = 0; h < numHandles; ++h) + { + constexpr int maxOffset = 5000; + handles[h] = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, handles[h].HasValue()); + sessionIds[h] = handles[h].Value()->AsSecureSession()->GetLocalSessionId(); + RandomSessionIdAllocatorOffset(inSuite, sessionManager, maxOffset); + } + + // Verify that none collide each other. + for (size_t h = 0; h < numHandles; ++h) + { + NL_TEST_ASSERT(inSuite, sessionIds[h] != sessionIds[(h + 1) % numHandles]); + } + + // Allocate through the entire session ID space and verify that none of + // these collide either. + for (int j = 0; j < UINT16_MAX; ++j) + { + auto handle = sessionManager.AllocateSession(); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + auto potentialCollision = handle.Value()->AsSecureSession()->GetLocalSessionId(); + for (size_t h = 0; h < numHandles; ++h) + { + NL_TEST_ASSERT(inSuite, potentialCollision != sessionIds[h]); + } + sessionManager.ExpirePairing(handle.Value()); + } + + // Free our allocated sessions. + for (size_t h = 0; h < numHandles; ++h) + { + sessionManager.ExpirePairing(handles[h].Value()); + } + } + + sessionManager.Shutdown(); +} + // Test Suite /** @@ -779,6 +897,7 @@ const nlTest sTests[] = NL_TEST_DEF("Drop stale connection Test", StaleConnectionDropTest), NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), + NL_TEST_DEF("Session Allocation Test", SessionAllocationTest), NL_TEST_SENTINEL() };