Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SessionHolder auto shifting #18107

Merged
merged 6 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/lib/core/CASEAuthTag.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include <array>

#include <lib/core/CHIPConfig.h>
#include <lib/core/CHIPEncoding.h>
#include <lib/core/NodeId.h>
Expand All @@ -35,11 +37,11 @@ static constexpr size_t kMaxSubjectCATAttributeCount = CHIP_CONFIG_CERT_MAX_RDN_

struct CATValues
{
CASEAuthTag values[kMaxSubjectCATAttributeCount] = { kUndefinedCAT };
std::array<CASEAuthTag, kMaxSubjectCATAttributeCount> values = { kUndefinedCAT };

/* @brief Returns size of the CAT values array.
*/
static constexpr size_t size() { return ArraySize(values); }
static constexpr size_t size() { return std::tuple_size<decltype(values)>::value; }

/* @brief Returns true if subject input checks against one of the CATs in the values array.
*/
Expand All @@ -58,6 +60,8 @@ struct CATValues
return false;
}

bool operator==(const CATValues & that) const { return values == that.values; }

static constexpr size_t kSerializedLength = kMaxSubjectCATAttributeCount * sizeof(CASEAuthTag);
typedef uint8_t Serialized[kSerializedLength];

Expand Down
47 changes: 47 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@ void SecureSessionDeleter::Release(SecureSession * entry)
entry->mTable.ReleaseSession(entry);
}

void SecureSession::Activate(const ScopedNodeId & localNode, const ScopedNodeId & peerNode, CATValues peerCATs,
uint16_t peerSessionId, const ReliableMessageProtocolConfig & config)
{
VerifyOrDie(mState == State::kEstablishing);
VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex());

// PASE sessions must always start unassociated with a Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex)));
// CASE sessions must always start "associated" a given Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex)));
// CASE sessions can only be activated against operational node IDs!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) &&
(!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId()))));

mPeerNodeId = peerNode.GetNodeId();
mLocalNodeId = localNode.GetNodeId();
mPeerCATs = peerCATs;
mPeerSessionId = peerSessionId;
mMRPConfig = config;
SetFabricIndex(peerNode.GetFabricIndex());

Retain(); // This ref is released inside MarkForEviction
MoveToState(State::kActive);

if (mSecureSessionType == Type::kCASE)
mTable.NewerSessionAvailable(this);

ChipLogDetail(Inet, "SecureSession[%p]: Activated - Type:%d LSID:%d", this, to_underlying(mSecureSessionType), mLocalSessionId);
}

const char * SecureSession::StateToString(State state) const
{
switch (state)
Expand Down Expand Up @@ -200,5 +230,22 @@ void SecureSession::Release()
ReferenceCounted<SecureSession, SecureSessionDeleter, 0, uint16_t>::Release();
}

void SecureSession::NewerSessionAvailable(const SessionHandle & session)
{
// Shift to the new session, checks are performed by the the caller SecureSessionTable::NewerSessionAvailable.
IntrusiveList<SessionHolder>::Iterator iter = mHolders.begin();
while (iter != mHolders.end())
{
// The iterator can be invalid once the session holder is migrated to another session. So we store its next value before
// notifying the holder.
IntrusiveList<SessionHolder>::Iterator next = iter;
++next;

iter->ShiftToSession(session);

iter = next;
}
}

} // namespace Transport
} // namespace chip
34 changes: 8 additions & 26 deletions src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,8 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec
* discovered during session establishment.
*/
void Activate(const ScopedNodeId & localNode, const ScopedNodeId & peerNode, CATValues peerCATs, uint16_t peerSessionId,
const ReliableMessageProtocolConfig & config)
{
VerifyOrDie(mState == State::kEstablishing);
VerifyOrDie(peerNode.GetFabricIndex() == localNode.GetFabricIndex());

// PASE sessions must always start unassociated with a Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kPASE) && (peerNode.GetFabricIndex() != kUndefinedFabricIndex)));
// CASE sessions must always start "associated" a given Fabric!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) && (peerNode.GetFabricIndex() == kUndefinedFabricIndex)));
// CASE sessions can only be activated against operational node IDs!
VerifyOrDie(!((mSecureSessionType == Type::kCASE) &&
(!IsOperationalNodeId(peerNode.GetNodeId()) || !IsOperationalNodeId(localNode.GetNodeId()))));

mPeerNodeId = peerNode.GetNodeId();
mLocalNodeId = localNode.GetNodeId();
mPeerCATs = peerCATs;
mPeerSessionId = peerSessionId;
mMRPConfig = config;
SetFabricIndex(peerNode.GetFabricIndex());

Retain(); // This ref is released inside MarkForEviction
MoveToState(State::kActive);
ChipLogDetail(Inet, "SecureSession[%p]: Activated - Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
mLocalSessionId);
}
const ReliableMessageProtocolConfig & config);

~SecureSession() override
{
ChipLogDetail(Inet, "SecureSession[%p]: Released - Type:%d LSID:%d", this, to_underlying(mSecureSessionType),
Expand Down Expand Up @@ -213,7 +190,7 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec
NodeId GetPeerNodeId() const { return mPeerNodeId; }
NodeId GetLocalNodeId() const { return mLocalNodeId; }

CATValues GetPeerCATs() const { return mPeerCATs; }
const CATValues & GetPeerCATs() const { return mPeerCATs; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }

Expand Down Expand Up @@ -262,6 +239,11 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec

SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }

// This should be a private API, only meant to be called by SecureSessionTable
// Session holders to this session may shift to the target session regarding SessionDelegate::GetNewSessionHandlingPolicy.
// It requires that the target sessoin is also a CASE session, having the same peer and CATs as this session.
void NewerSessionAvailable(const SessionHandle & session);
kghost marked this conversation as resolved.
Show resolved Hide resolved

private:
enum class State : uint8_t
{
Expand Down
26 changes: 26 additions & 0 deletions src/transport/SecureSessionTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ class SecureSessionTable
CHECK_RETURN_VALUE
Optional<SessionHandle> FindSecureSessionByLocalKey(uint16_t localSessionId);

// Select SessionHolders which are pointing to a session with the same peer as the given session. Shift them to the given
// session.
// This is an internal API, using raw pointer to a session is allowed here.
void NewerSessionAvailable(SecureSession * session)
{
VerifyOrDie(session->GetSecureSessionType() == SecureSession::Type::kCASE);
mEntries.ForEachActiveObject([&](SecureSession * oldSession) {
if (session == oldSession)
return Loop::Continue;

SessionHandle ref(*oldSession);

// This will give all SessionHolders pointing to oldSession a chance to switch to the provided session
//
// See documentation for SessionDelegate::GetNewSessionHandlingPolicy about how session auto-shifting works, and how
// to disable it for a specific SessionHolder in a specific scenario.
if (oldSession->GetSecureSessionType() == SecureSession::Type::kCASE && oldSession->GetPeer() == session->GetPeer() &&
oldSession->GetPeerCATs() == session->GetPeerCATs())
{
oldSession->NewerSessionAvailable(SessionHandle(*session));
}

return Loop::Continue;
});
}

private:
friend class TestSecureSessionTable;

Expand Down
5 changes: 3 additions & 2 deletions src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ class Session
SessionHandle session(*this);
while (!mHolders.Empty())
{
mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list
mHolders.begin()->SessionReleased(); // SessionReleased must remove the item from the linked list
}
}

void SetFabricIndex(FabricIndex index) { mFabricIndex = index; }

private:
IntrusiveList<SessionHolder> mHolders;

private:
FabricIndex mFabricIndex = kUndefinedFabricIndex;
};

Expand Down
10 changes: 9 additions & 1 deletion src/transport/SessionDelegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ class DLL_EXPORT SessionDelegate
* Called when a new secure session to the same peer is established, over the delegate of SessionHolderWithDelegate object. It
* is suggested to shift to the newly created session.
*
* Our security model is built upon Exchanges and Sessions, but not SessionHolders, such that SessionHolders should be able to
* shift to a new session freely. If an application is holding a session which is not intended to be shifted, it can provide
* its shifting policy by overriding GetNewSessionHandlingPolicy in SessionDelegate. For example SessionHolders inside
* ExchangeContext and PairingSession are not eligible for auto-shifting.
*
* Note: the default implementation orders shifting to the new session, it should be fine for all users, unless the
* SessionHolder object is expected to be sticky to a specified session.
* SessionHolder object is expected to be sticky to a specified session.
*
* Note: the implementation MUST NOT modify the session pool or the state of session holders (eg, adding new session, removing
* old session) from inside this callback.
*/
virtual NewSessionHandlingPolicy GetNewSessionHandlingPolicy() { return NewSessionHandlingPolicy::kShiftToNewSession; }

Expand Down
22 changes: 16 additions & 6 deletions src/transport/SessionHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ namespace chip {
* released when the underlying session is released. One must verify it is available before use. The object can be
* created using SessionHandle.Grab()
*/
class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
class SessionHolder : public IntrusiveListNodeBase<>
{
public:
SessionHolder() {}
~SessionHolder() override;
virtual ~SessionHolder();

SessionHolder(const SessionHolder &);
SessionHolder(SessionHolder && that);
SessionHolder & operator=(const SessionHolder &);
SessionHolder & operator=(SessionHolder && that);

// Implement SessionDelegate
void OnSessionReleased() override { Release(); }
virtual void SessionReleased() { Release(); }
virtual void ShiftToSession(const SessionHandle & session)
{
Release();
Grab(session);
}

bool Contains(const SessionHandle & session) const
{
Expand All @@ -51,7 +55,7 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
bool Grab(const SessionHandle & session);
void Release();

operator bool() const { return mSession.HasValue(); }
explicit operator bool() const { return mSession.HasValue(); }
Optional<SessionHandle> Get() const
{
//
Expand Down Expand Up @@ -81,14 +85,20 @@ class SessionHolderWithDelegate : public SessionHolder
SessionHolderWithDelegate(const SessionHandle & handle, SessionDelegate & delegate) : mDelegate(delegate) { Grab(handle); }
operator bool() const { return SessionHolder::operator bool(); }

void OnSessionReleased() override
void SessionReleased() override
{
Release();

// Note, the session is already cleared during mDelegate.OnSessionReleased
mDelegate.OnSessionReleased();
}

void ShiftToSession(const SessionHandle & session) override
{
if (mDelegate.GetNewSessionHandlingPolicy() == SessionDelegate::NewSessionHandlingPolicy::kShiftToNewSession)
SessionHolder::ShiftToSession(session);
}

void DispatchSessionEvent(SessionDelegate::Event event) override { (mDelegate.*event)(); }

private:
Expand Down
21 changes: 21 additions & 0 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH
return CHIP_NO_ERROR;
}

CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId,
uint16_t peerSessionId, NodeId localNodeId, NodeId peerNodeId,
FabricIndex fabric, const Transport::PeerAddress & peerAddress,
CryptoContext::SessionRole role, const CATValues & cats)
{
Optional<SessionHandle> session =
mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kCASE, localSessionId, localNodeId,
peerNodeId, cats, peerSessionId, fabric, GetLocalMRPConfig());
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);
SecureSession * secureSession = session.Value()->AsSecureSession();
secureSession->SetPeerAddress(peerAddress);

size_t secretLen = strlen(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE);
ByteSpan secret(reinterpret_cast<const uint8_t *>(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen);
ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret(
secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role));
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(Transport::PeerMessageCounter::kInitialSyncValue);
sessionHolder.Grab(session.Value());
return CHIP_NO_ERROR;
}

void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg)
{
CHIP_TRACE_PREPARED_MESSAGE_RECEIVED(&peerAddress, &msg);
Expand Down
5 changes: 5 additions & 0 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId,
uint16_t peerSessionId, FabricIndex fabricIndex,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role);
CHIP_ERROR InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, uint16_t peerSessionId,
NodeId localNodeId, NodeId peerNodeId, FabricIndex fabric,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role,
const CATValues & cats = CATValues{});

/**
* @brief
Expand Down Expand Up @@ -210,6 +214,7 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
void FabricRemoved(FabricIndex fabricIndex);

TransportMgrBase * GetTransportManager() const { return mTransportMgr; }
Transport::SecureSessionTable & GetSecureSessions() { return mSecureSessions; }

/**
* @brief
Expand Down
Loading