diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index a69c3ed6f4aef9..de0f3da6380b4d 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -31,7 +31,6 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres SessionEstablishmentDelegate * delegate) { // Create a UnauthenticatedSession for CASE pairing. - // Don't use mSecureSession here, because mSecureSession is for encrypted communication. Optional session = mInitParams.sessionManager->CreateUnauthenticatedSession(peerAddress, remoteMRPConfig); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index e01e8cf9adf9cd..6ab85815cab10c 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -85,7 +85,7 @@ typedef void (*OnDeviceConnectionFailure)(void * context, PeerId peerId, CHIP_ER * - Expose to consumers the secure session for talking to the device. */ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, - public SessionReleaseDelegate, + public SessionDelegate, public SessionEstablishmentDelegate, public AddressResolve::NodeListener { diff --git a/src/controller/CommissioneeDeviceProxy.h b/src/controller/CommissioneeDeviceProxy.h index 547868fe9f3dee..859fbe5a05add3 100644 --- a/src/controller/CommissioneeDeviceProxy.h +++ b/src/controller/CommissioneeDeviceProxy.h @@ -55,7 +55,7 @@ struct ControllerDeviceInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; }; -class CommissioneeDeviceProxy : public DeviceProxy, public SessionReleaseDelegate +class CommissioneeDeviceProxy : public DeviceProxy, public SessionDelegate { public: ~CommissioneeDeviceProxy() override; diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index ec4b4916070038..c0211c857384e4 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -57,7 +57,7 @@ class ExchangeContextDeletor */ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public ReferenceCounted, - public SessionReleaseDelegate + public SessionDelegate { friend class ExchangeManager; friend class ExchangeContextDeletor; @@ -81,7 +81,8 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, bool IsGroupExchangeContext() const { return mSession && mSession->IsGroupSession(); } - // Implement SessionReleaseDelegate + // Implement SessionDelegate + NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } void OnSessionReleased() override; /** diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index f90741c8d3fa08..1914f3a2cc3507 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -154,7 +154,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, void OnResponseTimeout(Messaging::ExchangeContext * ec) override; Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); } - //// SessionReleaseDelegate //// + //// SessionDelegate //// void OnSessionReleased() override; FabricIndex GetFabricIndex() const { return mFabricInfo != nullptr ? mFabricInfo->GetFabricIndex() : kUndefinedFabricIndex; } diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 8bf4b0d61dbf59..69492209f508c6 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -178,7 +178,7 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, Messaging::ExchangeMessageDispatch & GetMessageDispatch() override { return SessionEstablishmentExchangeDispatch::Instance(); } - //// SessionReleaseDelegate //// + //// SessionDelegate //// void OnSessionReleased() override; private: diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h index 525546df7e5705..25e0c9f95bd06b 100644 --- a/src/protocols/secure_channel/PairingSession.h +++ b/src/protocols/secure_channel/PairingSession.h @@ -38,7 +38,7 @@ namespace chip { class SessionManager; -class DLL_EXPORT PairingSession : public SessionReleaseDelegate +class DLL_EXPORT PairingSession : public SessionDelegate { public: PairingSession() : mSecureSessionHolder(*this) {} @@ -49,6 +49,9 @@ class DLL_EXPORT PairingSession : public SessionReleaseDelegate virtual ScopedNodeId GetLocalScopedNodeId() const = 0; virtual CATValues GetPeerCATs() const = 0; + // Implement SessionDelegate + NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; } + Optional GetLocalSessionId() const { Optional localSessionId; diff --git a/src/transport/SessionDelegate.h b/src/transport/SessionDelegate.h index eca796a972b91a..5d10fe3fbf0c94 100644 --- a/src/transport/SessionDelegate.h +++ b/src/transport/SessionDelegate.h @@ -20,10 +20,26 @@ namespace chip { -class DLL_EXPORT SessionReleaseDelegate +class DLL_EXPORT SessionDelegate { public: - virtual ~SessionReleaseDelegate() {} + virtual ~SessionDelegate() {} + + enum class NewSessionHandlingPolicy : uint8_t + { + kShiftToNewSession, + kStayAtOldSession, + }; + + /** + * @brief + * Called when a new secure session to the same peer is established, over the delegate of SessionHolderWithDelegate object. It + * is suggested to shift to the newly created session. + * + * Note: the default implementation orders shifting to the new session, it should be fine for all users, unless the + * SessionHolder object is expected to be sticky to a specified session. + */ + virtual NewSessionHandlingPolicy GetNewSessionHandlingPolicy() { return NewSessionHandlingPolicy::kShiftToNewSession; } /** * @brief diff --git a/src/transport/SessionHolder.h b/src/transport/SessionHolder.h index f30dabc9d09bcc..28321d9e71c2e9 100644 --- a/src/transport/SessionHolder.h +++ b/src/transport/SessionHolder.h @@ -28,7 +28,7 @@ namespace chip { * released when the underlying session is released. One must verify it is available before use. The object can be * created using SessionHandle.Grab() */ -class SessionHolder : public SessionReleaseDelegate, public IntrusiveListNodeBase +class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase { public: SessionHolder() {} @@ -39,7 +39,7 @@ class SessionHolder : public SessionReleaseDelegate, public IntrusiveListNodeBas SessionHolder & operator=(const SessionHolder &); SessionHolder & operator=(SessionHolder && that); - // Implement SessionReleaseDelegate + // Implement SessionDelegate void OnSessionReleased() override { Release(); } bool Contains(const SessionHandle & session) const @@ -67,11 +67,8 @@ class SessionHolder : public SessionReleaseDelegate, public IntrusiveListNodeBas class SessionHolderWithDelegate : public SessionHolder { public: - SessionHolderWithDelegate(SessionReleaseDelegate & delegate) : mDelegate(delegate) {} - SessionHolderWithDelegate(const SessionHandle & handle, SessionReleaseDelegate & delegate) : mDelegate(delegate) - { - Grab(handle); - } + SessionHolderWithDelegate(SessionDelegate & delegate) : mDelegate(delegate) {} + SessionHolderWithDelegate(const SessionHandle & handle, SessionDelegate & delegate) : mDelegate(delegate) { Grab(handle); } operator bool() const { return SessionHolder::operator bool(); } void OnSessionReleased() override @@ -83,7 +80,7 @@ class SessionHolderWithDelegate : public SessionHolder } private: - SessionReleaseDelegate & mDelegate; + SessionDelegate & mDelegate; }; } // namespace chip diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 25090b7a547876..298d240b232d90 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -59,7 +59,7 @@ const char PAYLOAD[] = "Hello!"; const char LARGE_PAYLOAD[kMaxAppMessageLen + 1] = "test message"; -class TestSessionReleaseCallback : public SessionReleaseDelegate +class TestSessionReleaseCallback : public SessionDelegate { public: void OnSessionReleased() override { mOldConnectionDropped = true; }