Skip to content

Commit

Permalink
Fix session allocation loop when shutting down with an open commissio…
Browse files Browse the repository at this point in the history
…ning window. (#20715)

* Fix session allocation loop when shutting down with an open commissioning window.

After #20487 if we shut down with a commissioning window open we end up in a
loop where the session manager shutdown marks the tentative PASE session for
eviction, we treat that as a commissioning error and start listening for PASE
again, creating a new session, etc.  With a heap pool this ends up happening to
work in that we keep evicting the new sessions until we hit the 20-attempt limit
and close the commissioning window.  With a non-heap pool, I sort of wonder what
happens, exactly.

The fix here is in two parts, with either part enough on its own to fix the
behavior described above:

1) Shut down the commissioning window manager earlier, before we shut down the
   session manager.  And correspondingly move its initialization during server
   init later.

2) Once session manager starts shutdown, refuse to create any new sessions.

* Fix unit tests.
  • Loading branch information
bzbarsky-apple authored and pull[bot] committed Oct 19, 2023
1 parent 8bdae9d commit 3eadd0b
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 35 deletions.
14 changes: 7 additions & 7 deletions src/app/server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams)
// TODO(16969): Remove chip::Platform::MemoryInit() call from Server class, it belongs to outer code
chip::Platform::MemoryInit();

SuccessOrExit(err = mCommissioningWindowManager.Init(this));
mCommissioningWindowManager.SetAppDelegate(initParams.appDelegate);

// Initialize PersistentStorageDelegate-based storage
mDeviceStorage = initParams.persistentStorageDelegate;
mSessionResumptionStorage = initParams.sessionResumptionStorage;
Expand Down Expand Up @@ -165,9 +162,6 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams)
mAclStorage = initParams.aclStorage;
SuccessOrExit(err = mAclStorage->Init(*mDeviceStorage, mFabrics.begin(), mFabrics.end()));

app::DnssdServer::Instance().SetFabricTable(&mFabrics);
app::DnssdServer::Instance().SetCommissioningModeProvider(&mCommissioningWindowManager);

mGroupsProvider = initParams.groupDataProvider;
SetGroupDataProvider(mGroupsProvider);

Expand Down Expand Up @@ -221,6 +215,12 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams)
err = mUnsolicitedStatusHandler.Init(&mExchangeMgr);
SuccessOrExit(err);

SuccessOrExit(err = mCommissioningWindowManager.Init(this));
mCommissioningWindowManager.SetAppDelegate(initParams.appDelegate);

app::DnssdServer::Instance().SetFabricTable(&mFabrics);
app::DnssdServer::Instance().SetCommissioningModeProvider(&mCommissioningWindowManager);

err = chip::app::InteractionModelEngine::GetInstance()->Init(&mExchangeMgr, &GetFabricTable());
SuccessOrExit(err);

Expand Down Expand Up @@ -424,14 +424,14 @@ void Server::Shutdown()

chip::Dnssd::Resolver::Instance().Shutdown();
chip::app::InteractionModelEngine::GetInstance()->Shutdown();
mCommissioningWindowManager.Shutdown();
mMessageCounterManager.Shutdown();
mExchangeMgr.Shutdown();
mSessions.Shutdown();
mTransports.Close();
mAccessControl.Finish();
Credentials::SetGroupDataProvider(nullptr);
mAttributePersister.Shutdown();
mCommissioningWindowManager.Shutdown();
// TODO(16969): Remove chip::Platform::MemoryInit() call from Server class, it belongs to outer code
chip::Platform::MemoryShutdown();
}
Expand Down
48 changes: 42 additions & 6 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ using TestContext = Test::LoopbackMessagingContext;
namespace chip {
namespace {

class TemporarySessionManager
{
public:
TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx)
{
NL_TEST_ASSERT(suite,
CHIP_NO_ERROR ==
mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(),
&mStorage, &ctx.GetFabricTable()));
// The setup here is really weird: we are using one session manager for
// the actual messages we send (the PASE handshake, so the
// unauthenticated sessions) and a different one for allocating the PASE
// sessions. Since our Init() set us up as the thing to handle messages
// on the transport manager, undo that.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

~TemporarySessionManager()
{
mSessionManager.Shutdown();
// Reset the session manager on the transport again, just in case
// shutdown messed with it.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

operator SessionManager &() { return mSessionManager; }

private:
TestContext & mCtx;
TestPersistentStorageDelegate mStorage;
SessionManager mSessionManager;
};

CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistentStorageDelegate * testStorage,
chip::Crypto::OperationalKeystore * opKeyStore,
chip::Credentials::PersistentStorageOpCertStore * opCertStore)
Expand Down Expand Up @@ -281,7 +314,8 @@ class TestCASESession

void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
Expand Down Expand Up @@ -312,7 +346,7 @@ void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inCont
void TestCASESession::SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestCASESecurePairingDelegate delegate;
Expand Down Expand Up @@ -425,7 +459,9 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S

void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);
Expand Down Expand Up @@ -902,13 +938,13 @@ void TestCASESession::SessionResumptionStorage(nlTestSuite * inSuite, void * inC
#if CONFIG_BUILD_FOR_HOST_UNIT_TEST
void TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestCASESecurePairingDelegate delegateCommissioner;
CASESession pairingCommissioner;
pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider);

TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

TestCASESecurePairingDelegate delegateAccessory;
CASESession pairingAccessory;

Expand Down
74 changes: 56 additions & 18 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,45 @@ class MockAppDelegate : public ExchangeDelegate
void OnResponseTimeout(ExchangeContext * ec) override {}
};

class TemporarySessionManager
{
public:
TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx)
{
NL_TEST_ASSERT(suite,
CHIP_NO_ERROR ==
mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(),
&mStorage, &ctx.GetFabricTable()));
// The setup here is really weird: we are using one session manager for
// the actual messages we send (the PASE handshake, so the
// unauthenticated sessions) and a different one for allocating the PASE
// sessions. Since our Init() set us up as the thing to handle messages
// on the transport manager, undo that.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

~TemporarySessionManager()
{
mSessionManager.Shutdown();
// Reset the session manager on the transport again, just in case
// shutdown messed with it.
mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager());
}

operator SessionManager &() { return mSessionManager; }

private:
TestContext & mCtx;
TestPersistentStorageDelegate mStorage;
SessionManager mSessionManager;
};

using namespace System::Clock::Literals;

void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
Expand Down Expand Up @@ -157,7 +190,7 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext)
void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

// Test all combinations of invalid parameters
TestSecurePairingDelegate delegate;
Expand Down Expand Up @@ -285,11 +318,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S

void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Optional<ReliableMessageProtocolConfig>::Missing(),
Expand All @@ -298,11 +332,12 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext)

void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -312,11 +347,12 @@ void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void *

void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32);
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -326,11 +362,12 @@ void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inCon

void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32);
ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32);
Expand All @@ -341,11 +378,12 @@ void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContex

void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext)
{
SessionManager sessionManager;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
auto & loopback = ctx.GetLoopback();
auto & loopback = ctx.GetLoopback();
loopback.Reset();
loopback.mNumMessagesToDrop = 2;
SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner,
Expand All @@ -358,7 +396,7 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo
void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);
SessionManager sessionManager;
TemporarySessionManager sessionManager(inSuite, ctx);

TestSecurePairingDelegate delegateCommissioner;
PASESession pairingCommissioner;
Expand Down
5 changes: 4 additions & 1 deletion src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,16 @@ void SessionManager::Shutdown()
mFabricTable = nullptr;
}

// Ensure that we don't create new sessions as we iterate our session table.
mState = State::kNotReady;

mSecureSessions.ForEachSession([&](auto session) {
session->MarkForEviction();
return Loop::Continue;
});

mMessageCounterManager = nullptr;

mState = State::kNotReady;
mSystemLayer = nullptr;
mTransportMgr = nullptr;
mCB = nullptr;
Expand Down Expand Up @@ -386,6 +388,7 @@ void SessionManager::ExpireAllPASESessions()
Optional<SessionHandle> SessionManager::AllocateSession(SecureSession::Type secureSessionType,
const ScopedNodeId & sessionEvictionHint)
{
VerifyOrReturnValue(mState == State::kInitialized, NullOptional);
return mSecureSessions.CreateNewSecureSession(secureSessionType, sessionEvictionHint);
}

Expand Down
42 changes: 39 additions & 3 deletions src/transport/tests/TestSessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,19 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager

void SessionAllocationTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

FabricTableHolder fabricTableHolder;
NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init());

secure_channel::MessageCounterManager messageCounterManager;
TestPersistentStorageDelegate deviceStorage1, deviceStorage2;

SessionManager sessionManager;
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage1,
&fabricTableHolder.GetFabricTable()));

// Allocate a session.
uint16_t sessionId1;
Expand Down Expand Up @@ -735,10 +747,24 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext)
}

// Reconstruct the Session Manager to reset state.
sessionManager.Shutdown();
sessionManager.~SessionManager();
new (&sessionManager) SessionManager();
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage2,
&fabricTableHolder.GetFabricTable()));

// Allocate a single session so we know what random id we are starting at.
{
auto handle = sessionManager.AllocateSession(
Transport::SecureSession::Type::kPASE,
ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex));
NL_TEST_ASSERT(inSuite, handle.HasValue());
prevSessionId = handle.Value()->AsSecureSession()->GetLocalSessionId();
handle.Value()->AsSecureSession()->MarkForEviction();
}

prevSessionId = 0;
// Verify that we increment session ID by 1 for each allocation (except for
// the wraparound case where we skip session ID 0), even when allocated
// sessions are immediately freed.
Expand Down Expand Up @@ -886,6 +912,8 @@ void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext)

static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext)
{
TestContext & ctx = *reinterpret_cast<TestContext *>(inContext);

IPAddress addr;
IPAddress::FromString("::1", addr);

Expand All @@ -894,9 +922,17 @@ static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext)
FabricIndex aliceFabricIndex = 1;
FabricIndex bobFabricIndex = 1;

FabricTableHolder fabricTableHolder;
secure_channel::MessageCounterManager messageCounterManager;
TestPersistentStorageDelegate deviceStorage;

SessionManager sessionManager;
secure_channel::MessageCounterManager gMessageCounterManager;
chip::TestPersistentStorageDelegate deviceStorage;

NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init());
NL_TEST_ASSERT(inSuite,
CHIP_NO_ERROR ==
sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage,
&fabricTableHolder.GetFabricTable()));

Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));

Expand Down

0 comments on commit 3eadd0b

Please sign in to comment.