diff --git a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp index 95a2bc73cdecc0..71f935590fb00e 100644 --- a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp +++ b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp @@ -468,7 +468,6 @@ bool emberAfOperationalCredentialsClusterRemoveFabricCallback(app::CommandHandle { SendNOCResponse(commandObj, commandPath, OperationalCertStatus::kSuccess, fabricBeingRemoved, CharSpan()); - // Use a more direct getter for FabricIndex from commandObj chip::Messaging::ExchangeContext * ec = commandObj->GetExchangeContext(); FabricIndex currentFabricIndex = commandObj->GetAccessingFabricIndex(); if (currentFabricIndex == fabricBeingRemoved) diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 21b3b52cfdef50..06eb17fabd4238 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -161,6 +161,10 @@ void CASESession::Clear() Crypto::ClearSecretData(mIPK); AbortExchange(); + + mLocalNodeId = kUndefinedNodeId; + mPeerNodeId = kUndefinedNodeId; + mFabricInfo = nullptr; } void CASESession::AbortExchange() @@ -257,6 +261,10 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout()); mPeerNodeId = peerNodeId; + mLocalNodeId = fabric->GetNodeId(); + + ChipLogProgress(SecureChannel, "Initiating session on local FabricIndex %u from 0x" ChipLogFormatX64 " -> 0x" ChipLogFormatX64, + static_cast(fabric->GetFabricIndex()), ChipLogValueX64(mLocalNodeId), ChipLogValueX64(mPeerNodeId)); err = SendSigma1(); SuccessOrExit(err); @@ -336,9 +344,13 @@ CHIP_ERROR CASESession::RecoverInitiatorIpk() size_t ipkIndex = (ipkKeySet.num_keys_used > 1) ? ((ipkKeySet.num_keys_used - 1) - 1) : 0; memcpy(&mIPK[0], ipkKeySet.epoch_keys[ipkIndex].key, sizeof(mIPK)); + // Leaving this logging code for debug, but this cannot be enabled at runtime + // since it leaks private security material. +#if 0 ChipLogProgress(SecureChannel, "RecoverInitiatorIpk: GroupDataProvider %p, Got IPK for FabricIndex %u", mGroupDataProvider, static_cast(mFabricInfo->GetFabricIndex())); ChipLogByteSpan(SecureChannel, ByteSpan(mIPK)); +#endif return CHIP_NO_ERROR; } @@ -492,6 +504,7 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestionationId(const ByteSpan & destina MutableByteSpan ipkSpan(mIPK); CopySpanToMutableSpan(candidateIpkSpan, ipkSpan); mFabricInfo = &fabricInfo; + mLocalNodeId = nodeId; break; } } @@ -524,6 +537,7 @@ CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumpti return CHIP_ERROR_INTERNAL; mPeerNodeId = node.GetNodeId(); + mLocalNodeId = mFabricInfo->GetNodeId(); return CHIP_NO_ERROR; } @@ -569,11 +583,15 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) return CHIP_NO_ERROR; } + // Attempt to match the initiator's desired destination based on local fabric table. err = FindLocalNodeFromDestionationId(destinationIdentifier, initiatorRandom); if (err == CHIP_NO_ERROR) { ChipLogProgress(SecureChannel, "CASE matched destination ID: fabricIndex %u, NodeID 0x" ChipLogFormatX64, - static_cast(mFabricInfo->GetFabricIndex()), ChipLogValueX64(mFabricInfo->GetNodeId())); + static_cast(mFabricInfo->GetFabricIndex()), ChipLogValueX64(mLocalNodeId)); + + // Side-effect of FindLocalNodeFromDestionationId success was that mFabricInfo/mLocalNodeId are now + // set to the local fabric and associated NodeId that was targeted by the initiator. } else { diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index a1c136acfa6b26..7321ed8c273735 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -64,6 +64,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kCASE; } ScopedNodeId GetPeer() const override { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); } + ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(mLocalNodeId, GetFabricIndex()); } CATValues GetPeerCATs() const override { return mPeerCATs; }; /** @@ -253,6 +254,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, FabricTable * mFabricsTable = nullptr; const FabricInfo * mFabricInfo = nullptr; NodeId mPeerNodeId = kUndefinedNodeId; + NodeId mLocalNodeId = kUndefinedNodeId; CATValues mPeerCATs; // This field is only used for CASE responder, when during sending sigma2 and waiting for sigma3 diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index dfa3b826bdaeec..1168be478b4519 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -78,6 +78,13 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, { return ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex); } + + ScopedNodeId GetLocalScopedNodeId() const override + { + // For PASE, source is always the undefined node ID + return ScopedNodeId(); + } + CATValues GetPeerCATs() const override { return CATValues(); }; CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override; diff --git a/src/transport/GroupSession.h b/src/transport/GroupSession.h index c3118653b9e384..46bcad14b3d9eb 100644 --- a/src/transport/GroupSession.h +++ b/src/transport/GroupSession.h @@ -27,7 +27,7 @@ namespace Transport { class IncomingGroupSession : public Session { public: - IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId sourceNodeId) : mGroupId(group), mSourceNodeId(sourceNodeId) + IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId peerNodeId) : mGroupId(group), mPeerNodeId(peerNodeId) { SetFabricIndex(fabricIndex); } @@ -38,7 +38,8 @@ class IncomingGroupSession : public Session const char * GetSessionTypeString() const override { return "incoming group"; }; #endif - ScopedNodeId GetPeer() const override { return ScopedNodeId(mSourceNodeId, GetFabricIndex()); } + ScopedNodeId GetPeer() const override { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); } + ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(kUndefinedNodeId, GetFabricIndex()); } Access::SubjectDescriptor GetSubjectDescriptor() const override { @@ -68,11 +69,9 @@ class IncomingGroupSession : public Session GroupId GetGroupId() const { return mGroupId; } - NodeId GetSourceNodeId() const { return mSourceNodeId; } - private: const GroupId mGroupId; - const NodeId mSourceNodeId; + const NodeId mPeerNodeId; }; class OutgoingGroupSession : public Session @@ -86,7 +85,10 @@ class OutgoingGroupSession : public Session const char * GetSessionTypeString() const override { return "outgoing group"; }; #endif + // Peer node ID is unused: users care about the group, not the node ScopedNodeId GetPeer() const override { return ScopedNodeId(); } + // Local node ID is unused: users care about the group, not the node + ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(); } Access::SubjectDescriptor GetSubjectDescriptor() const override { diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp index b7173a330f3c3d..d8008aa152ca4e 100644 --- a/src/transport/PairingSession.cpp +++ b/src/transport/PairingSession.cpp @@ -48,7 +48,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress & // Call Activate last, otherwise errors on anything after would lead to // a partially valid session. - secureSession->Activate(GetSecureSessionType(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig); + secureSession->Activate(GetSecureSessionType(), GetPeer(), GetLocalScopedNodeId(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig); ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!", ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId); diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index c2d7576c20ada9..72318ecd0e79b1 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -44,6 +44,7 @@ class DLL_EXPORT PairingSession virtual Transport::SecureSession::Type GetSecureSessionType() const = 0; virtual ScopedNodeId GetPeer() const = 0; + virtual ScopedNodeId GetLocalScopedNodeId() const = 0; virtual CATValues GetPeerCATs() const = 0; Optional GetLocalSessionId() const diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp index 2a55e03e5bbd9a..e2688869bc509a 100644 --- a/src/transport/SecureSession.cpp +++ b/src/transport/SecureSession.cpp @@ -20,11 +20,6 @@ namespace chip { namespace Transport { -ScopedNodeId SecureSession::GetPeer() const -{ - return ScopedNodeId(mPeerNodeId, GetFabricIndex()); -} - Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const { Access::SubjectDescriptor subjectDescriptor; diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index 49fea973b87127..6d4dc7b85770aa 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -72,10 +72,10 @@ class SecureSession : public Session // TODO: This constructor should be private. Tests should allocate a // kPending session and then call Activate(), just like non-test code does. - SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, + SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, NodeId localNodeId, CATValues peerCATs, uint16_t peerSessionId, FabricIndex fabric, const ReliableMessageProtocolConfig & config) : mSecureSessionType(secureSessionType), - mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), + mPeerNodeId(peerNodeId), mLocalNodeId(localNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId), mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()), mLastPeerActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config) { @@ -89,7 +89,7 @@ class SecureSession : public Session * receives a local session ID, but no other state. */ SecureSession(uint16_t localSessionId) : - SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig()) + SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig()) {} /** @@ -98,15 +98,25 @@ class SecureSession : public Session * PASE, setting internal state according to the parameters used and * discovered during session establishment. */ - void Activate(Type secureSessionType, const ScopedNodeId & peer, CATValues peerCATs, uint16_t peerSessionId, - const ReliableMessageProtocolConfig & config) + void Activate(Type secureSessionType, const ScopedNodeId & peerNode, const ScopedNodeId & localNode, CATValues peerCATs, + uint16_t peerSessionId, const ReliableMessageProtocolConfig & config) { + VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex()); + + // PASE sessions must always start unassociated with a Fabric! + VerifyOrDie(!((secureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex))); + // CASE sessions must always start "associated" a given Fabric! + VerifyOrDie(!((secureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex))); + // CASE sessions can only be activated against operational node IDs! + VerifyOrDie(!((secureSessionType == Type::kCASE) && (!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId())))); + mSecureSessionType = secureSessionType; - mPeerNodeId = peer.GetNodeId(); + mPeerNodeId = peerNode.GetNodeId(); + mLocalNodeId = localNode.GetNodeId(); mPeerCATs = peerCATs; mPeerSessionId = peerSessionId; mMRPConfig = config; - SetFabricIndex(peer.GetFabricIndex()); + SetFabricIndex(peerNode.GetFabricIndex()); } ~SecureSession() override { NotifySessionReleased(); } @@ -120,7 +130,16 @@ class SecureSession : public Session const char * GetSessionTypeString() const override { return "secure"; }; #endif - ScopedNodeId GetPeer() const override; + ScopedNodeId GetPeer() const override + { + return ScopedNodeId(mPeerNodeId, GetFabricIndex()); + } + + ScopedNodeId GetLocalScopedNodeId() const override + { + return ScopedNodeId(mLocalNodeId, GetFabricIndex()); + } + Access::SubjectDescriptor GetSubjectDescriptor() const override; bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; } @@ -147,6 +166,8 @@ class SecureSession : public Session bool IsPASESession() const { return GetSecureSessionType() == Type::kPASE; } bool IsActiveSession() const { return GetSecureSessionType() != Type::kPending; } NodeId GetPeerNodeId() const { return mPeerNodeId; } + NodeId GetLocalNodeId() const { return mLocalNodeId; } + CATValues GetPeerCATs() const { return mPeerCATs; } void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; } @@ -192,6 +213,7 @@ class SecureSession : public Session private: Type mSecureSessionType; NodeId mPeerNodeId; + NodeId mLocalNodeId; CATValues mPeerCATs; const uint16_t mLocalSessionId; uint16_t mPeerSessionId; diff --git a/src/transport/SecureSessionTable.h b/src/transport/SecureSessionTable.h index 1804ea412c88d9..a55c71ac6881f7 100644 --- a/src/transport/SecureSessionTable.h +++ b/src/transport/SecureSessionTable.h @@ -49,6 +49,7 @@ class SecureSessionTable * @param secureSessionType secure session type * @param localSessionId unique identifier for the local node's secure unicast session context * @param peerNodeId represents peer Node's ID + * @param localNodeId represents local Node's ID * @param peerCATs represents peer CASE Authenticated Tags * @param peerSessionId represents the encryption key ID assigned by peer node * @param fabric represents fabric ID for the session @@ -61,11 +62,34 @@ class SecureSessionTable */ CHECK_RETURN_VALUE Optional CreateNewSecureSessionForTest(SecureSession::Type secureSessionType, uint16_t localSessionId, - NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId, - FabricIndex fabric, const ReliableMessageProtocolConfig & config) + NodeId peerNodeId, NodeId localNodeId, CATValues peerCATs, uint16_t peerSessionId, + FabricIndex fabricIndex, const ReliableMessageProtocolConfig & config) { + if (secureSessionType == SecureSession::Type::kCASE) + { + if ((fabricIndex == kUndefinedFabricIndex) || (localNodeId == kUndefinedNodeId) || (peerNodeId == kUndefinedNodeId)) + { + return Optional::Missing(); + } + } + else if (secureSessionType == SecureSession::Type::kPASE) + { + if ((fabricIndex != kUndefinedFabricIndex) || (localNodeId != kUndefinedNodeId) || (peerNodeId != kUndefinedNodeId)) + { + // TODO: This secure session type is infeasible! We must fix the tests + if (false) + { + return Optional::Missing(); + } + else + { + (void)fabricIndex; + } + } + } + SecureSession * result = - mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config); + mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, localNodeId, peerCATs, peerSessionId, fabricIndex, config); return result != nullptr ? MakeOptional(*result) : Optional::Missing(); } diff --git a/src/transport/Session.h b/src/transport/Session.h index b51c78964cf3d2..17225b57c4a887 100644 --- a/src/transport/Session.h +++ b/src/transport/Session.h @@ -70,6 +70,7 @@ class Session virtual void Release() {} virtual ScopedNodeId GetPeer() const = 0; + virtual ScopedNodeId GetLocalScopedNodeId() const = 0; virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0; virtual bool RequireMRP() const = 0; virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0; diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index f8da0f43a550b3..8e4937c53c5bcd 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -158,7 +158,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P mGroupClientCounter.IncrementCounter(isControlMsg); packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag); packetHeader.SetSessionType(Header::SessionType::kGroupSession); - packetHeader.SetSourceNodeId(fabric->GetNodeId()); + NodeId sourceNodeId = fabric->GetNodeId(); + packetHeader.SetSourceNodeId(sourceNodeId); if (!packetHeader.IsValidGroupMsg()) { @@ -174,7 +175,7 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P packetHeader.SetSessionId(keyContext->GetKeyHash()); CryptoContext::NonceStorage nonce; - CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), packetHeader.GetMessageCounter(), fabric->GetNodeId()); + CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), packetHeader.GetMessageCounter(), sourceNodeId); CHIP_ERROR err = SecureMessageCodec::Encrypt(CryptoContext(keyContext), nonce, payloadHeader, packetHeader, message); keyContext->Release(); ReturnErrorOnFailure(err); @@ -203,18 +204,9 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength()); CryptoContext::NonceStorage nonce; - if (session->GetSecureSessionType() == SecureSession::Type::kCASE) - { - FabricInfo * fabric = mFabricTable->FindFabricWithIndex(session->GetFabricIndex()); - VerifyOrDie(fabric != nullptr); - CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, fabric->GetNodeId()); - } - else - { - // PASE Sessions use the undefined node ID of all zeroes, since there is no node ID to use - // and the key is short-lived and always different for each PASE session. - CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, kUndefinedNodeId); - } + NodeId sourceNodeId = session->GetLocalScopedNodeId().GetNodeId(); + CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, sourceNodeId); + ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session->GetCryptoContext(), nonce, payloadHeader, packetHeader, message)); ReturnErrorOnFailure(counter.Advance()); @@ -406,9 +398,10 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH uint16_t peerSessionId, FabricIndex fabric, const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role) { + NodeId localNodeId = kUndefinedNodeId; Optional session = mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kPASE, localSessionId, peerNodeId, - CATValues{}, peerSessionId, fabric, GetLocalMRPConfig()); + localNodeId, CATValues{}, peerSessionId, fabric, GetLocalMRPConfig()); VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY); SecureSession * secureSession = session.Value()->AsSecureSession(); secureSession->SetPeerAddress(peerAddress); diff --git a/src/transport/SessionManager.h b/src/transport/SessionManager.h index 9b535229e828f3..260db7866a395b 100644 --- a/src/transport/SessionManager.h +++ b/src/transport/SessionManager.h @@ -158,8 +158,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate // Test-only: create a session on the fly. CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId, - uint16_t peerSessionId, FabricIndex fabric, const Transport::PeerAddress & peerAddress, - CryptoContext::SessionRole role); + uint16_t peerSessionId, FabricIndex fabricIndex, + const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role); /** * @brief diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h index 15a118d47ccdd1..896168933df2c6 100644 --- a/src/transport/UnauthenticatedSessionTable.h +++ b/src/transport/UnauthenticatedSessionTable.h @@ -82,7 +82,8 @@ class UnauthenticatedSession : public Session, public ReferenceCounted::Retain(); } void Release() override { ReferenceCounted::Release(); } - ScopedNodeId GetPeer() const override { return ScopedNodeId(kUndefinedNodeId, GetFabricIndex()); } + ScopedNodeId GetPeer() const override { return ScopedNodeId(GetPeerNodeId(), kUndefinedFabricIndex); } + ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(kUndefinedFabricIndex, kUndefinedFabricIndex); } Access::SubjectDescriptor GetSubjectDescriptor() const override { diff --git a/src/transport/tests/TestPairingSession.cpp b/src/transport/tests/TestPairingSession.cpp index 6ffd2037f02f06..b40c044d4d893a 100644 --- a/src/transport/tests/TestPairingSession.cpp +++ b/src/transport/tests/TestPairingSession.cpp @@ -43,6 +43,7 @@ class TestPairingSession : public PairingSession public: Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kPASE; } ScopedNodeId GetPeer() const override { return ScopedNodeId(); } + ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(); } CATValues GetPeerCATs() const override { return CATValues(); }; const ReliableMessageProtocolConfig & GetRemoteMRPConfig() const { return mRemoteMRPConfig; } diff --git a/src/transport/tests/TestPeerConnections.cpp b/src/transport/tests/TestPeerConnections.cpp index 8aea395d003414..a4587b7c374e60 100644 --- a/src/transport/tests/TestPeerConnections.cpp +++ b/src/transport/tests/TestPeerConnections.cpp @@ -46,15 +46,14 @@ PeerAddress AddressFromString(const char * str) const PeerAddress kPeer1Addr = AddressFromString("fe80::1"); const PeerAddress kPeer2Addr = AddressFromString("fe80::2"); -const PeerAddress kPeer3Addr = AddressFromString("fe80::3"); +const PeerAddress kPasePeerAddr = AddressFromString("fe80::3"); -const NodeId kPeer1NodeId = 123; -const NodeId kPeer2NodeId = 6; -const NodeId kPeer3NodeId = 81; +const NodeId kLocalNodeId = 0xC439A991071292DB; +const NodeId kCasePeer1NodeId = 123; +const NodeId kCasePeer2NodeId = 6; +const FabricIndex kFabricIndex = 8; -const SecureSession::Type kPeer1SessionType = SecureSession::Type::kCASE; -const SecureSession::Type kPeer2SessionType = SecureSession::Type::kCASE; -const SecureSession::Type kPeer3SessionType = SecureSession::Type::kPASE; +const NodeId kPasePeerNodeId = kUndefinedNodeId; // PASE is always undefined const CATValues kPeer1CATs = { { 0xABCD0001, 0xABCE0100, 0xABCD0020 } }; const CATValues kPeer2CATs = { { 0xABCD0012, kUndefinedCAT, kUndefinedCAT } }; @@ -69,28 +68,34 @@ void TestBasicFunctionality(nlTestSuite * inSuite, void * inContext) clock.SetMonotonic(100_ms64); CATValues peerCATs; - // Node ID 1, peer key 1, local key 2 - auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // First node, peer session id 1, local session id 2 + auto optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 2, kCasePeer1NodeId, kLocalNodeId, kPeer1CATs, 1, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer1SessionType); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer1NodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == SecureSession::Type::kCASE); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kCasePeer1NodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLocalNodeId() == kLocalNodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->GetPeer() == ScopedNodeId(kCasePeer1NodeId, kFabricIndex)); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->GetLocalScopedNodeId() == ScopedNodeId(kLocalNodeId, kFabricIndex)); peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer1CATs, sizeof(CATValues)) == 0); - // Node ID 2, peer key 3, local key 4 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // Second node, peer session id 3, local session id 4 + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 4, kCasePeer2NodeId, kLocalNodeId, kPeer2CATs, 3, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == kPeer2SessionType); - NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kPeer2NodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetSecureSessionType() == SecureSession::Type::kCASE); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetPeerNodeId() == kCasePeer2NodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLocalNodeId() == kLocalNodeId); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->GetPeer() == ScopedNodeId(kCasePeer2NodeId, kFabricIndex)); + NL_TEST_ASSERT(inSuite, optionalSession.Value()->GetLocalScopedNodeId() == ScopedNodeId(kLocalNodeId, kFabricIndex)); NL_TEST_ASSERT(inSuite, optionalSession.Value()->AsSecureSession()->GetLastActivityTime() == 100_ms64); peerCATs = optionalSession.Value()->AsSecureSession()->GetPeerCATs(); NL_TEST_ASSERT(inSuite, memcmp(&peerCATs, &kPeer2CATs, sizeof(CATValues)) == 0); // Insufficient space for new connections. Object is max size 2 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, - 0 /* fabricIndex */, GetLocalMRPConfig()); + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kPASE, 6, kPasePeerNodeId, kLocalNodeId, kPeer3CATs, 5, + kUndefinedFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); System::Clock::Internal::SetSystemClockForTesting(realClock); } @@ -102,17 +107,17 @@ void TestFindByKeyId(nlTestSuite * inSuite, void * inContext) System::Clock::ClockBase * realClock = &System::SystemClock(); System::Clock::Internal::SetSystemClockForTesting(&clock); - // Node ID 1, peer key 1, local key 2 - auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // First node, peer session id 1, local session id 2 + auto optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 2, kCasePeer1NodeId, kLocalNodeId, kPeer1CATs, 1, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(1).HasValue()); NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue()); - // Node ID 2, peer key 3, local key 4 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // Second node, peer session id 3, local session id 4 + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 4, kCasePeer2NodeId, kLocalNodeId, kPeer2CATs, 3, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(3).HasValue()); @@ -139,23 +144,23 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) clock.SetMonotonic(100_ms64); - // Node ID 1, peer key 1, local key 2 - auto optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // First node, peer session id 1, local session id 2 + auto optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 2, kCasePeer1NodeId, kLocalNodeId, kPeer1CATs, 1, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer1Addr); clock.SetMonotonic(200_ms64); - // Node ID 2, peer key 3, local key 4 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer2SessionType, 4, kPeer2NodeId, kPeer2CATs, 3, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // Second node, peer session id 3, local session id 4 + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 4, kCasePeer2NodeId, kLocalNodeId, kPeer2CATs, 3, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer2Addr); // cannot add before expiry clock.SetMonotonic(300_ms64); - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, - 0 /* fabricIndex */, GetLocalMRPConfig()); + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kPASE, 6, kPasePeerNodeId, kLocalNodeId, kPeer3CATs, 5, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, !optionalSession.HasValue()); // at time 300, this expires ip addr 1 @@ -165,17 +170,17 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) callInfo.lastCallPeerAddress = state.GetPeerAddress(); }); NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); - NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer1NodeId); + NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kCasePeer1NodeId); NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer1Addr); NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); // now that the connections were expired, we can add peer3 clock.SetMonotonic(300_ms64); - // Node ID 3, peer key 5, local key 6 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer3SessionType, 6, kPeer3NodeId, kPeer3CATs, 5, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // Third node (PASE session), peer session id 5, local session id 6 + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kPASE, 6, kPasePeerNodeId, kLocalNodeId, kPeer3CATs, 5, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); - optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPeer3Addr); + optionalSession.Value()->AsSecureSession()->SetPeerAddress(kPasePeerAddr); clock.SetMonotonic(400_ms64); optionalSession = connections.FindSecureSessionByLocalKey(4); @@ -198,15 +203,15 @@ void TestExpireConnections(nlTestSuite * inSuite, void * inContext) // peer 2 stays active NL_TEST_ASSERT(inSuite, callInfo.callCount == 1); - NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPeer3NodeId); - NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPeer3Addr); + NL_TEST_ASSERT(inSuite, callInfo.lastCallNodeId == kPasePeerNodeId); + NL_TEST_ASSERT(inSuite, callInfo.lastCallPeerAddress == kPasePeerAddr); NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(2).HasValue()); NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue()); NL_TEST_ASSERT(inSuite, !connections.FindSecureSessionByLocalKey(6).HasValue()); - // Node ID 1, peer key 1, local key 2 - optionalSession = connections.CreateNewSecureSessionForTest(kPeer1SessionType, 2, kPeer1NodeId, kPeer1CATs, 1, - 0 /* fabricIndex */, GetLocalMRPConfig()); + // First node, peer session id 1, local session id 2 + optionalSession = connections.CreateNewSecureSessionForTest(SecureSession::Type::kCASE, 2, kCasePeer1NodeId, kLocalNodeId, kPeer1CATs, 1, + kFabricIndex, GetLocalMRPConfig()); NL_TEST_ASSERT(inSuite, optionalSession.HasValue()); NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(2).HasValue()); NL_TEST_ASSERT(inSuite, connections.FindSecureSessionByLocalKey(4).HasValue());