Skip to content

Commit

Permalink
Fix crash on removal of accessing fabric
Browse files Browse the repository at this point in the history
Because of an access to prior fabric data that is now deleted,
in SessionManager::PrepareMessage, while trying to reply to RemoveFabric,
applications crash when RemoveFabric is done on the accessing fabric.

This crash was awaiting full fix of  project-chip#16748 to be fixed, but that
issue is much bigger scope. We can actually fix the crash with a
suggestion made by @turon
(project-chip#16748 (comment))
to keep the *local node ID* in the SecureSession so that
SessionManager does not try to look-back at the FabricTable
whenever preparing a CASE message where the fabric may be gone.

This is a root cause fix for that very crash, but does not address
the other aspects of project-chip#16748 which relate to completely cleanly
handling fabric removal edge cases.

Issue project-chip#16748
Fixes project-chip#17579
Fixes project-chip#17680
Fixes project-chip#16729

This PR does the following:
- Add local node ID to the SecureSession and fix all associated plumbing
- Use the local node ID for nonce generation in PrepareMessage rather
  than looking-up the fabric table (which may no longer hold the
  fabric that has that prior node ID)
- Improve CASE session establishment logging
- Fix the tests needed
- Fix bad comments in TestPairingSession tests

Testing done:
- Added a YAML test (TestSelfFabricRemoval.yaml) for this case
  - Validated it failed before code fixes with the previously seen
    crash.
  - Validated that it passes with the new fixes
- Added necessary tests to TestPairingSession for new methods
- Unit tests pass
- Cert tests pass
  • Loading branch information
tcarmelveilleux committed Apr 27, 2022
1 parent 3dce396 commit 2744efc
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ bool emberAfOperationalCredentialsClusterRemoveFabricCallback(app::CommandHandle
{
SendNOCResponse(commandObj, commandPath, OperationalCertStatus::kSuccess, fabricBeingRemoved, CharSpan());

// Use a more direct getter for FabricIndex from commandObj
chip::Messaging::ExchangeContext * ec = commandObj->GetExchangeContext();
FabricIndex currentFabricIndex = commandObj->GetAccessingFabricIndex();
if (currentFabricIndex == fabricBeingRemoved)
Expand Down
20 changes: 19 additions & 1 deletion src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ void CASESession::Clear()
Crypto::ClearSecretData(mIPK);

AbortExchange();

mLocalNodeId = kUndefinedNodeId;
mPeerNodeId = kUndefinedNodeId;
mFabricInfo = nullptr;
}

void CASESession::AbortExchange()
Expand Down Expand Up @@ -257,6 +261,10 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric

mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout + mExchangeCtxt->GetSessionHandle()->GetAckTimeout());
mPeerNodeId = peerNodeId;
mLocalNodeId = fabric->GetNodeId();

ChipLogProgress(SecureChannel, "Initiating session on local FabricIndex %u from 0x" ChipLogFormatX64 " -> 0x" ChipLogFormatX64,
static_cast<unsigned>(fabric->GetFabricIndex()), ChipLogValueX64(mLocalNodeId), ChipLogValueX64(mPeerNodeId));

err = SendSigma1();
SuccessOrExit(err);
Expand Down Expand Up @@ -336,9 +344,13 @@ CHIP_ERROR CASESession::RecoverInitiatorIpk()
size_t ipkIndex = (ipkKeySet.num_keys_used > 1) ? ((ipkKeySet.num_keys_used - 1) - 1) : 0;
memcpy(&mIPK[0], ipkKeySet.epoch_keys[ipkIndex].key, sizeof(mIPK));

// Leaving this logging code for debug, but this cannot be enabled at runtime
// since it leaks private security material.
#if 0
ChipLogProgress(SecureChannel, "RecoverInitiatorIpk: GroupDataProvider %p, Got IPK for FabricIndex %u", mGroupDataProvider,
static_cast<unsigned>(mFabricInfo->GetFabricIndex()));
ChipLogByteSpan(SecureChannel, ByteSpan(mIPK));
#endif

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -492,6 +504,7 @@ CHIP_ERROR CASESession::FindLocalNodeFromDestionationId(const ByteSpan & destina
MutableByteSpan ipkSpan(mIPK);
CopySpanToMutableSpan(candidateIpkSpan, ipkSpan);
mFabricInfo = &fabricInfo;
mLocalNodeId = nodeId;
break;
}
}
Expand Down Expand Up @@ -524,6 +537,7 @@ CHIP_ERROR CASESession::TryResumeSession(SessionResumptionStorage::ConstResumpti
return CHIP_ERROR_INTERNAL;

mPeerNodeId = node.GetNodeId();
mLocalNodeId = mFabricInfo->GetNodeId();

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -569,11 +583,15 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)
return CHIP_NO_ERROR;
}

// Attempt to match the initiator's desired destination based on local fabric table.
err = FindLocalNodeFromDestionationId(destinationIdentifier, initiatorRandom);
if (err == CHIP_NO_ERROR)
{
ChipLogProgress(SecureChannel, "CASE matched destination ID: fabricIndex %u, NodeID 0x" ChipLogFormatX64,
static_cast<unsigned>(mFabricInfo->GetFabricIndex()), ChipLogValueX64(mFabricInfo->GetNodeId()));
static_cast<unsigned>(mFabricInfo->GetFabricIndex()), ChipLogValueX64(mLocalNodeId));

// Side-effect of FindLocalNodeFromDestionationId success was that mFabricInfo/mLocalNodeId are now
// set to the local fabric and associated NodeId that was targeted by the initiator.
}
else
{
Expand Down
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,

Transport::SecureSession::Type GetSecureSessionType() const override { return Transport::SecureSession::Type::kCASE; }
ScopedNodeId GetPeer() const override { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); }
ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(mLocalNodeId, GetFabricIndex()); }
CATValues GetPeerCATs() const override { return mPeerCATs; };

/**
Expand Down Expand Up @@ -253,6 +254,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
FabricTable * mFabricsTable = nullptr;
const FabricInfo * mFabricInfo = nullptr;
NodeId mPeerNodeId = kUndefinedNodeId;
NodeId mLocalNodeId = kUndefinedNodeId;
CATValues mPeerCATs;

// This field is only used for CASE responder, when during sending sigma2 and waiting for sigma3
Expand Down
7 changes: 7 additions & 0 deletions src/protocols/secure_channel/PASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler,
{
return ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex);
}

ScopedNodeId GetLocalScopedNodeId() const override
{
// For PASE, source is always the undefined node ID
return ScopedNodeId();
}

CATValues GetPeerCATs() const override { return CATValues(); };

CHIP_ERROR OnUnsolicitedMessageReceived(const PayloadHeader & payloadHeader, ExchangeDelegate *& newDelegate) override;
Expand Down
12 changes: 7 additions & 5 deletions src/transport/GroupSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace Transport {
class IncomingGroupSession : public Session
{
public:
IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId sourceNodeId) : mGroupId(group), mSourceNodeId(sourceNodeId)
IncomingGroupSession(GroupId group, FabricIndex fabricIndex, NodeId peerNodeId) : mGroupId(group), mPeerNodeId(peerNodeId)
{
SetFabricIndex(fabricIndex);
}
Expand All @@ -38,7 +38,8 @@ class IncomingGroupSession : public Session
const char * GetSessionTypeString() const override { return "incoming group"; };
#endif

ScopedNodeId GetPeer() const override { return ScopedNodeId(mSourceNodeId, GetFabricIndex()); }
ScopedNodeId GetPeer() const override { return ScopedNodeId(mPeerNodeId, GetFabricIndex()); }
ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(kUndefinedNodeId, GetFabricIndex()); }

Access::SubjectDescriptor GetSubjectDescriptor() const override
{
Expand Down Expand Up @@ -68,11 +69,9 @@ class IncomingGroupSession : public Session

GroupId GetGroupId() const { return mGroupId; }

NodeId GetSourceNodeId() const { return mSourceNodeId; }

private:
const GroupId mGroupId;
const NodeId mSourceNodeId;
const NodeId mPeerNodeId;
};

class OutgoingGroupSession : public Session
Expand All @@ -86,7 +85,10 @@ class OutgoingGroupSession : public Session
const char * GetSessionTypeString() const override { return "outgoing group"; };
#endif

// Peer node ID is unused: users care about the group, not the node
ScopedNodeId GetPeer() const override { return ScopedNodeId(); }
// Local node ID is unused: users care about the group, not the node
ScopedNodeId GetLocalScopedNodeId() const override { return ScopedNodeId(); }

Access::SubjectDescriptor GetSubjectDescriptor() const override
{
Expand Down
2 changes: 1 addition & 1 deletion src/transport/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &

// Call Activate last, otherwise errors on anything after would lead to
// a partially valid session.
secureSession->Activate(GetSecureSessionType(), GetPeer(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);
secureSession->Activate(GetSecureSessionType(), GetPeer(), GetLocalScopedNodeId(), GetPeerCATs(), peerSessionId, mRemoteMRPConfig);

ChipLogDetail(Inet, "New secure session created for device " ChipLogFormatScopedNodeId ", LSID:%d PSID:%d!",
ChipLogValueScopedNodeId(GetPeer()), secureSession->GetLocalSessionId(), peerSessionId);
Expand Down
1 change: 1 addition & 0 deletions src/transport/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DLL_EXPORT PairingSession

virtual Transport::SecureSession::Type GetSecureSessionType() const = 0;
virtual ScopedNodeId GetPeer() const = 0;
virtual ScopedNodeId GetLocalScopedNodeId() const = 0;
virtual CATValues GetPeerCATs() const = 0;

Optional<uint16_t> GetLocalSessionId() const
Expand Down
5 changes: 0 additions & 5 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
namespace chip {
namespace Transport {

ScopedNodeId SecureSession::GetPeer() const
{
return ScopedNodeId(mPeerNodeId, GetFabricIndex());
}

Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const
{
Access::SubjectDescriptor subjectDescriptor;
Expand Down
38 changes: 30 additions & 8 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ class SecureSession : public Session

// TODO: This constructor should be private. Tests should allocate a
// kPending session and then call Activate(), just like non-test code does.
SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
SecureSession(Type secureSessionType, uint16_t localSessionId, NodeId peerNodeId, NodeId localNodeId, CATValues peerCATs, uint16_t peerSessionId,
FabricIndex fabric, const ReliableMessageProtocolConfig & config) :
mSecureSessionType(secureSessionType),
mPeerNodeId(peerNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mPeerNodeId(peerNodeId), mLocalNodeId(localNodeId), mPeerCATs(peerCATs), mLocalSessionId(localSessionId), mPeerSessionId(peerSessionId),
mLastActivityTime(System::SystemClock().GetMonotonicTimestamp()),
mLastPeerActivityTime(System::SystemClock().GetMonotonicTimestamp()), mMRPConfig(config)
{
Expand All @@ -89,7 +89,7 @@ class SecureSession : public Session
* receives a local session ID, but no other state.
*/
SecureSession(uint16_t localSessionId) :
SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig())
SecureSession(Type::kPending, localSessionId, kUndefinedNodeId, kUndefinedNodeId, CATValues{}, 0, kUndefinedFabricIndex, GetLocalMRPConfig())
{}

/**
Expand All @@ -98,15 +98,25 @@ class SecureSession : public Session
* PASE, setting internal state according to the parameters used and
* discovered during session establishment.
*/
void Activate(Type secureSessionType, const ScopedNodeId & peer, CATValues peerCATs, uint16_t peerSessionId,
const ReliableMessageProtocolConfig & config)
void Activate(Type secureSessionType, const ScopedNodeId & peerNode, const ScopedNodeId & localNode, CATValues peerCATs,
uint16_t peerSessionId, const ReliableMessageProtocolConfig & config)
{
VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex());

// PASE sessions must always start unassociated with a Fabric!
VerifyOrDie(!((secureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex)));
// CASE sessions must always start "associated" a given Fabric!
VerifyOrDie(!((secureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex)));
// CASE sessions can only be activated against operational node IDs!
VerifyOrDie(!((secureSessionType == Type::kCASE) && (!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId()))));

mSecureSessionType = secureSessionType;
mPeerNodeId = peer.GetNodeId();
mPeerNodeId = peerNode.GetNodeId();
mLocalNodeId = localNode.GetNodeId();
mPeerCATs = peerCATs;
mPeerSessionId = peerSessionId;
mMRPConfig = config;
SetFabricIndex(peer.GetFabricIndex());
SetFabricIndex(peerNode.GetFabricIndex());
}
~SecureSession() override { NotifySessionReleased(); }

Expand All @@ -120,7 +130,16 @@ class SecureSession : public Session
const char * GetSessionTypeString() const override { return "secure"; };
#endif

ScopedNodeId GetPeer() const override;
ScopedNodeId GetPeer() const override
{
return ScopedNodeId(mPeerNodeId, GetFabricIndex());
}

ScopedNodeId GetLocalScopedNodeId() const override
{
return ScopedNodeId(mLocalNodeId, GetFabricIndex());
}

Access::SubjectDescriptor GetSubjectDescriptor() const override;

bool RequireMRP() const override { return GetPeerAddress().GetTransportType() == Transport::Type::kUdp; }
Expand All @@ -147,6 +166,8 @@ class SecureSession : public Session
bool IsPASESession() const { return GetSecureSessionType() == Type::kPASE; }
bool IsActiveSession() const { return GetSecureSessionType() != Type::kPending; }
NodeId GetPeerNodeId() const { return mPeerNodeId; }
NodeId GetLocalNodeId() const { return mLocalNodeId; }

CATValues GetPeerCATs() const { return mPeerCATs; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }
Expand Down Expand Up @@ -192,6 +213,7 @@ class SecureSession : public Session
private:
Type mSecureSessionType;
NodeId mPeerNodeId;
NodeId mLocalNodeId;
CATValues mPeerCATs;
const uint16_t mLocalSessionId;
uint16_t mPeerSessionId;
Expand Down
30 changes: 27 additions & 3 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SecureSessionTable
* @param secureSessionType secure session type
* @param localSessionId unique identifier for the local node's secure unicast session context
* @param peerNodeId represents peer Node's ID
* @param localNodeId represents local Node's ID
* @param peerCATs represents peer CASE Authenticated Tags
* @param peerSessionId represents the encryption key ID assigned by peer node
* @param fabric represents fabric ID for the session
Expand All @@ -61,11 +62,34 @@ class SecureSessionTable
*/
CHECK_RETURN_VALUE
Optional<SessionHandle> CreateNewSecureSessionForTest(SecureSession::Type secureSessionType, uint16_t localSessionId,
NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
FabricIndex fabric, const ReliableMessageProtocolConfig & config)
NodeId peerNodeId, NodeId localNodeId, CATValues peerCATs, uint16_t peerSessionId,
FabricIndex fabricIndex, const ReliableMessageProtocolConfig & config)
{
if (secureSessionType == SecureSession::Type::kCASE)
{
if ((fabricIndex == kUndefinedFabricIndex) || (localNodeId == kUndefinedNodeId) || (peerNodeId == kUndefinedNodeId))
{
return Optional<SessionHandle>::Missing();
}
}
else if (secureSessionType == SecureSession::Type::kPASE)
{
if ((fabricIndex != kUndefinedFabricIndex) || (localNodeId != kUndefinedNodeId) || (peerNodeId != kUndefinedNodeId))
{
// TODO: This secure session type is infeasible! We must fix the tests
if (false)
{
return Optional<SessionHandle>::Missing();
}
else
{
(void)fabricIndex;
}
}
}

SecureSession * result =
mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, localNodeId, peerCATs, peerSessionId, fabricIndex, config);
return result != nullptr ? MakeOptional<SessionHandle>(*result) : Optional<SessionHandle>::Missing();
}

Expand Down
1 change: 1 addition & 0 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Session
virtual void Release() {}

virtual ScopedNodeId GetPeer() const = 0;
virtual ScopedNodeId GetLocalScopedNodeId() const = 0;
virtual Access::SubjectDescriptor GetSubjectDescriptor() const = 0;
virtual bool RequireMRP() const = 0;
virtual const ReliableMessageProtocolConfig & GetMRPConfig() const = 0;
Expand Down
23 changes: 8 additions & 15 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
mGroupClientCounter.IncrementCounter(isControlMsg);
packetHeader.SetFlags(Header::SecFlagValues::kPrivacyFlag);
packetHeader.SetSessionType(Header::SessionType::kGroupSession);
packetHeader.SetSourceNodeId(fabric->GetNodeId());
NodeId sourceNodeId = fabric->GetNodeId();
packetHeader.SetSourceNodeId(sourceNodeId);

if (!packetHeader.IsValidGroupMsg())
{
Expand All @@ -174,7 +175,7 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P

packetHeader.SetSessionId(keyContext->GetKeyHash());
CryptoContext::NonceStorage nonce;
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), packetHeader.GetMessageCounter(), fabric->GetNodeId());
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), packetHeader.GetMessageCounter(), sourceNodeId);
CHIP_ERROR err = SecureMessageCodec::Encrypt(CryptoContext(keyContext), nonce, payloadHeader, packetHeader, message);
keyContext->Release();
ReturnErrorOnFailure(err);
Expand Down Expand Up @@ -203,18 +204,9 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
CHIP_TRACE_MESSAGE_SENT(payloadHeader, packetHeader, message->Start(), message->TotalLength());

CryptoContext::NonceStorage nonce;
if (session->GetSecureSessionType() == SecureSession::Type::kCASE)
{
FabricInfo * fabric = mFabricTable->FindFabricWithIndex(session->GetFabricIndex());
VerifyOrDie(fabric != nullptr);
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, fabric->GetNodeId());
}
else
{
// PASE Sessions use the undefined node ID of all zeroes, since there is no node ID to use
// and the key is short-lived and always different for each PASE session.
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, kUndefinedNodeId);
}
NodeId sourceNodeId = session->GetLocalScopedNodeId().GetNodeId();
CryptoContext::BuildNonce(nonce, packetHeader.GetSecurityFlags(), messageCounter, sourceNodeId);

ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session->GetCryptoContext(), nonce, payloadHeader, packetHeader, message));
ReturnErrorOnFailure(counter.Advance());

Expand Down Expand Up @@ -406,9 +398,10 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH
uint16_t peerSessionId, FabricIndex fabric,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role)
{
NodeId localNodeId = kUndefinedNodeId;
Optional<SessionHandle> session =
mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kPASE, localSessionId, peerNodeId,
CATValues{}, peerSessionId, fabric, GetLocalMRPConfig());
localNodeId, CATValues{}, peerSessionId, fabric, GetLocalMRPConfig());
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);
SecureSession * secureSession = session.Value()->AsSecureSession();
secureSession->SetPeerAddress(peerAddress);
Expand Down
4 changes: 2 additions & 2 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate

// Test-only: create a session on the fly.
CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId,
uint16_t peerSessionId, FabricIndex fabric, const Transport::PeerAddress & peerAddress,
CryptoContext::SessionRole role);
uint16_t peerSessionId, FabricIndex fabricIndex,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role);

/**
* @brief
Expand Down
Loading

0 comments on commit 2744efc

Please sign in to comment.