diff --git a/examples/chip-tool/commands/tests/TestCommand.cpp b/examples/chip-tool/commands/tests/TestCommand.cpp index aed75cd215b068..083ed65a7b7ffb 100644 --- a/examples/chip-tool/commands/tests/TestCommand.cpp +++ b/examples/chip-tool/commands/tests/TestCommand.cpp @@ -41,7 +41,7 @@ CHIP_ERROR TestCommand::WaitForCommissionee(chip::NodeId nodeId) // or is just starting out fresh outright. Let's make sure we're not re-using any cached CASE sessions // that will now be stale and mismatched with the peer, causing subsequent interactions to fail. // - CurrentCommissioner().SessionMgr()->ExpireAllPairings(nodeId, fabricIndex); + CurrentCommissioner().SessionMgr()->ExpireAllPairings(chip::ScopedNodeId(nodeId, fabricIndex)); return CurrentCommissioner().GetConnectedDevice(nodeId, &mOnDeviceConnectedCallback, &mOnDeviceConnectionFailureCallback); } diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index 2ec8ce9780646e..f279714a1da261 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -23,12 +23,11 @@ CASEClient::CASEClient(const CASEClientInitParams & params) : mInitParams(params void CASEClient::SetMRPIntervals(const ReliableMessageProtocolConfig & mrpConfig) { - mCASESession.SetMRPConfig(mrpConfig); + mCASESession.SetRemoteMRPConfig(mrpConfig); } CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddress & peerAddress, - const ReliableMessageProtocolConfig & mrpConfig, OnCASEConnected onConnection, - OnCASEConnectionFailure onFailure, void * context) + const ReliableMessageProtocolConfig & mrpConfig, SessionEstablishmentDelegate * delegate) { // Create a UnauthenticatedSession for CASE pairing. // Don't use mSecureSession here, because mSecureSession is for encrypted communication. @@ -45,45 +44,9 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); mCASESession.SetGroupDataProvider(mInitParams.groupDataProvider); - ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, peerAddress, mInitParams.fabricInfo, - peer.GetNodeId(), exchange, mInitParams.sessionResumptionStorage, this, + ReturnErrorOnFailure(mCASESession.EstablishSession(*mInitParams.sessionManager, mInitParams.fabricInfo, peer.GetNodeId(), + exchange, mInitParams.sessionResumptionStorage, delegate, mInitParams.mrpLocalConfig)); - mConnectionSuccessCallback = onConnection; - mConnectionFailureCallback = onFailure; - mConectionContext = context; - mPeerId = peer; - mPeerAddress = peerAddress; - - return CHIP_NO_ERROR; -} - -void CASEClient::OnSessionEstablishmentError(CHIP_ERROR error) -{ - if (mConnectionFailureCallback) - { - mConnectionFailureCallback(mConectionContext, this, error); - } -} - -void CASEClient::OnSessionEstablished() -{ - // On successful CASE connection, the local session ID will be used for the derived secure session. - if (mConnectionSuccessCallback) - { - mConnectionSuccessCallback(mConectionContext, this); - } -} - -CHIP_ERROR CASEClient::DeriveSecureSessionHandle(SessionHolder & handle) -{ - CHIP_ERROR err = mInitParams.sessionManager->NewPairing( - handle, Optional::Value(mPeerAddress), mPeerId.GetNodeId(), &mCASESession, - CryptoContext::SessionRole::kInitiator, mInitParams.fabricInfo->GetFabricIndex()); - if (err != CHIP_NO_ERROR) - { - ChipLogError(Controller, "Failed in setting up CASE secure channel: err %s", ErrorStr(err)); - return err; - } return CHIP_NO_ERROR; } diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 220e920904252f..ca93d2bdea43f1 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -40,7 +40,7 @@ struct CASEClientInitParams Optional mrpLocalConfig = Optional::Missing(); }; -class DLL_EXPORT CASEClient : public SessionEstablishmentDelegate +class DLL_EXPORT CASEClient { public: CASEClient(const CASEClientInitParams & params); @@ -48,26 +48,12 @@ class DLL_EXPORT CASEClient : public SessionEstablishmentDelegate void SetMRPIntervals(const ReliableMessageProtocolConfig & mrpConfig); CHIP_ERROR EstablishSession(PeerId peer, const Transport::PeerAddress & peerAddress, - const ReliableMessageProtocolConfig & mrpConfig, OnCASEConnected onConnection, - OnCASEConnectionFailure onFailure, void * context); - - // Implementation of SessionEstablishmentDelegate - void OnSessionEstablishmentError(CHIP_ERROR error) override; - - void OnSessionEstablished() override; - - CHIP_ERROR DeriveSecureSessionHandle(SessionHolder & handle); + const ReliableMessageProtocolConfig & mrpConfig, SessionEstablishmentDelegate * delegate); private: CASEClientInitParams mInitParams; CASESession mCASESession; - PeerId mPeerId; - Transport::PeerAddress mPeerAddress; - - OnCASEConnected mConnectionSuccessCallback = nullptr; - OnCASEConnectionFailure mConnectionFailureCallback = nullptr; - void * mConectionContext = nullptr; }; } // namespace chip diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index dd037bdbeac07e..f53f885f5c7463 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -214,8 +214,8 @@ CHIP_ERROR OperationalDeviceProxy::EstablishConnection() CASEClientInitParams{ mInitParams.sessionManager, mInitParams.sessionResumptionStorage, mInitParams.exchangeMgr, mFabricInfo, mInitParams.groupDataProvider, mInitParams.mrpLocalConfig }); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); - CHIP_ERROR err = - mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this); + + CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, this); if (err != CHIP_NO_ERROR) { CleanupCASEClient(); @@ -282,50 +282,33 @@ void OperationalDeviceProxy::DequeueConnectionCallbacks(CHIP_ERROR error) } } -void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASEClient * client, CHIP_ERROR error) +void OperationalDeviceProxy::OnSessionEstablishmentError(CHIP_ERROR error) { - OperationalDeviceProxy * device = static_cast(context); - VerifyOrReturn(device->mState != State::Uninitialized && device->mState != State::NeedsAddress, + VerifyOrReturn(mState != State::Uninitialized && mState != State::NeedsAddress, ChipLogError(Controller, "HandleCASEConnectionFailure was called while the device was not initialized")); - VerifyOrReturn(client == device->mCASEClient, ChipLogError(Controller, "HandleCASEConnectionFailure for unknown CASEClient")); // // We don't need to reset the state all the way back to NeedsAddress since all that transpired // was just CASE connection failure. So let's re-use the cached address to re-do CASE again // if need-be. // - device->MoveToState(State::Initialized); + MoveToState(State::Initialized); - device->DequeueConnectionCallbacks(error); + DequeueConnectionCallbacks(error); - // - // Do not touch device instance anymore; it might have been destroyed by a failure - // callback. - // + // Do not touch device instance anymore; it might have been destroyed by a failure callback. } -void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client) +void OperationalDeviceProxy::OnSessionEstablished(const SessionHandle & session) { - OperationalDeviceProxy * device = static_cast(context); - VerifyOrReturn(device->mState != State::Uninitialized, + VerifyOrReturn(mState != State::Uninitialized, ChipLogError(Controller, "HandleCASEConnected was called while the device was not initialized")); - VerifyOrReturn(client == device->mCASEClient, ChipLogError(Controller, "HandleCASEConnected for unknown CASEClient")); - CHIP_ERROR err = client->DeriveSecureSessionHandle(device->mSecureSession); - if (err != CHIP_NO_ERROR) - { - device->HandleCASEConnectionFailure(context, client, err); - } - else - { - device->MoveToState(State::SecureConnected); - device->DequeueConnectionCallbacks(CHIP_NO_ERROR); - } + mSecureSession.Grab(session); + MoveToState(State::SecureConnected); + DequeueConnectionCallbacks(CHIP_NO_ERROR); - // - // Do not touch this instance anymore; it might have been destroyed by a - // callback. - // + // Do not touch this instance anymore; it might have been destroyed by a callback. } CHIP_ERROR OperationalDeviceProxy::Disconnect() diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 5a6243ea18d9d2..d921232bf2f850 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -85,7 +85,7 @@ typedef void (*OnDeviceConnectionFailure)(void * context, PeerId peerId, CHIP_ER * - Expose to consumers the secure session for talking to the device. */ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, - SessionReleaseDelegate, + public SessionReleaseDelegate, public SessionEstablishmentDelegate, public AddressResolve::NodeListener { @@ -140,6 +140,10 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, bool IsConnecting() const { return mState == State::Connecting; } + //////////// SessionEstablishmentDelegate Implementation /////////////// + void OnSessionEstablished(const SessionHandle & session) override; + void OnSessionEstablishmentError(CHIP_ERROR error) override; + /** * Called when a connection is closing. * The object releases all resources associated with the connection. @@ -276,9 +280,6 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, bool IsSecureConnected() const override { return mState == State::SecureConnected; } - static void HandleCASEConnected(void * context, CASEClient * client); - static void HandleCASEConnectionFailure(void * context, CASEClient * client, CHIP_ERROR error); - void CleanupCASEClient(); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, diff --git a/src/app/server/CommissioningWindowManager.cpp b/src/app/server/CommissioningWindowManager.cpp index 339935c94f228e..b3ca9b728c8f31 100644 --- a/src/app/server/CommissioningWindowManager.cpp +++ b/src/app/server/CommissioningWindowManager.cpp @@ -144,19 +144,9 @@ void CommissioningWindowManager::OnSessionEstablishmentStarted() DeviceLayer::SystemLayer().StartTimer(kPASESessionEstablishmentTimeout, HandleSessionEstablishmentTimeout, this); } -void CommissioningWindowManager::OnSessionEstablished() +void CommissioningWindowManager::OnSessionEstablished(const SessionHandle & session) { DeviceLayer::SystemLayer().CancelTimer(HandleSessionEstablishmentTimeout, this); - SessionHolder sessionHolder; - CHIP_ERROR err = mServer->GetSecureSessionManager().NewPairing( - sessionHolder, Optional::Value(mPairingSession.GetPeerAddress()), mPairingSession.GetPeerNodeId(), - &mPairingSession, CryptoContext::SessionRole::kResponder, 0); - if (err != CHIP_NO_ERROR) - { - ChipLogError(AppServer, "Commissioning failed while setting up secure channel: err %s", ErrorStr(err)); - OnSessionEstablishmentError(err); - return; - } ChipLogProgress(AppServer, "Commissioning completed session establishment step"); if (mAppDelegate != nullptr) @@ -177,7 +167,7 @@ void CommissioningWindowManager::OnSessionEstablished() } else { - err = failSafeContext.ArmFailSafe(kUndefinedFabricId, System::Clock::Seconds16(60)); + CHIP_ERROR err = failSafeContext.ArmFailSafe(kUndefinedFabricId, System::Clock::Seconds16(60)); if (err != CHIP_NO_ERROR) { ChipLogError(AppServer, "Error arming failsafe on PASE session establishment completion"); diff --git a/src/app/server/CommissioningWindowManager.h b/src/app/server/CommissioningWindowManager.h index eb5471d7d7abfa..21f41ce2d9435e 100644 --- a/src/app/server/CommissioningWindowManager.h +++ b/src/app/server/CommissioningWindowManager.h @@ -85,7 +85,7 @@ class CommissioningWindowManager : public SessionEstablishmentDelegate, public a //////////// SessionEstablishmentDelegate Implementation /////////////// void OnSessionEstablishmentError(CHIP_ERROR error) override; void OnSessionEstablishmentStarted() override; - void OnSessionEstablished() override; + void OnSessionEstablished(const SessionHandle & session) override; void Shutdown(); void Cleanup(); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 418cf48f6e6555..65bce65531efe0 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -646,7 +646,7 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re exchangeCtxt = mSystemState->ExchangeMgr()->NewContext(session.Value(), &device->GetPairing()); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); - err = device->GetPairing().Pair(*mSystemState->SessionMgr(), params.GetPeerAddress(), params.GetSetupPINCode(), + err = device->GetPairing().Pair(*mSystemState->SessionMgr(), params.GetSetupPINCode(), Optional::Value(GetLocalMRPConfig()), exchangeCtxt, this); SuccessOrExit(err); @@ -827,7 +827,7 @@ void DeviceCommissioner::OnSessionEstablishmentError(CHIP_ERROR err) RendezvousCleanup(err); } -void DeviceCommissioner::OnSessionEstablished() +void DeviceCommissioner::OnSessionEstablished(const SessionHandle & session) { // PASE session established. CommissioneeDeviceProxy * device = mDeviceInPASEEstablishment; @@ -837,12 +837,7 @@ void DeviceCommissioner::OnSessionEstablished() VerifyOrReturn(device != nullptr, OnSessionEstablishmentError(CHIP_ERROR_INVALID_DEVICE_DESCRIPTOR)); - PASESession * pairing = &device->GetPairing(); - - // TODO: the session should know which peer we are trying to connect to when started - pairing->SetPeerNodeId(device->GetDeviceId()); - - CHIP_ERROR err = device->SetConnected(); + CHIP_ERROR err = device->SetConnected(session); if (err != CHIP_NO_ERROR) { ChipLogError(Controller, "Failed in setting up secure channel: err %s", ErrorStr(err)); diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index d2aa1172e05c4f..381b8b37bf9244 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -459,7 +459,7 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, //////////// SessionEstablishmentDelegate Implementation /////////////// void OnSessionEstablishmentError(CHIP_ERROR error) override; - void OnSessionEstablished() override; + void OnSessionEstablished(const SessionHandle & session) override; void RendezvousCleanup(CHIP_ERROR status); diff --git a/src/controller/CommissioneeDeviceProxy.cpp b/src/controller/CommissioneeDeviceProxy.cpp index c7b695e2e5a387..a2c46babe8588e 100644 --- a/src/controller/CommissioneeDeviceProxy.cpp +++ b/src/controller/CommissioneeDeviceProxy.cpp @@ -76,7 +76,7 @@ CHIP_ERROR CommissioneeDeviceProxy::UpdateDeviceData(const Transport::PeerAddres // Initialize PASE session state with any MRP parameters that DNS-SD has provided. // It can be overridden by PASE session protocol messages that include MRP parameters. - mPairing.SetMRPConfig(mMRPConfig); + mPairing.SetRemoteMRPConfig(mMRPConfig); if (!mSecureSession) { @@ -93,26 +93,12 @@ CHIP_ERROR CommissioneeDeviceProxy::UpdateDeviceData(const Transport::PeerAddres return CHIP_NO_ERROR; } -CHIP_ERROR CommissioneeDeviceProxy::SetConnected() +CHIP_ERROR CommissioneeDeviceProxy::SetConnected(const SessionHandle & session) { - if (mState != ConnectionState::Connecting) - { - return CHIP_ERROR_INCORRECT_STATE; - } - - CHIP_ERROR err = mSessionManager->NewPairing(mSecureSession, Optional::Value(mDeviceAddress), - GetDeviceId(), &mPairing, CryptoContext::SessionRole::kInitiator, mFabricIndex); - - if (err == CHIP_NO_ERROR) - { - mState = ConnectionState::SecureConnected; - } - else - { - ChipLogError(Controller, "NewPairing returning error %" CHIP_ERROR_FORMAT, err.Format()); - mState = ConnectionState::NotConnected; - } - return err; + VerifyOrReturnError(mState == ConnectionState::Connecting, CHIP_ERROR_INCORRECT_STATE); + mState = ConnectionState::SecureConnected; + mSecureSession.Grab(session); + return CHIP_NO_ERROR; } void CommissioneeDeviceProxy::Reset() diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h index e33f958b0de864..f83b87116e8959 100644 --- a/src/controller/CommissioneeDeviceProxy.h +++ b/src/controller/CommissioneeDeviceProxy.h @@ -189,7 +189,7 @@ class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegat * * This stores the session details in the session manager. */ - CHIP_ERROR SetConnected(); + CHIP_ERROR SetConnected(const SessionHandle & session); bool IsSecureConnected() const override { return IsActive() && mState == ConnectionState::SecureConnected; } diff --git a/src/lib/support/logging/CHIPLogging.h b/src/lib/support/logging/CHIPLogging.h index 413b93bb6a541b..cb2018ed016050 100644 --- a/src/lib/support/logging/CHIPLogging.h +++ b/src/lib/support/logging/CHIPLogging.h @@ -407,5 +407,9 @@ bool IsCategoryEnabled(uint8_t category); */ #define ChipLogFormatMessageType "0x%x" +/** Logging helpers for scoped node ids, which is a tuple of */ +#define ChipLogFormatScopedNodeId "<" ChipLogFormatX64 ", %d>" +#define ChipLogValueScopedNodeId(id) ChipLogValueX64((id).GetNodeId()), (id).GetFabricIndex() + } // namespace Logging } // namespace chip diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 4cd351099de318..18f3a479ca1f34 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -126,6 +126,9 @@ class MessagingContext : public PlatformMemoryUser SessionHandle GetSessionAliceToBob(); SessionHandle GetSessionBobToFriends(); + const Transport::PeerAddress & GetAliceAddress() { return mAliceAddress; } + const Transport::PeerAddress & GetBobAddress() { return mBobAddress; } + Messaging::ExchangeContext * NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate); Messaging::ExchangeContext * NewUnauthenticatedExchangeToBob(Messaging::ExchangeDelegate * delegate); diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 8beed6da2accaa..7de23c2d04289c 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -127,23 +127,10 @@ void CASEServer::OnSessionEstablishmentError(CHIP_ERROR err) Cleanup(); } -void CASEServer::OnSessionEstablished() +void CASEServer::OnSessionEstablished(const SessionHandle & session) { - ChipLogProgress(Inet, "CASE Session established. Setting up the secure channel."); - mSessionManager->ExpireAllPairings(GetSession().GetPeerNodeId(), GetSession().GetFabricIndex()); - - SessionHolder sessionHolder; - CHIP_ERROR err = mSessionManager->NewPairing( - sessionHolder, Optional::Value(GetSession().GetPeerAddress()), GetSession().GetPeerNodeId(), - &GetSession(), CryptoContext::SessionRole::kResponder, GetSession().GetFabricIndex()); - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, "Failed in setting up secure channel: err %s", ErrorStr(err)); - OnSessionEstablishmentError(err); - return; - } - - ChipLogProgress(Inet, "CASE secure channel is available now."); + ChipLogProgress(Inet, "CASE Session established to peer: " ChipLogFormatScopedNodeId, + ChipLogValueScopedNodeId(session->GetPeer())); Cleanup(); } } // namespace chip diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index f0baa12dbc4324..c6ab227d57e007 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -51,7 +51,7 @@ class CASEServer : public SessionEstablishmentDelegate, //////////// SessionEstablishmentDelegate Implementation /////////////// void OnSessionEstablishmentError(CHIP_ERROR error) override; - void OnSessionEstablished() override; + void OnSessionEstablished(const SessionHandle & session) override; //// UnsolicitedMessageHandler Implementation //// CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override; diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 07c900e33a2ea2..3bd59609910af1 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -123,14 +123,32 @@ using HKDF_sha_crypto = HKDF_sha; // The session establishment fails if the response is not received within timeout window. static constexpr ExchangeContext::Timeout kSigma_Response_Timeout = System::Clock::Seconds16(30); -CASESession::CASESession() : PairingSession(Transport::SecureSession::Type::kCASE) {} - CASESession::~CASESession() { // Let's clear out any security state stored in the object, before destroying it. Clear(); } +void CASESession::Finish() +{ + mCASESessionEstablished = true; + + Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); + + // Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that. + DiscardExchange(); + + CHIP_ERROR err = ActivateSecureSession(address); + if (err == CHIP_NO_ERROR) + { + mDelegate->OnSessionEstablished(mSecureSessionHolder.Get()); + } + else + { + mDelegate->OnSessionEstablishmentError(err); + } +} + void CASESession::Clear() { // This function zeroes out and resets the memory used by the object. @@ -199,6 +217,7 @@ CASESession::ListenForSessionEstablishment(SessionManager & sessionManager, Fabr VerifyOrReturnError(fabrics != nullptr, CHIP_ERROR_INVALID_ARGUMENT); ReturnErrorOnFailure(Init(sessionManager, delegate)); + mRole = CryptoContext::SessionRole::kResponder; mFabricsTable = fabrics; mSessionResumptionStorage = sessionResumptionStorage; mLocalMRPConfig = mrpConfig; @@ -210,26 +229,21 @@ CASESession::ListenForSessionEstablishment(SessionManager & sessionManager, Fabr return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, - FabricInfo * fabric, NodeId peerNodeId, ExchangeContext * exchangeCtxt, - SessionResumptionStorage * sessionResumptionStorage, +CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, FabricInfo * fabric, NodeId peerNodeId, + ExchangeContext * exchangeCtxt, SessionResumptionStorage * sessionResumptionStorage, SessionEstablishmentDelegate * delegate, Optional mrpConfig) { MATTER_TRACE_EVENT_SCOPE("EstablishSession", "CASESession"); CHIP_ERROR err = CHIP_NO_ERROR; -#if CHIP_PROGRESS_LOGGING - char peerAddrBuff[Transport::PeerAddress::kMaxToStringSize]; - peerAddress.ToString(peerAddrBuff); - ChipLogProgress(SecureChannel, "Establishing CASE session to %s", peerAddrBuff); -#endif - // Return early on error here, as we have not initialized any state yet ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT); err = Init(sessionManager, delegate); + mRole = CryptoContext::SessionRole::kInitiator; + // We are setting the exchange context specifically before checking for error. // This is to make sure the exchange will get closed if Init() returned an error. mExchangeCtxt = exchangeCtxt; @@ -242,8 +256,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, const mLocalMRPConfig = mrpConfig; mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); - SetPeerAddress(peerAddress); - SetPeerNodeId(peerNodeId); + mPeerNodeId = peerNodeId; err = SendSigma1(); SuccessOrExit(err); @@ -269,7 +282,7 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } -CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) +CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session) const { size_t saltlen; @@ -292,7 +305,7 @@ CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoConte } ReturnErrorOnFailure(session.InitFromSecret(ByteSpan(mSharedSecret, mSharedSecret.Length()), ByteSpan(msg_salt.Get(), saltlen), - CryptoContext::SessionInfoType::kSessionEstablishment, role)); + CryptoContext::SessionInfoType::kSessionEstablishment, mRole)); return CHIP_NO_ERROR; } @@ -378,7 +391,7 @@ CHIP_ERROR CASESession::SendSigma1() MutableByteSpan destinationIdSpan(destinationIdentifier); ReturnErrorOnFailure(GenerateCaseDestinationId(ByteSpan(mIPK), ByteSpan(mInitiatorRandom), rootPubKeySpan, fabricId, - GetPeerNodeId(), destinationIdSpan)); + mPeerNodeId, destinationIdSpan)); } ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(3), destinationIdentifier, sizeof(destinationIdentifier))); @@ -396,7 +409,7 @@ CHIP_ERROR CASESession::SendSigma1() if (mSessionResumptionStorage != nullptr) { SessionResumptionStorage::ResumptionIdStorage resumptionId; - CHIP_ERROR err = mSessionResumptionStorage->FindByScopedNodeId(mFabricInfo->GetScopedNodeIdForNode(GetPeerNodeId()), + CHIP_ERROR err = mSessionResumptionStorage->FindByScopedNodeId(mFabricInfo->GetScopedNodeIdForNode(mPeerNodeId), resumptionId, mSharedSecret, mPeerCATs); if (err == CHIP_NO_ERROR) { @@ -821,27 +834,14 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) if (mSessionResumptionStorage != nullptr) { - CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), resumptionId, - mSharedSecret, mPeerCATs); + CHIP_ERROR err2 = mSessionResumptionStorage->Save(GetPeer(), resumptionId, mSharedSecret, mPeerCATs); if (err2 != CHIP_NO_ERROR) ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); } SendStatusReport(mExchangeCtxt, kProtocolCodeSuccess); - // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction - // on running out of session contexts. - - mCASESessionEstablished = true; - - // Discard the exchange so that Clear() doesn't try closing it. The - // exchange will handle that. - DiscardExchange(); - - // Call delegate to indicate session establishment is successful - // Do this last in case the delegate frees us. - mDelegate->OnSessionEstablished(); - + Finish(); exit: if (err != CHIP_NO_ERROR) { @@ -977,7 +977,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) // Verify that responderNodeId (from responderNOC) matches one that was included // in the computation of the Destination Identifier when generating Sigma1. - VerifyOrReturnError(GetPeerNodeId() == responderNodeId, CHIP_ERROR_INVALID_CASE_PARAMETER); + VerifyOrReturnError(mPeerNodeId == responderNodeId, CHIP_ERROR_INVALID_CASE_PARAMETER); // Construct msg_R2_Signed and validate the signature in msg_r2_encrypted msg_r2_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), responderNOC.size(), responderICAC.size(), @@ -1245,7 +1245,7 @@ CHIP_ERROR CASESession::HandleSigma3(System::PacketBufferHandle && msg) // Validate initiator identity located in msg->Start() // Constructing responder identity SuccessOrExit(err = ValidatePeerIdentity(initiatorNOC, initiatorICAC, initiatorNodeId, initiatorPublicKey)); - SetPeerNodeId(initiatorNodeId); + mPeerNodeId = initiatorNodeId; // Step 4 - Construct Sigma3 TBS Data msg_r3_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), initiatorNOC.size(), initiatorICAC.size(), @@ -1288,27 +1288,14 @@ CHIP_ERROR CASESession::HandleSigma3(System::PacketBufferHandle && msg) if (mSessionResumptionStorage != nullptr) { - CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), mResumptionId, - mSharedSecret, mPeerCATs); + CHIP_ERROR err2 = mSessionResumptionStorage->Save(GetPeer(), mResumptionId, mSharedSecret, mPeerCATs); if (err2 != CHIP_NO_ERROR) ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); } SendStatusReport(mExchangeCtxt, kProtocolCodeSuccess); - // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction - // on running out of session contexts. - - mCASESessionEstablished = true; - - // Discard the exchange so that Clear() doesn't try closing it. The - // exchange will handle that. - DiscardExchange(); - - // Call delegate to indicate session establishment is successful - // Do this last in case the delegate frees us. - mDelegate->OnSessionEstablished(); - + Finish(); exit: if (err != CHIP_NO_ERROR) { @@ -1503,24 +1490,13 @@ void CASESession::OnSuccessStatusReport() if (mSessionResumptionStorage != nullptr) { - CHIP_ERROR err2 = mSessionResumptionStorage->Save(ScopedNodeId(GetPeerNodeId(), GetFabricIndex()), mResumptionId, - mSharedSecret, mPeerCATs); + CHIP_ERROR err2 = mSessionResumptionStorage->Save(GetPeer(), mResumptionId, mSharedSecret, mPeerCATs); if (err2 != CHIP_NO_ERROR) ChipLogError(SecureChannel, "Unable to save session resumption state: %" CHIP_ERROR_FORMAT, err2.Format()); } - // Discard the exchange so that Clear() doesn't try closing it. The - // exchange will handle that. - DiscardExchange(); - mState = kInitialized; - - // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction - // on running out of session contexts. - - // Call delegate to indicate pairing completion. - // Do this last in case the delegate frees us. - mDelegate->OnSessionEstablished(); + Finish(); } CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 71188883f4b72a..2d15387b0d0b8a 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -60,12 +60,12 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, public PairingSession { public: - CASESession(); - CASESession(CASESession &&) = default; - CASESession(const CASESession &) = default; - ~CASESession() override; + Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kCASE; } + ScopedNodeId GetPeer() const override { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); } + CATValues GetPeerCATs() const override { return mPeerCATs; }; + /** * @brief * Initialize using configured fabrics and wait for session establishment requests. @@ -86,7 +86,6 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, * Create and send session establishment request using device's operational credentials. * * @param sessionManager session manager from which to allocate a secure session object - * @param peerAddress Address of peer with which to establish a session. * @param fabric The fabric that should be used for connecting with the peer * @param peerNodeId Node id of the peer node * @param exchangeCtxt The exchange context to send and receive messages with the peer @@ -95,9 +94,9 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, * @return CHIP_ERROR The result of initialization */ CHIP_ERROR - EstablishSession(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, FabricInfo * fabric, - NodeId peerNodeId, Messaging::ExchangeContext * exchangeCtxt, - SessionResumptionStorage * sessionResumptionStorage, SessionEstablishmentDelegate * delegate, + EstablishSession(SessionManager & sessionManager, FabricInfo * fabric, NodeId peerNodeId, + Messaging::ExchangeContext * exchangeCtxt, SessionResumptionStorage * sessionResumptionStorage, + SessionEstablishmentDelegate * delegate, Optional mrpConfig = Optional::Missing()); /** @@ -135,15 +134,12 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, /** * @brief - * Derive a secure session from the established session. The API will return error - * if called before session is established. + * Derive a secure session from the established session. The API will return error if called before session is established. * - * @param session Reference to the secure session that will be - * initialized once session establishment is complete - * @param role Role of the new session (initiator or responder) + * @param session Reference to the secure session that will be initialized once session establishment is complete * @return CHIP_ERROR The result of session derivation */ - CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; + CHIP_ERROR DeriveSecureSession(CryptoContext & session) const override; //// UnsolicitedMessageHandler Implementation //// CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override @@ -217,6 +213,9 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void OnSuccessStatusReport() override; CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) override; + // TODO: pull up Finish to PairingSession class + void Finish(); + void AbortExchange(); /** @@ -253,6 +252,8 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, FabricTable * mFabricsTable = nullptr; const FabricInfo * mFabricInfo = nullptr; + NodeId mPeerNodeId = kUndefinedNodeId; + CATValues mPeerCATs; // This field is only used for CASE responder, when during sending sigma2 and waiting for sigma3 SessionResumptionStorage::ResumptionIdStorage mResumptionId; @@ -261,8 +262,6 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, State mState; - Optional mLocalMRPConfig; - protected: bool mCASESessionEstablished = false; }; diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index d1b7e1a6cced7e..cbe67852f698d1 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -63,14 +63,32 @@ const char * kSpake2pR2ISessionInfo = "Commissioning R2I Key"; // The session establishment fails if the response is not received with in timeout window. static constexpr ExchangeContext::Timeout kSpake2p_Response_Timeout = System::Clock::Seconds16(30); -PASESession::PASESession() : PairingSession(Transport::SecureSession::Type::kPASE) {} - PASESession::~PASESession() { // Let's clear out any security state stored in the object, before destroying it. Clear(); } +void PASESession::Finish() +{ + mPairingComplete = true; + + Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress(); + + // Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that. + DiscardExchange(); + + CHIP_ERROR err = ActivateSecureSession(address); + if (err == CHIP_NO_ERROR) + { + mDelegate->OnSessionEstablished(mSecureSessionHolder.Get()); + } + else + { + mDelegate->OnSessionEstablishmentError(err); + } +} + void PASESession::Clear() { // This function zeroes out and resets the memory used by the object. @@ -181,6 +199,8 @@ CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Sp // been initialized SuccessOrExit(err); + mRole = CryptoContext::SessionRole::kResponder; + VerifyOrExit(CanCastTo(salt.size()), err = CHIP_ERROR_INVALID_ARGUMENT); mSaltLength = static_cast(salt.size()); @@ -201,8 +221,6 @@ CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Sp mPairingComplete = false; mLocalMRPConfig = mrpConfig; - SetPeerNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId)); - ChipLogDetail(SecureChannel, "Waiting for PBKDF param request"); exit: @@ -213,7 +231,7 @@ CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Sp return err; } -CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, +CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUpPINCode, Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) { @@ -222,13 +240,12 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, const Transport::P CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate); SuccessOrExit(err); + mRole = CryptoContext::SessionRole::kInitiator; + mExchangeCtxt = exchangeCtxt; mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); - SetPeerAddress(peerAddress); - mLocalMRPConfig = mrpConfig; - SetPeerNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId)); err = SendPBKDFParamRequest(); SuccessOrExit(err); @@ -258,11 +275,11 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec) mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } -CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) +CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session) const { VerifyOrReturnError(mPairingComplete, CHIP_ERROR_INCORRECT_STATE); return session.InitFromSecret(ByteSpan(mKe, mKeLen), ByteSpan(nullptr, 0), - CryptoContext::SessionInfoType::kSessionEstablishment, role); + CryptoContext::SessionInfoType::kSessionEstablishment, mRole); } CHIP_ERROR PASESession::SendPBKDFParamRequest() @@ -743,16 +760,7 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) // Send confirmation to peer that we succeeded so they can start using the session. SendStatusReport(mExchangeCtxt, kProtocolCodeSuccess); - mPairingComplete = true; - - // Discard the exchange so that Clear() doesn't try closing it. The - // exchange will handle that. - DiscardExchange(); - - // Call delegate to indicate pairing completion - // Do this last in case the delegate frees us. - mDelegate->OnSessionEstablished(); - + Finish(); exit: if (err != CHIP_NO_ERROR) @@ -764,15 +772,7 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) void PASESession::OnSuccessStatusReport() { - mPairingComplete = true; - - // Discard the exchange so that Clear() doesn't try closing it. The - // exchange will handle that. - DiscardExchange(); - - // Call delegate to indicate pairing completion - // Do this last in case the delegate frees us. - mDelegate->OnSessionEstablished(); + Finish(); } CHIP_ERROR PASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index a1b5a316c0f209..dfa3b826bdaeec 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -71,15 +71,14 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, public PairingSession { public: - PASESession(); - PASESession(PASESession &&) = default; - PASESession(const PASESession &) = delete; - ~PASESession() override; - // TODO: The SetPeerNodeId method should not be exposed; PASE sessions - // should not need to be told their peer node ID - using PairingSession::SetPeerNodeId; + Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kPASE; } + ScopedNodeId GetPeer() const override + { + return ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex); + } + CATValues GetPeerCATs() const override { return CATValues(); }; CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override; @@ -104,7 +103,6 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, * Create a pairing request using peer's setup PIN code. * * @param sessionManager session manager from which to allocate a secure session object - * @param peerAddress Address of peer to pair * @param peerSetUpPINCode Setup PIN code of the peer device * @param exchangeCtxt The exchange context to send and receive messages with the peer * Note: It's expected that the caller of this API hands over the @@ -114,9 +112,8 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, * * @return CHIP_ERROR The result of initialization */ - CHIP_ERROR Pair(SessionManager & sessionManager, const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, - Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, - SessionEstablishmentDelegate * delegate); + CHIP_ERROR Pair(SessionManager & sessionManager, uint32_t peerSetUpPINCode, Optional mrpConfig, + Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); /** * @brief @@ -135,15 +132,12 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, /** * @brief - * Derive a secure session from the paired session. The API will return error - * if called before pairing is established. + * Derive a secure session from the paired session. The API will return error if called before pairing is established. * - * @param session Reference to the secure session that will be - * initialized once pairing is complete - * @param role Role of the new session (initiator or responder) + * @param session Reference to the secure session that will be initialized once pairing is complete * @return CHIP_ERROR The result of session derivation */ - CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; + CHIP_ERROR DeriveSecureSession(CryptoContext & session) const override; // TODO: remove Clear, we should create a new instance instead reset the old instance. /** @brief This function zeroes out and resets the memory used by the object. @@ -207,6 +201,9 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, void OnSuccessStatusReport() override; CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) override; + // TODO: pull up Finish to PairingSession class + void Finish(); + void CloseExchange(); /** @@ -240,8 +237,6 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, Messaging::ExchangeContext * mExchangeCtxt = nullptr; - Optional mLocalMRPConfig; - struct Spake2pErrorMsg { Spake2pErrorType error; diff --git a/src/protocols/secure_channel/SessionEstablishmentDelegate.h b/src/protocols/secure_channel/SessionEstablishmentDelegate.h index 1c38c8538cc1a3..12753872180359 100644 --- a/src/protocols/secure_channel/SessionEstablishmentDelegate.h +++ b/src/protocols/secure_channel/SessionEstablishmentDelegate.h @@ -47,7 +47,7 @@ class DLL_EXPORT SessionEstablishmentDelegate /** * Called when the new secure session has been established */ - virtual void OnSessionEstablished() {} + virtual void OnSessionEstablished(const SessionHandle & session) {} virtual ~SessionEstablishmentDelegate() {} }; diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 601a9a20062449..44f8a4c68354c5 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -75,7 +75,15 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate public: void OnSessionEstablishmentError(CHIP_ERROR error) override { mNumPairingErrors++; } - void OnSessionEstablished() override { mNumPairingComplete++; } + void OnSessionEstablished(const SessionHandle & session) override + { + mSession.Grab(session); + mNumPairingComplete++; + } + + SessionHolder & GetSessionHolder() { return mSession; } + + SessionHolder mSession; // TODO: Rename mNumPairing* to mNumEstablishment* uint32_t mNumPairingErrors = 0; @@ -182,9 +190,6 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) SessionManager sessionManager; NL_TEST_ASSERT(inSuite, pairing.GetSecureSessionType() == SecureSession::Type::kCASE); - CATValues peerCATs; - peerCATs = pairing.GetPeerCATs(); - NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kUndefinedCATs, sizeof(CATValues)) == 0); pairing.SetGroupDataProvider(&gDeviceGroupDataProvider); NL_TEST_ASSERT(inSuite, @@ -210,18 +215,15 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, - nullptr, nullptr, nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, nullptr, Node01_01, nullptr, nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - nullptr, nullptr, nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, fabric, Node01_01, nullptr, nullptr, nullptr) != CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - context, nullptr, &delegate) == CHIP_NO_ERROR); + pairing.EstablishSession(sessionManager, fabric, Node01_01, context, nullptr, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -241,8 +243,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - context1, nullptr, &delegate) == CHIP_ERROR_BAD_REQUEST); + pairing1.EstablishSession(sessionManager, fabric, Node01_01, context1, nullptr, &delegate) == + CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; @@ -274,8 +276,7 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte pairingAccessory.ListenForSessionEstablishment(sessionManager, &gDeviceFabrics, nullptr, &delegateAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingCommissioner.EstablishSession(sessionManager, Transport::PeerAddress(Transport::Type::kBle), fabric, - Node01_01, contextCommissioner, nullptr, + pairingCommissioner.EstablishSession(sessionManager, fabric, Node01_01, contextCommissioner, nullptr, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -323,29 +324,25 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, fabric != nullptr); NL_TEST_ASSERT(inSuite, - pairingCommissioner->EstablishSession(ctx.GetSecureSessionManager(), - Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - contextCommissioner, nullptr, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner->EstablishSession(ctx.GetSecureSessionManager(), fabric, Node01_01, contextCommissioner, + nullptr, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 5); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); - // Validate that secure session can be created after the pairing - SessionHolder sessionHolder; - NL_TEST_ASSERT(inSuite, - ctx.GetSecureSessionManager().NewPairing(sessionHolder, NullOptional, Node01_01, pairingCommissioner, - CryptoContext::SessionRole::kInitiator, - gCommissionerFabricIndex) == CHIP_NO_ERROR); + // Validate that secure session is created + SessionHolder & holder = delegateCommissioner.GetSessionHolder(); + NL_TEST_ASSERT(inSuite, bool(holder)); + NL_TEST_ASSERT(inSuite, holder->GetPeer() == fabric->GetScopedNodeIdForNode(Node01_01)); auto * pairingCommissioner1 = chip::Platform::New(); pairingCommissioner1->SetGroupDataProvider(&gCommissionerGroupDataProvider); ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner1); NL_TEST_ASSERT(inSuite, - pairingCommissioner1->EstablishSession(ctx.GetSecureSessionManager(), - Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, - contextCommissioner1, nullptr, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner1->EstablishSession(ctx.GetSecureSessionManager(), fabric, Node01_01, contextCommissioner1, + nullptr, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); chip::Platform::Delete(pairingCommissioner); diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 4cc6ffc1cf5a7a..598dee3c53dd5d 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -93,7 +93,7 @@ class TestSecurePairingDelegate : public SessionEstablishmentDelegate public: void OnSessionEstablishmentError(CHIP_ERROR error) override { mNumPairingErrors++; } - void OnSessionEstablished() override { mNumPairingComplete++; } + void OnSessionEstablished(const SessionHandle & session) override { mNumPairingComplete++; } uint32_t mNumPairingErrors = 0; uint32_t mNumPairingComplete = 0; @@ -167,13 +167,13 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, - Optional::Missing(), nullptr, nullptr) != CHIP_NO_ERROR); + pairing.Pair(sessionManager, sTestSpake2p01_PinCode, Optional::Missing(), nullptr, + nullptr) != CHIP_NO_ERROR); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, - Optional::Missing(), context, &delegate) == CHIP_NO_ERROR); + pairing.Pair(sessionManager, sTestSpake2p01_PinCode, Optional::Missing(), context, + &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -190,9 +190,8 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) PASESession pairing1; ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, - Optional::Missing(), context1, - &delegate) == CHIP_ERROR_BAD_REQUEST); + pairing1.Pair(sessionManager, sTestSpake2p01_PinCode, Optional::Missing(), + context1, &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; @@ -237,8 +236,8 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), sTestSpake2p01_PinCode, - mrpCommissionerConfig, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(sessionManager, sTestSpake2p01_PinCode, mrpCommissionerConfig, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); while (gLoopback.mMessageDropped) @@ -260,17 +259,21 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P if (mrpCommissionerConfig.HasValue()) { NL_TEST_ASSERT(inSuite, - pairingAccessory.GetMRPConfig().mIdleRetransTimeout == mrpCommissionerConfig.Value().mIdleRetransTimeout); - NL_TEST_ASSERT( - inSuite, pairingAccessory.GetMRPConfig().mActiveRetransTimeout == mrpCommissionerConfig.Value().mActiveRetransTimeout); + pairingAccessory.GetRemoteMRPConfig().mIdleRetransTimeout == + mrpCommissionerConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, + pairingAccessory.GetRemoteMRPConfig().mActiveRetransTimeout == + mrpCommissionerConfig.Value().mActiveRetransTimeout); } if (mrpAccessoryConfig.HasValue()) { NL_TEST_ASSERT(inSuite, - pairingCommissioner.GetMRPConfig().mIdleRetransTimeout == mrpAccessoryConfig.Value().mIdleRetransTimeout); - NL_TEST_ASSERT( - inSuite, pairingCommissioner.GetMRPConfig().mActiveRetransTimeout == mrpAccessoryConfig.Value().mActiveRetransTimeout); + pairingCommissioner.GetRemoteMRPConfig().mIdleRetransTimeout == + mrpAccessoryConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, + pairingCommissioner.GetRemoteMRPConfig().mActiveRetransTimeout == + mrpAccessoryConfig.Value().mActiveRetransTimeout); } } @@ -366,9 +369,8 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(sessionManager, Transport::PeerAddress(Transport::Type::kBle), 4321, - Optional::Missing(), contextCommissioner, - &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(sessionManager, 4321, Optional::Missing(), + contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 0); diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp index ccd38d736a1116..275cf0a3eb691e 100644 --- a/src/transport/PairingSession.cpp +++ b/src/transport/PairingSession.cpp @@ -23,19 +23,26 @@ namespace chip { -CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager, uint16_t sessionId) +CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager) { - auto handle = sessionManager.AllocateSession(sessionId); + auto handle = sessionManager.AllocateSession(); VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); mSecureSessionHolder.Grab(handle.Value()); return CHIP_NO_ERROR; } -CHIP_ERROR PairingSession::AllocateSecureSession(SessionManager & sessionManager) +CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress & peerAddress) { - auto handle = sessionManager.AllocateSession(); - VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_NO_MEMORY); - mSecureSessionHolder.Grab(handle.Value()); + Transport::SecureSession * secureSession = mSecureSessionHolder->AsSecureSession(); + + uint16_t peerSessionId = GetPeerSessionId(); + ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!", + ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId); + secureSession->Activate(GetSecureSessionType(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig); + secureSession->SetPeerAddress(peerAddress); + ReturnErrorOnFailure(DeriveSecureSession(secureSession->GetCryptoContext())); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + return CHIP_NO_ERROR; } @@ -74,7 +81,7 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TL if (TLV::TagNumFromTag(tlvReader.GetTag()) == 1) { ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); - mMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + mRemoteMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); // The next element is optional. If it's not present, return CHIP_NO_ERROR. CHIP_ERROR err = tlvReader.Next(); @@ -87,7 +94,7 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TL VerifyOrReturnError(TLV::TagNumFromTag(tlvReader.GetTag()) == 2, CHIP_ERROR_INVALID_TLV_TAG); ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); - mMRPConfig.mActiveRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + mRemoteMRPConfig.mActiveRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); return tlvReader.ExitContainer(containerType); } diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index fee363f0ec714b..fc586e15897a2d 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -38,17 +38,11 @@ namespace chip { class DLL_EXPORT PairingSession { public: - PairingSession(Transport::SecureSession::Type secureSessionType) : mSecureSessionType(secureSessionType) {} virtual ~PairingSession() {} - Transport::SecureSession::Type GetSecureSessionType() const { return mSecureSessionType; } - - // TODO: the session should know which peer we are trying to connect to at start - // mPeerNodeId should be const and assigned at the construction, such that GetPeerNodeId will never return kUndefinedNodeId, and - // SetPeerNodeId is not necessary. - NodeId GetPeerNodeId() const { return mPeerNodeId; } - - CATValues GetPeerCATs() const { return mPeerCATs; } + virtual Transport::SecureSession::Type GetSecureSessionType() const = 0; + virtual ScopedNodeId GetPeer() const = 0; + virtual CATValues GetPeerCATs() const = 0; Optional GetLocalSessionId() const { @@ -61,8 +55,6 @@ class DLL_EXPORT PairingSession return localSessionId; } - auto GetSecureSessionHandle() const { return mSecureSessionHolder.ToOptional(); } - uint16_t GetPeerSessionId() const { VerifyOrDie(mPeerSessionId.HasValue()); @@ -70,29 +62,17 @@ class DLL_EXPORT PairingSession } bool IsValidPeerSessionId() const { return mPeerSessionId.HasValue(); } - // TODO: decouple peer address into transport, such that pairing session do not need to handle peer address - const Transport::PeerAddress & GetPeerAddress() const { return mPeerAddress; } - Transport::PeerAddress & GetPeerAddress() { return mPeerAddress; } - /** * @brief - * Derive a secure session from the paired session. The API will return error - * if called before pairing is established. + * Derive a secure session from the paired session. The API will return error if called before pairing is established. * - * @param session Reference to the secure session that will be - * initialized once pairing is complete - * @param role Role of the new session (initiator or responder) + * @param session Reference to the secure session that will be initialized once pairing is complete * @return CHIP_ERROR The result of session derivation */ - virtual CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) = 0; - - /** - * @brief - * Get the MRP config that was communicated during the session establishment. - */ - virtual const ReliableMessageProtocolConfig & GetMRPConfig() const { return mMRPConfig; } + virtual CHIP_ERROR DeriveSecureSession(CryptoContext & session) const = 0; - void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; } + const ReliableMessageProtocolConfig & GetRemoteMRPConfig() const { return mRemoteMRPConfig; } + void SetRemoteMRPConfig(const ReliableMessageProtocolConfig & config) { mRemoteMRPConfig = config; } /** * Encode the provided MRP parameters using the provided TLV tag. @@ -110,22 +90,9 @@ class DLL_EXPORT PairingSession */ CHIP_ERROR AllocateSecureSession(SessionManager & sessionManager); - /** - * Allocate a secure session object from the passed session manager with the - * specified session ID. - * - * This variant of the interface may be used in test scenarios where - * session IDs need to be predetermined. - - * @param sessionManager session manager from which to allocate a secure session object - * @param sessionId caller-requested session ID - * @return CHIP_ERROR The outcome of the allocation attempt - */ - CHIP_ERROR AllocateSecureSession(SessionManager & sessionManager, uint16_t sessionId); + CHIP_ERROR ActivateSecureSession(const Transport::PeerAddress & peerAddress); - void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } void SetPeerSessionId(uint16_t id) { mPeerSessionId.SetValue(id); } - void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } virtual void OnSuccessStatusReport() {} virtual CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) { @@ -180,7 +147,7 @@ class DLL_EXPORT PairingSession /** * Try to decode the current element (pointed by the TLV reader) as MRP parameters. - * If the MRP parameters are found, mMRPConfig is updated with the devoded values. + * If the MRP parameters are found, mRemoteMRPConfig is updated with the devoded values. * * MRP parameters are optional. So, if the TLV reader is not pointing to the MRP parameters, * the function is a noop. @@ -193,29 +160,21 @@ class DLL_EXPORT PairingSession // TODO: remove Clear, we should create a new instance instead reset the old instance. void Clear() { - mPeerNodeId = kUndefinedNodeId; - mPeerCATs = kUndefinedCATs; - mPeerAddress = Transport::PeerAddress::Uninitialized(); mPeerSessionId.ClearValue(); mSecureSessionHolder.Release(); } -private: - const Transport::SecureSession::Type mSecureSessionType; - protected: - NodeId mPeerNodeId = kUndefinedNodeId; - CATValues mPeerCATs; - -private: + CryptoContext::SessionRole mRole; SessionHolder mSecureSessionHolder; - // TODO: decouple peer address into transport, such that pairing session do not need to handle peer address - Transport::PeerAddress mPeerAddress = Transport::PeerAddress::Uninitialized(); + // mLocalMRPConfig is our config which is sent to the other end and used by the peer session. + // mRemoteMRPConfig is received from other end and set to our session. + Optional mLocalMRPConfig; + ReliableMessageProtocolConfig mRemoteMRPConfig = GetLocalMRPConfig(); +private: Optional mPeerSessionId; - - ReliableMessageProtocolConfig mMRPConfig = GetLocalMRPConfig(); }; } // namespace chip diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 1db5b6a550a697..a579c58ea001d2 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -97,15 +97,15 @@ class SecureSession : public Session * PASE, setting internal state according to the parameters used and * discovered during session establishment. */ - void Activate(Type secureSessionType, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, + void Activate(Type secureSessionType, const ScopedNodeId & peer, CATValues peerCATs, uint16_t peerSessionId, const ReliableMessageProtocolConfig & config) { mSecureSessionType = secureSessionType; - mPeerNodeId = peerNodeId; + mPeerNodeId = peer.GetNodeId(); mPeerCATs = peerCATs; mPeerSessionId = peerSessionId; mMRPConfig = config; - SetFabricIndex(fabric); + SetFabricIndex(peer.GetFabricIndex()); } ~SecureSession() override { NotifySessionReleased(); } diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index c71431fe363b5b..f66f47d0964c91 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -349,10 +349,10 @@ void SessionManager::ExpirePairing(const SessionHandle & sessionHandle) mSecureSessions.ReleaseSession(sessionHandle->AsSecureSession()); } -void SessionManager::ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric) +void SessionManager::ExpireAllPairings(const ScopedNodeId & node) { mSecureSessions.ForEachSession([&](auto session) { - if (session->GetPeerNodeId() == peerNodeId && session->GetFabricIndex() == fabric) + if (session->GetPeer() == node) { mSecureSessions.ReleaseSession(session); } @@ -389,18 +389,6 @@ Optional SessionManager::AllocateSession() return mSecureSessions.CreateNewSecureSession(); } -Optional SessionManager::AllocateSession(uint16_t sessionId) -{ - // If we forego SessionManager session ID allocation, we can have a - // collission. In case of such a collission, we must evict first. - Optional oldSession = mSecureSessions.FindSecureSessionByLocalKey(sessionId); - if (oldSession.HasValue()) - { - mSecureSessions.ReleaseSession(oldSession.Value()->AsSecureSession()); - } - return mSecureSessions.CreateNewSecureSession(sessionId); -} - CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabric, const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role) @@ -421,46 +409,6 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH return CHIP_NO_ERROR; } -CHIP_ERROR SessionManager::NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, - NodeId peerNodeId, PairingSession * pairing, CryptoContext::SessionRole direction, - FabricIndex fabric) -{ - uint16_t peerSessionId = pairing->GetPeerSessionId(); - SecureSession * secureSession; - auto handle = pairing->GetSecureSessionHandle(); - VerifyOrReturnError(handle.HasValue(), CHIP_ERROR_INCORRECT_STATE); - VerifyOrReturnError(handle.Value()->IsSecureSession(), CHIP_ERROR_INCORRECT_STATE); - secureSession = handle.Value()->AsSecureSession(); - - ChipLogDetail(Inet, "New secure session created for device 0x" ChipLogFormatX64 ", LSID:%d PSID:%d!", - ChipLogValueX64(peerNodeId), secureSession->GetLocalSessionId(), peerSessionId); - secureSession->Activate(pairing->GetSecureSessionType(), peerNodeId, pairing->GetPeerCATs(), peerSessionId, fabric, - pairing->GetMRPConfig()); - - if (peerAddr.HasValue() && peerAddr.Value().GetIPAddress() != Inet::IPAddress::Any) - { - secureSession->SetPeerAddress(peerAddr.Value()); - } - else if (peerAddr.HasValue() && peerAddr.Value().GetTransportType() == Transport::Type::kBle) - { - secureSession->SetPeerAddress(peerAddr.Value()); - } - else if (peerAddr.HasValue() && - (peerAddr.Value().GetTransportType() == Transport::Type::kTcp || - peerAddr.Value().GetTransportType() == Transport::Type::kUdp)) - { - return CHIP_ERROR_INVALID_ARGUMENT; - } - - ReturnErrorOnFailure(pairing->DeriveSecureSession(secureSession->GetCryptoContext(), direction)); - - secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); - - sessionHolder.Grab(handle.Value()); - - return CHIP_NO_ERROR; -} - void SessionManager::ScheduleExpiryTimer() { CHIP_ERROR err = mSystemLayer->StartTimer(System::Clock::Milliseconds32(CHIP_PEER_CONNECTION_TIMEOUT_CHECK_FREQUENCY_MS), diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index ce94d402482d25..dda5cebaab709d 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -161,18 +161,6 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate uint16_t peerSessionId, FabricIndex fabric, const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role); - /** - * @brief - * Establish a new pairing with a peer node - * - * @details - * This method sets up a new pairing with the peer node. It also - * establishes the security keys for secure communication with the - * peer node. - */ - CHIP_ERROR NewPairing(SessionHolder & sessionHolder, const Optional & peerAddr, NodeId peerNodeId, - PairingSession * pairing, CryptoContext::SessionRole direction, FabricIndex fabric); - /** * @brief * Allocate a secure session and non-colliding session ID in the secure @@ -183,21 +171,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate CHECK_RETURN_VALUE Optional AllocateSession(); - /** - * @brief - * Allocate a secure session in the secure session table at the specified - * session ID. If the session ID collides with an existing session, evict - * it. This variant of the interface may be used in test scenarios where - * session IDs need to be predetermined. - * - * @param localSessionId a unique identifier for the local node's secure unicast session context - * @return SessionHandle with a reference to a SecureSession, else NullOptional on failure - */ - CHECK_RETURN_VALUE - Optional AllocateSession(uint16_t localSessionId); - void ExpirePairing(const SessionHandle & session); - void ExpireAllPairings(NodeId peerNodeId, FabricIndex fabric); + void ExpireAllPairings(const ScopedNodeId & node); void ExpireAllPairingsForFabric(FabricIndex fabric); void ExpireAllPASEPairings(); diff --git a/src/transport/tests/TestPairingSession.cpp b/src/transport/tests/TestPairingSession.cpp index 6a9f89756abb83..6ffd2037f02f06 100644 --- a/src/transport/tests/TestPairingSession.cpp +++ b/src/transport/tests/TestPairingSession.cpp @@ -41,9 +41,13 @@ using namespace chip::System::Clock; class TestPairingSession : public PairingSession { public: - TestPairingSession(Transport::SecureSession::Type secureSessionType) : PairingSession(secureSessionType) {} + Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kPASE; } + ScopedNodeId GetPeer() const override { return ScopedNodeId(); } + CATValues GetPeerCATs() const override { return CATValues(); }; - CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override { return CHIP_NO_ERROR; } + const ReliableMessageProtocolConfig & GetRemoteMRPConfig() const { return mRemoteMRPConfig; } + + CHIP_ERROR DeriveSecureSession(CryptoContext & session) const override { return CHIP_NO_ERROR; } CHIP_ERROR DecodeMRPParametersIfPresent(TLV::Tag expectedTag, System::PacketBufferTLVReader & tlvReader) { @@ -53,7 +57,7 @@ class TestPairingSession : public PairingSession void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext) { - TestPairingSession session(Transport::SecureSession::Type::kCASE); + TestPairingSession session; ReliableMessageProtocolConfig config(Milliseconds32(100), Milliseconds32(200)); @@ -80,13 +84,13 @@ void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(1), reader) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == config.mIdleRetransTimeout); - NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == config.mActiveRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetRemoteMRPConfig().mIdleRetransTimeout == config.mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetRemoteMRPConfig().mActiveRetransTimeout == config.mActiveRetransTimeout); } void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inContext) { - TestPairingSession session(Transport::SecureSession::Type::kPASE); + TestPairingSession session; System::PacketBufferHandle buf = System::PacketBufferHandle::New(64, 0); System::PacketBufferTLVWriter writer; @@ -108,8 +112,8 @@ void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inCon NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(2), reader) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == GetLocalMRPConfig().mIdleRetransTimeout); - NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == GetLocalMRPConfig().mActiveRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetRemoteMRPConfig().mIdleRetransTimeout == GetLocalMRPConfig().mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetRemoteMRPConfig().mActiveRetransTimeout == GetLocalMRPConfig().mActiveRetransTimeout); } // Test Suite diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index cb7e964e6ead19..9f5cb44ea0bd80 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -465,113 +465,6 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } -class SecurePairingUsingTestSecret : public PairingSession -{ -public: - SecurePairingUsingTestSecret() : PairingSession(Transport::SecureSession::Type::kPASE) - { - // Do not set to 0 to prevent an unwanted unsecured session - // since the session type is unknown. - SetPeerSessionId(1); - } - - void Init(SessionManager & sessionManager) - { - // Do not set to 0 to prevent an unwanted unsecured session - // since the session type is unknown. - AllocateSecureSession(sessionManager, mLocalSessionId); - } - - SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId, SessionManager & sessionManager) : - PairingSession(Transport::SecureSession::Type::kPASE), mLocalSessionId(localSessionId) - { - AllocateSecureSession(sessionManager, localSessionId); - SetPeerSessionId(peerSessionId); - } - - CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override - { - size_t secretLen = strlen(kTestSecret); - return session.InitFromSecret(ByteSpan(reinterpret_cast(kTestSecret), secretLen), ByteSpan(nullptr, 0), - CryptoContext::SessionInfoType::kSessionEstablishment, role); - } - -private: - // Do not set to 0 to prevent an unwanted unsecured session - // since the session type is unknown. - uint16_t mLocalSessionId = 1; - const char * kTestSecret = CHIP_CONFIG_TEST_SHARED_SECRET_VALUE; -}; - -void StaleConnectionDropTest(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - constexpr NodeId kSourceNodeId = 123654; - - IPAddress addr; - IPAddress::FromString("::1", addr); - CHIP_ERROR err = CHIP_NO_ERROR; - - TransportMgr transportMgr; - FabricTable fabricTable; - SessionManager sessionManager; - secure_channel::MessageCounterManager gMessageCounterManager; - chip::TestPersistentStorageDelegate deviceStorage; - - NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == transportMgr.Init("LOOPBACK")); - NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTable.Init(&deviceStorage)); - NL_TEST_ASSERT( - inSuite, - CHIP_NO_ERROR == - sessionManager.Init(&ctx.GetSystemLayer(), &transportMgr, &gMessageCounterManager, &deviceStorage, &fabricTable)); - - Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - TestSessionReleaseCallback callback; - SessionHolderWithDelegate session1(callback); - SessionHolderWithDelegate session2(callback); - SessionHolderWithDelegate session3(callback); - SessionHolderWithDelegate session4(callback); - SessionHolderWithDelegate session5(callback); - - // First pairing - callback.mOldConnectionDropped = false; - SecurePairingUsingTestSecret pairing1(1, 1, sessionManager); - err = sessionManager.NewPairing(session1, peer, kSourceNodeId, &pairing1, CryptoContext::SessionRole::kInitiator, 1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); - - // New pairing with different peer node ID and different local key ID (same peer key ID) - callback.mOldConnectionDropped = false; - SecurePairingUsingTestSecret pairing2(1, 2, sessionManager); - err = sessionManager.NewPairing(session2, peer, kSourceNodeId, &pairing2, CryptoContext::SessionRole::kResponder, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); - - // New pairing with undefined node ID and different local key ID (same peer key ID) - callback.mOldConnectionDropped = false; - SecurePairingUsingTestSecret pairing3(1, 3, sessionManager); - err = sessionManager.NewPairing(session3, peer, kUndefinedNodeId, &pairing3, CryptoContext::SessionRole::kResponder, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, !callback.mOldConnectionDropped); - - // New pairing with same local key ID, and a given node ID - callback.mOldConnectionDropped = false; - SecurePairingUsingTestSecret pairing4(1, 2, sessionManager); - err = sessionManager.NewPairing(session4, peer, kSourceNodeId, &pairing4, CryptoContext::SessionRole::kResponder, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); - - // New pairing with same local key ID, and undefined node ID - callback.mOldConnectionDropped = false; - SecurePairingUsingTestSecret pairing5(1, 1, sessionManager); - err = sessionManager.NewPairing(session5, peer, kUndefinedNodeId, &pairing5, CryptoContext::SessionRole::kResponder, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, callback.mOldConnectionDropped); - - sessionManager.Shutdown(); -} - void SendPacketWithOldCounterTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -829,14 +722,6 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) sessionId1 = session->AsSecureSession()->GetLocalSessionId(); } - // Allocate a session at a colliding ID, verify eviction. - { - callback.mOldConnectionDropped = false; - auto handle = sessionManager.AllocateSession(sessionId1); - NL_TEST_ASSERT(inSuite, handle.HasValue()); - SessionHolderWithDelegate session(handle.Value(), callback); - } - // Verify that we increment session ID by 1 for each allocation, except for // the wraparound case where we skip session ID 0. auto prevSessionId = sessionId1; @@ -932,7 +817,6 @@ const nlTest sTests[] = NL_TEST_DEF("Message Self Test", CheckMessageTest), NL_TEST_DEF("Send Encrypted Packet Test", SendEncryptedPacketTest), NL_TEST_DEF("Send Bad Encrypted Packet Test", SendBadEncryptedPacketTest), - NL_TEST_DEF("Drop stale connection Test", StaleConnectionDropTest), NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), NL_TEST_DEF("Session Allocation Test", SessionAllocationTest),