Skip to content

Commit

Permalink
Implement SessionHolder auto shifting
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Jun 10, 2022
1 parent 668938e commit 7b54546
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/app/OperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ void OperationalDeviceProxy::OnSessionEstablished(const SessionHandle & session)
return; // Got an invalid session, do not change any state

MoveToState(State::SecureConnected);
mInitParams.sessionManager->ShiftToSession(session);
DequeueConnectionCallbacks(CHIP_NO_ERROR);

// Do not touch this instance anymore; it might have been destroyed by a callback.
Expand Down
8 changes: 6 additions & 2 deletions src/lib/core/CASEAuthTag.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include <array>

#include <lib/core/CHIPConfig.h>
#include <lib/core/CHIPEncoding.h>
#include <lib/core/NodeId.h>
Expand All @@ -35,11 +37,11 @@ static constexpr size_t kMaxSubjectCATAttributeCount = CHIP_CONFIG_CERT_MAX_RDN_

struct CATValues
{
CASEAuthTag values[kMaxSubjectCATAttributeCount] = { kUndefinedCAT };
std::array<CASEAuthTag, kMaxSubjectCATAttributeCount> values = { kUndefinedCAT };

/* @brief Returns size of the CAT values array.
*/
static constexpr size_t size() { return ArraySize(values); }
static constexpr size_t size() { return std::tuple_size<decltype(values)>::value; }

/* @brief Returns true if subject input checks against one of the CATs in the values array.
*/
Expand All @@ -58,6 +60,8 @@ struct CATValues
return false;
}

bool operator==(const CATValues & that) const { return values == that.values; }

static constexpr size_t kSerializedLength = kMaxSubjectCATAttributeCount * sizeof(CASEAuthTag);
typedef uint8_t Serialized[kSerializedLength];

Expand Down
1 change: 1 addition & 0 deletions src/protocols/secure_channel/CASEServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void CASEServer::OnSessionEstablished(const SessionHandle & session)
{
ChipLogProgress(Inet, "CASE Session established to peer: " ChipLogFormatScopedNodeId,
ChipLogValueScopedNodeId(session->GetPeer()));
mSessionManager->ShiftToSession(session);
Cleanup();
}
} // namespace chip
9 changes: 9 additions & 0 deletions src/transport/SecureSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,14 @@ void SecureSession::Release()
ReferenceCounted<SecureSession, SecureSessionDeleter, 0, uint16_t>::Release();
}

void SecureSession::TryShiftToSession(const SessionHandle & session)
{
if (GetSecureSessionType() == SecureSession::Type::kCASE && GetPeer() == session->GetPeer() &&
GetPeerCATs() == session->AsSecureSession()->GetPeerCATs())
{
Session::DoShiftToSession(session);
}
}

} // namespace Transport
} // namespace chip
5 changes: 4 additions & 1 deletion src/transport/SecureSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec
NodeId GetPeerNodeId() const { return mPeerNodeId; }
NodeId GetLocalNodeId() const { return mLocalNodeId; }

CATValues GetPeerCATs() const { return mPeerCATs; }
const CATValues & GetPeerCATs() const { return mPeerCATs; }

void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; }

Expand Down Expand Up @@ -228,6 +228,9 @@ class SecureSession : public Session, public ReferenceCounted<SecureSession, Sec

SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; }

// This should be a private API, only meant to be called by SessionManager
void TryShiftToSession(const SessionHandle & session);

private:
enum class State : uint8_t
{
Expand Down
16 changes: 16 additions & 0 deletions src/transport/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,21 @@ System::Clock::Timeout Session::ComputeRoundTripTimeout(System::Clock::Timeout u
return GetAckTimeout() + upperlayerProcessingTimeout;
}

void Session::DoShiftToSession(const SessionHandle & session)
{
// Shift to the new session, checks are performed by the subclass implementation which is the caller.
IntrusiveList<SessionHolder>::Iterator iter = mHolders.begin();
while (iter != mHolders.end())
{
// The iterator can be invalid once it is migrated to another session. So we store its next before it is happening.
IntrusiveList<SessionHolder>::Iterator next = iter;
++next;

iter->ShiftToSession(session);

iter = next;
}
}

} // namespace Transport
} // namespace chip
4 changes: 3 additions & 1 deletion src/transport/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ class Session
SessionHandle session(*this);
while (!mHolders.Empty())
{
mHolders.begin()->OnSessionReleased(); // OnSessionReleased must remove the item from the linked list
mHolders.begin()->SessionReleased(); // OnSessionReleased must remove the item from the linked list
}
}

void SetFabricIndex(FabricIndex index) { mFabricIndex = index; }

void DoShiftToSession(const SessionHandle & session);

private:
IntrusiveList<SessionHolder> mHolders;
FabricIndex mFabricIndex = kUndefinedFabricIndex;
Expand Down
10 changes: 9 additions & 1 deletion src/transport/SessionDelegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ class DLL_EXPORT SessionDelegate
* Called when a new secure session to the same peer is established, over the delegate of SessionHolderWithDelegate object. It
* is suggested to shift to the newly created session.
*
* Our security model is built upon Exchanges and Sessions, but not SessionHolders, such that SessionHolders should be able to
* shift to a new sessoin freely. If an application is holding a session which is not intent to be shifted, it can provides
* its shifting policy by override GetNewSessionHandlingPolicy in SessionDelegate. For example SessionHolders inside
* ExchangeContext and PairingSession are not eligible for auto-shifting.
*
* Note: the default implementation orders shifting to the new session, it should be fine for all users, unless the
* SessionHolder object is expected to be sticky to a specified session.
* SessionHolder object is expected to be sticky to a specified session.
*
* Note: the implementation should not modify session pool nor session holders (eg, adding new session, removing old session),
* or else something inconsistent can be happened inside Session::DoShiftToSession.
*/
virtual NewSessionHandlingPolicy GetNewSessionHandlingPolicy() { return NewSessionHandlingPolicy::kShiftToNewSession; }

Expand Down
23 changes: 17 additions & 6 deletions src/transport/SessionHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ namespace chip {
* released when the underlying session is released. One must verify it is available before use. The object can be
* created using SessionHandle.Grab()
*/
class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
class SessionHolder : public IntrusiveListNodeBase<>
{
public:
SessionHolder() {}
~SessionHolder() override;
virtual ~SessionHolder();

SessionHolder(const SessionHolder &);
SessionHolder(SessionHolder && that);
SessionHolder & operator=(const SessionHolder &);
SessionHolder & operator=(SessionHolder && that);

// Implement SessionDelegate
void OnSessionReleased() override { Release(); }
virtual void SessionReleased() { Release(); }
virtual void ShiftToSession(const SessionHandle & session)
{
Release();
Grab(session);
}

bool Contains(const SessionHandle & session) const
{
Expand All @@ -51,7 +55,7 @@ class SessionHolder : public SessionDelegate, public IntrusiveListNodeBase<>
bool Grab(const SessionHandle & session);
void Release();

operator bool() const { return mSession.HasValue(); }
explicit operator bool() const { return mSession.HasValue(); }
Optional<SessionHandle> Get() const
{
//
Expand All @@ -78,17 +82,24 @@ class SessionHolderWithDelegate : public SessionHolder
{
public:
SessionHolderWithDelegate(SessionDelegate & delegate) : mDelegate(delegate) {}
SessionHolderWithDelegate(SessionHolder & holder, SessionDelegate & delegate) : SessionHolder(holder), mDelegate(delegate) {}
SessionHolderWithDelegate(const SessionHandle & handle, SessionDelegate & delegate) : mDelegate(delegate) { Grab(handle); }
operator bool() const { return SessionHolder::operator bool(); }

void OnSessionReleased() override
void SessionReleased() override
{
Release();

// Note, the session is already cleared during mDelegate.OnSessionReleased
mDelegate.OnSessionReleased();
}

void ShiftToSession(const SessionHandle & session) override
{
if (mDelegate.GetNewSessionHandlingPolicy() == SessionDelegate::NewSessionHandlingPolicy::kShiftToNewSession)
SessionHolder::ShiftToSession(session);
}

void DispatchSessionEvent(SessionDelegate::Event event) override { (mDelegate.*event)(); }

private:
Expand Down
38 changes: 38 additions & 0 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,27 @@ CHIP_ERROR SessionManager::InjectPaseSessionWithTestKey(SessionHolder & sessionH
return CHIP_NO_ERROR;
}

CHIP_ERROR SessionManager::InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId,
uint16_t peerSessionId, NodeId localNodeId, NodeId peerNodeId,
FabricIndex fabric, const Transport::PeerAddress & peerAddress,
CryptoContext::SessionRole role, const CATValues & cats)
{
Optional<SessionHandle> session =
mSecureSessions.CreateNewSecureSessionForTest(chip::Transport::SecureSession::Type::kCASE, localSessionId, localNodeId,
peerNodeId, cats, peerSessionId, fabric, GetLocalMRPConfig());
VerifyOrReturnError(session.HasValue(), CHIP_ERROR_NO_MEMORY);
SecureSession * secureSession = session.Value()->AsSecureSession();
secureSession->SetPeerAddress(peerAddress);

size_t secretLen = strlen(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE);
ByteSpan secret(reinterpret_cast<const uint8_t *>(CHIP_CONFIG_TEST_SHARED_SECRET_VALUE), secretLen);
ReturnErrorOnFailure(secureSession->GetCryptoContext().InitFromSecret(
secret, ByteSpan(nullptr, 0), CryptoContext::SessionInfoType::kSessionEstablishment, role));
secureSession->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(LocalSessionMessageCounter::kInitialSyncValue);
sessionHolder.Grab(session.Value());
return CHIP_NO_ERROR;
}

void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle && msg)
{
CHIP_TRACE_PREPARED_MESSAGE_RECEIVED(&peerAddress, &msg);
Expand Down Expand Up @@ -724,6 +745,23 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade
}
}

void SessionManager::ShiftToSession(const SessionHandle & handle)
{
VerifyOrDie(handle->IsSecureSession());
VerifyOrDie(handle->AsSecureSession()->GetSecureSessionType() == SecureSession::Type::kCASE);
mSecureSessions.ForEachSession([&](SecureSession * oldSession) {
if (handle->AsSecureSession() == oldSession)
return Loop::Continue;

// This will update all SessionHolder pointing to oldSession, to the provided handle.
//
// See comment of SessionDelegate::GetNewSessionHandlingPolicy about how session auto-shifting works, and how to disable it
// for specific SessionHolder in specific scenario.
oldSession->TryShiftToSession(handle);
return Loop::Continue;
});
}

Optional<SessionHandle> SessionManager::FindSecureSessionForNode(ScopedNodeId peerNodeId,
const Optional<Transport::SecureSession::Type> & type)
{
Expand Down
7 changes: 7 additions & 0 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
CHIP_ERROR InjectPaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, NodeId peerNodeId,
uint16_t peerSessionId, FabricIndex fabricIndex,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role);
CHIP_ERROR InjectCaseSessionWithTestKey(SessionHolder & sessionHolder, uint16_t localSessionId, uint16_t peerSessionId,
NodeId localNodeId, NodeId peerNodeId, FabricIndex fabric,
const Transport::PeerAddress & peerAddress, CryptoContext::SessionRole role,
const CATValues & cats = CATValues{});

/**
* @brief
Expand Down Expand Up @@ -224,6 +228,9 @@ class DLL_EXPORT SessionManager : public TransportMgrDelegate
return mUnauthenticatedSessions.AllocInitiator(ephemeralInitiatorNodeID, peerAddress, config);
}

// Update existing SessionHolders to shift to the given session.
void ShiftToSession(const SessionHandle & handle);

//
// Find an existing secure session given a peer's scoped NodeId and a type of session to match against.
// If matching against all types of sessions is desired, NullOptional should be passed into type.
Expand Down
72 changes: 72 additions & 0 deletions src/transport/tests/TestSessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,77 @@ void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext)
sessionManager.Shutdown();
}

static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext)
{
IPAddress addr;
IPAddress::FromString("::1", addr);

NodeId aliceNodeId = 0x11223344ull;
NodeId bobNodeId = 0x12344321ull;
FabricIndex aliceFabricIndex = 1;
FabricIndex bobFabricIndex = 1;

SessionManager sessionManager;
secure_channel::MessageCounterManager gMessageCounterManager;
chip::TestPersistentStorageDelegate deviceStorage;

Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));

SessionHolder aliceToBobSession;
CHIP_ERROR err = sessionManager.InjectCaseSessionWithTestKey(aliceToBobSession, 2, 1, aliceNodeId, bobNodeId, aliceFabricIndex,
peer, CryptoContext::SessionRole::kInitiator);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

class StickySessionDelegate : public SessionDelegate
{
public:
NewSessionHandlingPolicy GetNewSessionHandlingPolicy() override { return NewSessionHandlingPolicy::kStayAtOldSession; }
void OnSessionReleased() override {}
} delegate;

SessionHolderWithDelegate stickyAliceToBobSession(aliceToBobSession.Get().Value(), delegate);
NL_TEST_ASSERT(inSuite, aliceToBobSession.Contains(stickyAliceToBobSession.Get().Value()));

SessionHolder bobToAliceSession;
err = sessionManager.InjectCaseSessionWithTestKey(bobToAliceSession, 1, 2, bobNodeId, aliceNodeId, bobFabricIndex, peer,
CryptoContext::SessionRole::kResponder);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

SessionHolder newAliceToBobSession;
err = sessionManager.InjectCaseSessionWithTestKey(newAliceToBobSession, 3, 4, aliceNodeId, bobNodeId, aliceFabricIndex, peer,
CryptoContext::SessionRole::kInitiator);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Here we got 3 sessions, and 4 holders:
// 1. alice -> bob: aliceToBobSession, stickyAliceToBobSession
// 2. alice <- bob: bobToAliceSession
// 3. alice -> bob: newAliceToBobSession

SecureSession * session1 = aliceToBobSession->AsSecureSession();
SecureSession * session2 = bobToAliceSession->AsSecureSession();
SecureSession * session3 = newAliceToBobSession->AsSecureSession();

NL_TEST_ASSERT(inSuite, session1 != session3);
NL_TEST_ASSERT(inSuite, stickyAliceToBobSession->AsSecureSession() == session1);

// Now shift the 1st session to the 3rd one, after shifting, holders should be:
// 1. alice -> bob: stickyAliceToBobSession
// 2. alice <- bob: bobToAliceSession
// 3. alice -> bob: aliceToBobSession, newAliceToBobSession
sessionManager.ShiftToSession(newAliceToBobSession.Get().Value());

NL_TEST_ASSERT(inSuite, aliceToBobSession);
NL_TEST_ASSERT(inSuite, stickyAliceToBobSession);
NL_TEST_ASSERT(inSuite, newAliceToBobSession);

NL_TEST_ASSERT(inSuite, stickyAliceToBobSession->AsSecureSession() == session1);
NL_TEST_ASSERT(inSuite, bobToAliceSession->AsSecureSession() == session2);
NL_TEST_ASSERT(inSuite, aliceToBobSession->AsSecureSession() == session3);
NL_TEST_ASSERT(inSuite, newAliceToBobSession->AsSecureSession() == session3);

sessionManager.Shutdown();
}

// Test Suite

/**
Expand All @@ -852,6 +923,7 @@ const nlTest sTests[] =
NL_TEST_DEF("Too-old counter Test", SendPacketWithTooOldCounterTest),
NL_TEST_DEF("Session Allocation Test", SessionAllocationTest),
NL_TEST_DEF("Session Counter Exhausted Test", SessionCounterExhaustedTest),
NL_TEST_DEF("SessionShiftingTest", SessionShiftingTest),

NL_TEST_SENTINEL()
};
Expand Down

0 comments on commit 7b54546

Please sign in to comment.