diff --git a/examples/tv-casting-app/linux/main.cpp b/examples/tv-casting-app/linux/main.cpp index d7957cd1e22f23..f8d35f37f57248 100644 --- a/examples/tv-casting-app/linux/main.cpp +++ b/examples/tv-casting-app/linux/main.cpp @@ -328,21 +328,14 @@ class TargetVideoPlayerInfo .clientPool = &gCASEClientPool, }; - PeerId peerID = fabric->GetPeerIdForNode(nodeId); - mOperationalDeviceProxy = chip::Platform::New(initParams, peerID); + PeerId peerID = fabric->GetPeerIdForNode(nodeId); - // TODO: figure out why this doesn't work so that we can remove OperationalDeviceProxy creation above, - // and remove the FindSecureSessionForNode and SetConnectedSession calls below - // mOperationalDeviceProxy = server->GetCASESessionManager()->FindExistingSession(nodeId); + mOperationalDeviceProxy = server->GetCASESessionManager()->FindExistingSession(peerID); if (mOperationalDeviceProxy == nullptr) { - ChipLogError(AppServer, "Failed in creating an instance of OperationalDeviceProxy"); + ChipLogError(AppServer, "Failed to find an existing instance of OperationalDeviceProxy to the peer"); return CHIP_ERROR_INVALID_ARGUMENT; } - ChipLogError(AppServer, "Created an instance of OperationalDeviceProxy"); - - SessionHandle handle = server->GetSecureSessionManager().FindSecureSessionForNode(nodeId); - mOperationalDeviceProxy->SetConnectedSession(handle); mInitialized = true; return CHIP_NO_ERROR; diff --git a/src/app/CASESessionManager.cpp b/src/app/CASESessionManager.cpp index 02f6ea1ced98b8..597be920fe984a 100644 --- a/src/app/CASESessionManager.cpp +++ b/src/app/CASESessionManager.cpp @@ -39,7 +39,7 @@ CHIP_ERROR CASESessionManager::FindOrEstablishSession(PeerId peerId, Callback::C OperationalDeviceProxy * session = FindExistingSession(peerId); if (session == nullptr) { - ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing session found"); + ChipLogDetail(CASESessionManager, "FindOrEstablishSession: No existing OperationalDeviceProxy instance found"); session = mConfig.devicePool->Allocate(mConfig.sessionInitParams, peerId); diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index ef7ea1ce0738b6..7682323c7e8f8f 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -29,6 +29,7 @@ #include "CASEClient.h" #include "CommandSender.h" #include "ReadPrepareParams.h" +#include "transport/SecureSession.h" #include #include @@ -57,10 +58,36 @@ void OperationalDeviceProxy::MoveToState(State aTargetState) } } +bool OperationalDeviceProxy::CheckAndLoadExistingSession() +{ + VerifyOrReturnError(mState == State::NeedsAddress || mState == State::Initialized, false); + + SessionHolder existingSession; + ScopedNodeId peerNodeId(mPeerId.GetNodeId(), mFabricInfo->GetFabricIndex()); + + mInitParams.sessionManager->FindSecureSessionForNode(mSecureSession, peerNodeId, Transport::SecureSession::Type::kCASE); + if (mSecureSession) + { + ChipLogProgress(Controller, "Found an existing secure session to [" ChipLogFormatX64 ":" ChipLogFormatX64 "]!", + ChipLogValueX64(mPeerId.GetCompressedFabricId()), ChipLogValueX64(mPeerId.GetNodeId())); + return true; + } + + return false; +} + CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure) { - CHIP_ERROR err = CHIP_NO_ERROR; + CHIP_ERROR err = CHIP_NO_ERROR; + bool isConnected = false; + + // + // Always enqueue our user provided callbacks into our callback list. + // If anything goes wrong below, we'll trigger failures (including any queued from + // a previous iteration which in theory shouldn't happen, but this is written to be more defensive) + // + EnqueueConnectionCallbacks(onConnection, onFailure); switch (mState) { @@ -69,35 +96,47 @@ CHIP_ERROR OperationalDeviceProxy::Connect(Callback::Callback break; case State::NeedsAddress: - err = LookupPeerAddress(); - EnqueueConnectionCallbacks(onConnection, onFailure); + isConnected = CheckAndLoadExistingSession(); + if (!isConnected) + { + err = LookupPeerAddress(); + } + break; case State::Initialized: - err = EstablishConnection(); - if (err == CHIP_NO_ERROR) + isConnected = CheckAndLoadExistingSession(); + if (!isConnected) { - EnqueueConnectionCallbacks(onConnection, onFailure); + err = EstablishConnection(); } + break; + case State::Connecting: - EnqueueConnectionCallbacks(onConnection, onFailure); break; case State::SecureConnected: - if (onConnection != nullptr) - { - onConnection->mCall(onConnection->mContext, this); - } + isConnected = true; break; default: err = CHIP_ERROR_INCORRECT_STATE; } - if (err != CHIP_NO_ERROR && onFailure != nullptr) + if (isConnected) + { + MoveToState(State::SecureConnected); + } + + // + // Dequeue all our callbacks on either encountering an error + // or if we successfully connected. Both should not be set + // simultaneously. + // + if (err != CHIP_NO_ERROR || isConnected) { - onFailure->mCall(onFailure->mContext, mPeerId, err); + DequeueConnectionCallbacks(err); } return err; @@ -133,7 +172,7 @@ CHIP_ERROR OperationalDeviceProxy::UpdateDeviceData(const Transport::PeerAddress err = EstablishConnection(); if (err != CHIP_NO_ERROR) { - OnSessionEstablishmentError(err); + DequeueConnectionCallbacks(err); } } else @@ -194,35 +233,37 @@ void OperationalDeviceProxy::EnqueueConnectionCallbacks(Callback::Callback * cb = Callback::Callback::FromCancelable(ready.mNext); + Callback::Callback * cb = + Callback::Callback::FromCancelable(ready.mNext); cb->Cancel(); - if (executeCallback) + + if (error != CHIP_NO_ERROR) { - cb->mCall(cb->mContext, this); + cb->mCall(cb->mContext, mPeerId, error); } } -} -void OperationalDeviceProxy::DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback) -{ - Cancelable ready; - mConnectionFailure.DequeueAll(ready); + mConnectionSuccess.DequeueAll(ready); while (ready.mNext != &ready) { - Callback::Callback * cb = - Callback::Callback::FromCancelable(ready.mNext); + Callback::Callback * cb = Callback::Callback::FromCancelable(ready.mNext); cb->Cancel(); - if (executeCallback) + if (error == CHIP_NO_ERROR) { - cb->mCall(cb->mContext, mPeerId, error); + cb->mCall(cb->mContext, this); } } } @@ -234,13 +275,20 @@ void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASECli ChipLogError(Controller, "HandleCASEConnectionFailure was called while the device was not initialized")); VerifyOrReturn(client == device->mCASEClient, ChipLogError(Controller, "HandleCASEConnectionFailure for unknown CASEClient")); + // + // We don't need to reset the state all the way back to NeedsAddress since all that transpired + // was just CASE connection failure. So let's re-use the cached address to re-do CASE again + // if need-be. + // device->MoveToState(State::Initialized); device->CloseCASESession(); - device->DequeueConnectionSuccessCallbacks(/* executeCallback */ false); - device->DequeueConnectionFailureCallbacks(error, /* executeCallback */ true); - // Do not touch device anymore; it might have been destroyed by a failure + device->DequeueConnectionCallbacks(error); + + // + // Do not touch this instance anymore; it might have been destroyed by a failure // callback. + // } void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client) @@ -254,19 +302,18 @@ void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * cl if (err != CHIP_NO_ERROR) { device->HandleCASEConnectionFailure(context, client, err); - // Do not touch device anymore; it might have been destroyed by a - // HandleCASEConnectionFailure. } else { device->MoveToState(State::SecureConnected); - device->CloseCASESession(); - device->DequeueConnectionFailureCallbacks(CHIP_NO_ERROR, /* executeCallback */ false); - device->DequeueConnectionSuccessCallbacks(/* executeCallback */ true); - // Do not touch device anymore; it might have been destroyed by a - // success callback. + device->DequeueConnectionCallbacks(CHIP_NO_ERROR); } + + // + // Do not touch this instance anymore; it might have been destroyed by a failure + // callback. + // } CHIP_ERROR OperationalDeviceProxy::Disconnect() @@ -285,12 +332,6 @@ CHIP_ERROR OperationalDeviceProxy::Disconnect() return CHIP_NO_ERROR; } -void OperationalDeviceProxy::SetConnectedSession(const SessionHandle & handle) -{ - mSecureSession.Grab(handle); - MoveToState(State::SecureConnected); -} - void OperationalDeviceProxy::Clear() { if (mCASEClient) @@ -367,8 +408,7 @@ void OperationalDeviceProxy::OnNodeAddressResolutionFailed(const PeerId & peerId ChipLogError(Discovery, "Operational discovery failed for 0x" ChipLogFormatX64 ": %" CHIP_ERROR_FORMAT, ChipLogValueX64(peerId.GetNodeId()), reason.Format()); - DequeueConnectionSuccessCallbacks(/* executeCallback */ false); - DequeueConnectionFailureCallbacks(reason, /* executeCallback */ true); + DequeueConnectionCallbacks(reason); } } // namespace chip diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 7f1b9d715b23f6..fb65eaaa2cf5db 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -91,6 +91,10 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, { public: ~OperationalDeviceProxy() override; + + // + // TODO: Should not be PeerId, but rather, ScopedNodeId + // OperationalDeviceProxy(DeviceProxyInitParams & params, PeerId peerId) : mSecureSession(*this) { mInitParams = params; @@ -159,15 +163,6 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, */ CHIP_ERROR Disconnect() override; - /** - * Use SetConnectedSession if 'this' object is a newly allocated device proxy. - * It will take an existing session, such as the one established - * during commissioning, and use it for this device proxy. - * - * Note: Avoid using this function generally as it is Deprecated - */ - void SetConnectedSession(const SessionHandle & handle); - NodeId GetDeviceId() const override { return mPeerId.GetNodeId(); } /** @@ -268,6 +263,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, CHIP_ERROR EstablishConnection(); + /* + * This checks to see if an existing CASE session exists to the peer within the SessionManager + * and if one exists, to load that into mSecureSession. + * + * Returns true if a valid session was found, false otherwise. + * + */ + bool CheckAndLoadExistingSession(); + bool IsSecureConnected() const override { return mState == State::SecureConnected; } static void HandleCASEConnected(void * context, CASEClient * client); @@ -280,8 +284,15 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure); - void DequeueConnectionSuccessCallbacks(bool executeCallback); - void DequeueConnectionFailureCallbacks(CHIP_ERROR error, bool executeCallback); + /* + * This dequeues all failure and success callbacks and appropriately + * invokes either set depending on the value of error. + * + * If error == CHIP_NO_ERROR, only success callbacks are invoked. + * Otherwise, only failure callbacks are invoked. + * + */ + void DequeueConnectionCallbacks(CHIP_ERROR error); }; } // namespace chip diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 7f99ae8cd4fea5..8d1705543af798 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -814,11 +814,13 @@ void SessionManager::ExpiryTimerCallback(System::Layer * layer, void * param) mgr->ScheduleExpiryTimer(); // re-schedule the oneshot timer } -SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) +void SessionManager::FindSecureSessionForNode(SessionHolder & sessionHolder, ScopedNodeId peerNodeId, + Transport::SecureSession::Type type) { SecureSession * found = nullptr; - mSecureSessions.ForEachSession([&](auto session) { - if (session->GetPeerNodeId() == peerNodeId) + mSecureSessions.ForEachSession([&peerNodeId, type, &found](auto session) { + if (session->GetPeer() == peerNodeId && + (type == SecureSession::Type::kUndefined || type == session->GetSecureSessionType())) { found = session; return Loop::Break; @@ -826,8 +828,12 @@ SessionHandle SessionManager::FindSecureSessionForNode(NodeId peerNodeId) return Loop::Continue; }); - VerifyOrDie(found != nullptr); - return SessionHandle(*found); + sessionHolder.Release(); + + if (found) + { + sessionHolder.Grab(SessionHandle(*found)); + } } /** diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 9b59b9af3fea78..988c6b30e3188e 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -249,9 +249,12 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config); } - // TODO: this is a temporary solution for legacy tests which use nodeId to send packets - // and tv-casting-app that uses the TV's node ID to find the associated secure session - SessionHandle FindSecureSessionForNode(NodeId peerNodeId); + // + // Find an existing secure session given a peer's scoped NodeId and a type of session to match against. + // If matching against all types of sessions is desired, kUnDefined should be passed into type. + // + void FindSecureSessionForNode(SessionHolder & sessionHolder, ScopedNodeId peerNodeId, + Transport::SecureSession::Type type = Transport::SecureSession::Type::kUndefined); using SessionHandleCallback = bool (*)(void * context, SessionHandle & sessionHandle); CHIP_ERROR ForEachSessionHandle(void * context, SessionHandleCallback callback);