diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 6c6a6080de02c2..4807e6e08e0932 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -54,9 +54,10 @@ void OperationalDeviceProxy::MoveToState(State aTargetState) ChipLogDetail(Controller, "OperationalDeviceProxy[" ChipLogFormatX64 ":" ChipLogFormatX64 "]: State change %d --> %d", ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId()), to_underlying(mState), to_underlying(aTargetState)); + mState = aTargetState; - if (aTargetState != State::Connecting) + if (aTargetState != State::Connecting && aTargetState != State::Recovering) { CleanupCASEClient(); } @@ -128,6 +129,10 @@ void OperationalDeviceProxy::Connect(Callback::Callback * onC if (!isConnected) { err = EstablishConnection(); + if (err == CHIP_NO_ERROR) + { + MoveToState(State::Connecting); + } } break; @@ -139,6 +144,9 @@ void OperationalDeviceProxy::Connect(Callback::Callback * onC isConnected = true; break; + case State::Recovering: + break; + default: err = CHIP_ERROR_INCORRECT_STATE; } @@ -189,7 +197,11 @@ void OperationalDeviceProxy::UpdateDeviceData(const Transport::PeerAddress & add { MoveToState(State::HasAddress); err = EstablishConnection(); - if (err != CHIP_NO_ERROR) + if (err == CHIP_NO_ERROR) + { + MoveToState(State::Connecting); + } + else { DequeueConnectionCallbacks(err); } @@ -223,8 +235,6 @@ CHIP_ERROR OperationalDeviceProxy::EstablishConnection() return err; } - MoveToState(State::Connecting); - return CHIP_NO_ERROR; } @@ -285,31 +295,43 @@ void OperationalDeviceProxy::DequeueConnectionCallbacks(CHIP_ERROR error) void OperationalDeviceProxy::OnSessionEstablishmentError(CHIP_ERROR error) { - VerifyOrReturn(mState != State::Uninitialized && mState != State::NeedsAddress, + VerifyOrReturn(mState == State::Connecting || mState == State::Recovering, ChipLogError(Controller, "HandleCASEConnectionFailure was called while the device was not initialized")); - // - // We don't need to reset the state all the way back to NeedsAddress since all that transpired - // was just CASE connection failure. So let's re-use the cached address to re-do CASE again - // if need-be. - // - MoveToState(State::HasAddress); + if (mState == State::Connecting) + { + // + // We don't need to reset the state all the way back to NeedsAddress since all that transpired + // was just CASE connection failure. So let's re-use the cached address to re-do CASE again + // if need-be. + // + MoveToState(State::HasAddress); - DequeueConnectionCallbacks(error); + DequeueConnectionCallbacks(error); + } + else if (mState == State::Recovering) + { + mSecureSession.Get().Value()->DispatchSessionEvent(&SessionDelegate::OnRecoveryFailed); + } // Do not touch device instance anymore; it might have been destroyed by a failure callback. } void OperationalDeviceProxy::OnSessionEstablished(const SessionHandle & session) { - VerifyOrReturn(mState != State::Uninitialized, + VerifyOrReturn(mState == State::Connecting || mState == State::Recovering, ChipLogError(Controller, "HandleCASEConnected was called while the device was not initialized")); + bool report = (mState == State::Connecting); + if (!mSecureSession.Grab(session)) return; // Got an invalid session, do not change any state MoveToState(State::SecureConnected); - DequeueConnectionCallbacks(CHIP_NO_ERROR); + mInitParams.sessionManager->ShiftToSession(session); + + if (report) + DequeueConnectionCallbacks(CHIP_NO_ERROR); // Do not touch this instance anymore; it might have been destroyed by a callback. } @@ -345,9 +367,23 @@ void OperationalDeviceProxy::OnFirstMessageDeliveryFailed() LookupPeerAddress(); } -void OperationalDeviceProxy::OnSessionHang() +void OperationalDeviceProxy::OnRequestRecovery() { - // TODO: establish a new session + TrySessionRecovery(); +} + +void OperationalDeviceProxy::TrySessionRecovery() +{ + MoveToState(State::Recovering); + CHIP_ERROR err = EstablishConnection(); + if (err == CHIP_NO_ERROR) + { + MoveToState(State::Recovering); + } + else + { + mSecureSession.Get().Value()->DispatchSessionEvent(&SessionDelegate::OnRecoveryFailed); + } } CHIP_ERROR OperationalDeviceProxy::ShutdownSubscriptions() @@ -370,11 +406,7 @@ OperationalDeviceProxy::~OperationalDeviceProxy() } } - if (mCASEClient) - { - // Make sure we don't leak it. - mInitParams.clientPool->Release(mCASEClient); - } + MoveToState(State::Uninitialized); } CHIP_ERROR OperationalDeviceProxy::LookupPeerAddress() diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 74731440c09c5f..4e3e49bcd69562 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -151,9 +151,9 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, void OnSessionReleased() override; // Called when a message is not acked within first retrans timer, try to refresh the peer address void OnFirstMessageDeliveryFailed() override; - // Called when a connection is hanging. Try to re-establish another session, and shift to the new session when done, the + // Triggered by application layer. Try to re-establish another session, and shift to the new session when done, the // original session won't be touched during the period. - void OnSessionHang() override; + void OnRequestRecovery() override; /** * Mark any open session with the device as expired. @@ -218,6 +218,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, ResolvingAddress, // Address lookup in progress. HasAddress, // Have an address, CASE handshake not started yet. Connecting, // CASE handshake in progress. + Recovering, // CASE session hang, trying to establish a new one, the old session is hanging but left untouched. SecureConnected, // CASE session established. }; @@ -277,6 +278,8 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, * This function will set new IP address, port and MRP retransmission intervals of the device. */ void UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config); + + void TrySessionRecovery(); }; } // namespace chip diff --git a/src/lib/core/CASEAuthTag.h b/src/lib/core/CASEAuthTag.h index cd490d70d28334..5740d19f80021c 100644 --- a/src/lib/core/CASEAuthTag.h +++ b/src/lib/core/CASEAuthTag.h @@ -17,6 +17,8 @@ #pragma once +#include + #include #include #include @@ -35,11 +37,11 @@ static constexpr size_t kMaxSubjectCATAttributeCount = CHIP_CONFIG_CERT_MAX_RDN_ struct CATValues { - CASEAuthTag values[kMaxSubjectCATAttributeCount] = { kUndefinedCAT }; + std::array values = { kUndefinedCAT }; /* @brief Returns size of the CAT values array. */ - static constexpr size_t size() { return ArraySize(values); } + static constexpr size_t size() { return std::tuple_size::value; } /* @brief Returns true if subject input checks against one of the CATs in the values array. */ @@ -58,6 +60,8 @@ struct CATValues return false; } + bool operator==(const CATValues & that) const { return values == that.values; } + static constexpr size_t kSerializedLength = kMaxSubjectCATAttributeCount * sizeof(CASEAuthTag); typedef uint8_t Serialized[kSerializedLength]; diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 4ecbd28547ad8b..02b338c183b345 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -113,6 +113,7 @@ void CASEServer::OnSessionEstablished(const SessionHandle & session) { ChipLogProgress(Inet, "CASE Session established to peer: " ChipLogFormatScopedNodeId, ChipLogValueScopedNodeId(session->GetPeer())); + mSessionManager->ShiftToSession(session); Cleanup(); } } // namespace chip diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp index 2902a9df5a9790..d88e32d025816b 100644 --- a/src/transport/SecureSession.cpp +++ b/src/transport/SecureSession.cpp @@ -77,5 +77,14 @@ Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const return subjectDescriptor; } +void SecureSession::TryShiftToSession(const SessionHandle & session) +{ + if (GetSecureSessionType() == SecureSession::Type::kCASE && GetPeer() == session->GetPeer() && + GetPeerCATs() == session->AsSecureSession()->GetPeerCATs()) + { + Session::DoShiftToSession(session); + } +} + } // namespace Transport } // namespace chip diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 7083ed34d5078d..b5b61797ecedbd 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -176,7 +176,7 @@ class SecureSession : public Session, public ReferenceCounted(this); } +void Session::DoShiftToSession(const SessionHandle & session) +{ + // Shift to the new session, checks are performed by the subclass implementation which is the caller. + IntrusiveList::Iterator iter = mHolders.begin(); + while (iter != mHolders.end()) + { + // The iterator can be invalid once it is migrated to another session. So we store its next before it is happening. + IntrusiveList::Iterator next = iter; + ++next; + + iter->ShiftToSession(session); + + iter = next; + } +} + } // namespace Transport } // namespace chip diff --git a/src/transport/Session.h b/src/transport/Session.h index d1eb81d03c6e7e..5ae7e08b749157 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -108,12 +108,14 @@ class Session SessionHandle session(*this); while (!mHolders.Empty()) { - mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list + mHolders.begin()->SessionReleased(); // OnSessionReleased must remove the item from the linked list } } void SetFabricIndex(FabricIndex index) { mFabricIndex = index; } + void DoShiftToSession(const SessionHandle & session); + private: IntrusiveList mHolders; FabricIndex mFabricIndex = kUndefinedFabricIndex; diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index 763848dfbb85eb..3896d001d73583 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -36,8 +36,16 @@ class DLL_EXPORT SessionDelegate * Called when a new secure session to the same peer is established, over the delegate of SessionHolderWithDelegate object. It * is suggested to shift to the newly created session. * + * Our security model is built upon Exchanges and Sessions, but not SessionHolders, such that SessionHolders should be able to + * shift to a new sessoin freely. If an application is holding a session which is not intent to be shifted, it can provides + * its shifting policy by override GetNewSessionHandlingPolicy in SessionDelegate. For example SessionHolders inside + * ExchangeContext and PairingSession are not eligible for auto-shifting. + * * Note: the default implementation orders shifting to the new session, it should be fine for all users, unless the - * SessionHolder object is expected to be sticky to a specified session. + * SessionHolder object is expected to be sticky to a specified session. + * + * Note: the implementation should not modify session pool nor session holders (eg, adding new session, removing old session), + * or else something inconsistent can be happened inside Session::DoShiftToSession. */ virtual NewSessionHandlingPolicy GetNewSessionHandlingPolicy() { return NewSessionHandlingPolicy::kShiftToNewSession; } @@ -64,6 +72,22 @@ class DLL_EXPORT SessionDelegate * Note: the implementation must not do anything that will destroy the session or change the SessionHolder. */ virtual void OnSessionHang() {} + + /** + * @brief + * Called when an application requests to recover a session. + * + * Note: the implementation must not do anything that will destroy the session or change the SessionHolder. + */ + virtual void OnRequestRecovery() {} + + /** + * @brief + * Called when a pairing fails to recover a session. + * + * Note: the implementation must not do anything that will destroy the session or change the SessionHolder. + */ + virtual void OnRecoveryFailed() {} }; } // namespace chip diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index ac7cc507b8602a..53b43e6c854d91 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -28,19 +28,23 @@ namespace chip { * released when the underlying session is released. One must verify it is available before use. The object can be * created using SessionHandle.Grab() */ -class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase +class SessionHolder : public IntrusiveListNodeBase { public: SessionHolder() {} - ~SessionHolder() override; + virtual ~SessionHolder(); SessionHolder(const SessionHolder &); SessionHolder(SessionHolder && that); SessionHolder & operator=(const SessionHolder &); SessionHolder & operator=(SessionHolder && that); - // Implement SessionDelegate - void OnSessionReleased() override { Release(); } + virtual void SessionReleased() { Release(); } + virtual void ShiftToSession(const SessionHandle & session) + { + Release(); + Grab(session); + } bool Contains(const SessionHandle & session) const { @@ -51,7 +55,7 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase bool Grab(const SessionHandle & session); void Release(); - operator bool() const { return mSession.HasValue(); } + explicit operator bool() const { return mSession.HasValue(); } Optional Get() const { // @@ -78,10 +82,11 @@ class SessionHolderWithDelegate : public SessionHolder { public: SessionHolderWithDelegate(SessionDelegate & delegate) : mDelegate(delegate) {} + SessionHolderWithDelegate(SessionHolder & holder, SessionDelegate & delegate) : SessionHolder(holder), mDelegate(delegate) {} SessionHolderWithDelegate(const SessionHandle & handle, SessionDelegate & delegate) : mDelegate(delegate) { Grab(handle); } operator bool() const { return SessionHolder::operator bool(); } - void OnSessionReleased() override + void SessionReleased() override { Release(); @@ -89,6 +94,12 @@ class SessionHolderWithDelegate : public SessionHolder mDelegate.OnSessionReleased(); } + void ShiftToSession(const SessionHandle & session) override + { + if (mDelegate.GetNewSessionHandlingPolicy() == SessionDelegate::NewSessionHandlingPolicy::kShiftToNewSession) + SessionHolder::ShiftToSession(session); + } + void DispatchSessionEvent(SessionDelegate::Event event) override { (mDelegate.*event)(); } private: diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index cc3d8e2759de17..976134e4608382 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -409,6 +409,27 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH return CHIP_NO_ERROR; } +CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, + uint16_t peerSessionId, NodeId localNodeId, NodeId peerNodeId, + FabricIndex fabric, const Transport::PeerAddress & peerAddress, + CryptoContext::SessionRole role, const CATValues & cats) +{ + Optional session = + mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kCASE, localSessionId, localNodeId, + peerNodeId, cats, peerSessionId, fabric, GetLocalMRPConfig()); + VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); + SecureSession * secureSession = session.Value()->AsSecureSession(); + secureSession->SetPeerAddress(peerAddress); + + size_t secretLen = strlen(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE); + ByteSpan secret(reinterpret_cast(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen); + ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret( + secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role)); + secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue); + sessionHolder.Grab(session.Value()); + return CHIP_NO_ERROR; +} + void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg) { CHIP_TRACE_PREPARED_MESSAGE_RECEIVED(&peerAddress, &msg); @@ -720,6 +741,23 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade } } +void SessionManager::ShiftToSession(const SessionHandle & handle) +{ + VerifyOrDie(handle->IsSecureSession()); + VerifyOrDie(handle->AsSecureSession()->GetSecureSessionType() == SecureSession::Type::kCASE); + mSecureSessions.ForEachSession([&](SecureSession * oldSession) { + if (handle->AsSecureSession() == oldSession) + return Loop::Continue; + + // This will update all SessionHolder pointing to oldSession, to the provided handle. + // + // See comment of SessionDelegate::GetNewSessionHandlingPolicy about how session auto-shifting works, and how to disable it + // for specific SessionHolder in specific scenario. + oldSession->TryShiftToSession(handle); + return Loop::Continue; + }); +} + Optional SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId, const Optional & type) { diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index a3ed1f5b440474..a404628169375f 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -154,6 +154,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, uint16_t peerSessionId, FabricIndex fabricIndex, const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role); + CHIP_ERROR InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, uint16_t peerSessionId, + NodeId localNodeId, NodeId peerNodeId, FabricIndex fabric, + const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role, + const CATValues & cats = CATValues{}); /** * @brief @@ -224,6 +228,9 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config); } + // Update existing SessionHolders to shift to the given session. + void ShiftToSession(const SessionHandle & handle); + // // Find an existing secure session given a peer's scoped NodeId and a type of session to match against. // If matching against all types of sessions is desired, NullOptional should be passed into type. diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index cc9179a92b2330..fbbba6311f8597 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -662,7 +662,7 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager } } -void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) +static void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) { SessionManager sessionManager; @@ -759,6 +759,77 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) sessionManager.Shutdown(); } +static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext) +{ + IPAddress addr; + IPAddress::FromString("::1", addr); + + NodeId aliceNodeId = 0x11223344ull; + NodeId bobNodeId = 0x12344321ull; + FabricIndex aliceFabricIndex = 1; + FabricIndex bobFabricIndex = 1; + + SessionManager sessionManager; + secure_channel::MessageCounterManager gMessageCounterManager; + chip::TestPersistentStorageDelegate deviceStorage; + + Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); + + SessionHolder aliceToBobSession; + CHIP_ERROR err = sessionManager.InjectCaseSessionWithTestKey(aliceToBobSession, 2, 1, aliceNodeId, bobNodeId, aliceFabricIndex, + peer, CryptoContext::SessionRole::kInitiator); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + class StickySessionDelegate : public SessionDelegate + { + public: + NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } + void OnSessionReleased() override {} + } delegate; + + SessionHolderWithDelegate stickyAliceToBobSession(aliceToBobSession.Get().Value(), delegate); + NL_TEST_ASSERT(inSuite, aliceToBobSession.Contains(stickyAliceToBobSession.Get().Value())); + + SessionHolder bobToAliceSession; + err = sessionManager.InjectCaseSessionWithTestKey(bobToAliceSession, 1, 2, bobNodeId, aliceNodeId, bobFabricIndex, peer, + CryptoContext::SessionRole::kResponder); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + SessionHolder newAliceToBobSession; + err = sessionManager.InjectCaseSessionWithTestKey(newAliceToBobSession, 3, 4, aliceNodeId, bobNodeId, aliceFabricIndex, peer, + CryptoContext::SessionRole::kInitiator); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // Here we got 3 sessions, and 4 holders: + // 1. alice -> bob: aliceToBobSession, stickyAliceToBobSession + // 2. alice <- bob: bobToAliceSession + // 3. alice -> bob: newAliceToBobSession + + SecureSession * session1 = aliceToBobSession->AsSecureSession(); + SecureSession * session2 = bobToAliceSession->AsSecureSession(); + SecureSession * session3 = newAliceToBobSession->AsSecureSession(); + + NL_TEST_ASSERT(inSuite, session1 != session3); + NL_TEST_ASSERT(inSuite, stickyAliceToBobSession->AsSecureSession() == session1); + + // Now shift the 1st session to the 3rd one, after shifting, holders should be: + // 1. alice -> bob: stickyAliceToBobSession + // 2. alice <- bob: bobToAliceSession + // 3. alice -> bob: aliceToBobSession, newAliceToBobSession + sessionManager.ShiftToSession(newAliceToBobSession.Get().Value()); + + NL_TEST_ASSERT(inSuite, aliceToBobSession); + NL_TEST_ASSERT(inSuite, stickyAliceToBobSession); + NL_TEST_ASSERT(inSuite, newAliceToBobSession); + + NL_TEST_ASSERT(inSuite, stickyAliceToBobSession->AsSecureSession() == session1); + NL_TEST_ASSERT(inSuite, bobToAliceSession->AsSecureSession() == session2); + NL_TEST_ASSERT(inSuite, aliceToBobSession->AsSecureSession() == session3); + NL_TEST_ASSERT(inSuite, newAliceToBobSession->AsSecureSession() == session3); + + sessionManager.Shutdown(); +} + // Test Suite /** @@ -774,6 +845,7 @@ const nlTest sTests[] = NL_TEST_DEF("Old counter Test", SendPacketWithOldCounterTest), NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest), NL_TEST_DEF("Session Allocation Test", SessionAllocationTest), + NL_TEST_DEF("SessionShiftingTest", SessionShiftingTest), NL_TEST_SENTINEL() };