diff --git a/src/commissioner/States.h b/src/commissioner/States.h index 778bbf3ec4f1dc..fe274ca3d9faf9 100644 --- a/src/commissioner/States.h +++ b/src/commissioner/States.h @@ -161,23 +161,43 @@ struct PasscodeAuthenticatedSessionEstablishment : CommissionableNodeDiscovery session; + Messaging::ExchangeContext * exchange = nullptr; + Optional sessionId; mPairing = Platform::MakeShared(); VerifyOrExit(mPairing.get() != nullptr, err = CHIP_ERROR_NO_MEMORY); session = this->mCommissionee.mSystemState->SessionMgr()->CreateUnauthenticatedSession( this->mCommissionee.mCommissionableNodeAddress.Value(), this->mCommissionee.mMrpConfig.ValueOr(gDefaultMRPConfig)); VerifyOrExit(session.HasValue(), err = CHIP_ERROR_NO_MEMORY); + { + uint16_t allocatedSessionId = 0; + SuccessOrExit(err = this->mCommissionee.mSystemState->GetSessionIDAllocator()->Allocate(allocatedSessionId)); + sessionId.SetValue(allocatedSessionId); + } + // Allocate the exchange immediately before calling PASESession::Pair. + // + // PASESession::Pair takes ownership of the exchange and will free it on + // error, but can only do this if it is actually called. Allocating the + // exchange context right before calling Pair ensures that if allocation + // succeeds, PASESession has taken ownership. exchange = this->mCommissionee.mSystemState->ExchangeMgr()->NewContext(session.Value(), this->mPairing.get()); VerifyOrExit(exchange != nullptr, err = CHIP_ERROR_INTERNAL); - SuccessOrExit(err = this->mCommissionee.mSystemState->GetSessionIDAllocator()->Allocate(sessionId)); SuccessOrExit(err = mPairing.get()->Pair(this->mCommissionee.mCommissionableNodeAddress.Value(), - this->mPayload.get()->setUpPINCode, sessionId, this->mCommissionee.mMrpConfig, - exchange, this)); + this->mPayload.get()->setUpPINCode, sessionId.Value(), + this->mCommissionee.mMrpConfig, exchange, this)); exit: + if (err != CHIP_NO_ERROR) + { + // See above. If exchange was allocated, PASESession::Pair will + // have freed it on error. Hence, no exchange cleanup is needed + // here. However, we do need to free our session ID on error. + if (sessionId.HasValue()) + { + this->mCommissionee.mSystemState->GetSessionIDAllocator()->Free(sessionId.Value()); + } + } return err; } @@ -752,11 +772,11 @@ struct CertificateAuthenticatedSessionEstablishment : Base, SessionEst CHIP_ERROR TryCase() { - CHIP_ERROR err = CHIP_NO_ERROR; - uint16_t sessionId = 0; - Messaging::ExchangeContext * exchange = nullptr; - Optional session; + CHIP_ERROR err = CHIP_NO_ERROR; FabricInfo * fabric; + Optional session; + Messaging::ExchangeContext * exchange = nullptr; + Optional sessionId; fabric = this->mCommissionee.mSystemState->Fabrics()->FindFabricWithCompressedId( this->mCommissionee.mOperationalId.Value().GetCompressedFabricId()); @@ -766,13 +786,33 @@ struct CertificateAuthenticatedSessionEstablishment : Base, SessionEst session = this->mCommissionee.mSystemState->SessionMgr()->CreateUnauthenticatedSession( this->mCommissionee.mOperationalAddress.Value(), this->mCommissionee.mMrpConfig.ValueOr(gDefaultMRPConfig)); VerifyOrExit(session.HasValue(), err = CHIP_ERROR_NO_MEMORY); + { + uint16_t allocatedSessionId = 0; + SuccessOrExit(err = this->mCommissionee.mSystemState->GetSessionIDAllocator()->Allocate(allocatedSessionId)); + sessionId.SetValue(allocatedSessionId); + } + // Allocate the exchange immediately before calling EstablishSession. + // + // CASESession::EstablishSession takes ownership of the exchange and + // will free it on error, but can only do this if it is actually called. + // Allocating the exchange context right before calling EstablishSession + // ensures that if allocation succeeds, CASESession has taken ownership. exchange = this->mCommissionee.mSystemState->ExchangeMgr()->NewContext(session.Value(), this->mPairing.get()); VerifyOrExit(exchange != nullptr, err = CHIP_ERROR_INTERNAL); - SuccessOrExit(err = this->mCommissionee.mSystemState->GetSessionIDAllocator()->Allocate(sessionId)); SuccessOrExit(err = mPairing.get()->EstablishSession(this->mCommissionee.mOperationalAddress.Value(), fabric, - this->mCommissionee.mOperationalId.Value().GetNodeId(), sessionId, - exchange, this, this->mCommissionee.mMrpConfig)); + this->mCommissionee.mOperationalId.Value().GetNodeId(), + sessionId.Value(), exchange, this, this->mCommissionee.mMrpConfig)); exit: + if (err != CHIP_NO_ERROR) + { + // See above. If exchange was allocated, PASESession::Pair will + // have freed it on error. Hence, no exchange cleanup is needed + // here. However, we do need to free our session ID on error. + if (sessionId.HasValue()) + { + this->mCommissionee.mSystemState->GetSessionIDAllocator()->Free(sessionId.Value()); + } + } return err; }