From 3eadd0b8abf44a147d0c849364baf0122d94c01e Mon Sep 17 00:00:00 2001 From: Boris Zbarsky Date: Thu, 14 Jul 2022 14:34:50 -0400 Subject: [PATCH] Fix session allocation loop when shutting down with an open commissioning window. (#20715) * Fix session allocation loop when shutting down with an open commissioning window. After #20487 if we shut down with a commissioning window open we end up in a loop where the session manager shutdown marks the tentative PASE session for eviction, we treat that as a commissioning error and start listening for PASE again, creating a new session, etc. With a heap pool this ends up happening to work in that we keep evicting the new sessions until we hit the 20-attempt limit and close the commissioning window. With a non-heap pool, I sort of wonder what happens, exactly. The fix here is in two parts, with either part enough on its own to fix the behavior described above: 1) Shut down the commissioning window manager earlier, before we shut down the session manager. And correspondingly move its initialization during server init later. 2) Once session manager starts shutdown, refuse to create any new sessions. * Fix unit tests. --- src/app/server/Server.cpp | 14 ++-- .../secure_channel/tests/TestCASESession.cpp | 48 ++++++++++-- .../secure_channel/tests/TestPASESession.cpp | 74 ++++++++++++++----- src/transport/SessionManager.cpp | 5 +- src/transport/tests/TestSessionManager.cpp | 42 ++++++++++- 5 files changed, 148 insertions(+), 35 deletions(-) diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 4330446f410698..6e8799911c554b 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -125,9 +125,6 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) // TODO(16969): Remove chip::Platform::MemoryInit() call from Server class, it belongs to outer code chip::Platform::MemoryInit(); - SuccessOrExit(err = mCommissioningWindowManager.Init(this)); - mCommissioningWindowManager.SetAppDelegate(initParams.appDelegate); - // Initialize PersistentStorageDelegate-based storage mDeviceStorage = initParams.persistentStorageDelegate; mSessionResumptionStorage = initParams.sessionResumptionStorage; @@ -165,9 +162,6 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) mAclStorage = initParams.aclStorage; SuccessOrExit(err = mAclStorage->Init(*mDeviceStorage, mFabrics.begin(), mFabrics.end())); - app::DnssdServer::Instance().SetFabricTable(&mFabrics); - app::DnssdServer::Instance().SetCommissioningModeProvider(&mCommissioningWindowManager); - mGroupsProvider = initParams.groupDataProvider; SetGroupDataProvider(mGroupsProvider); @@ -221,6 +215,12 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) err = mUnsolicitedStatusHandler.Init(&mExchangeMgr); SuccessOrExit(err); + SuccessOrExit(err = mCommissioningWindowManager.Init(this)); + mCommissioningWindowManager.SetAppDelegate(initParams.appDelegate); + + app::DnssdServer::Instance().SetFabricTable(&mFabrics); + app::DnssdServer::Instance().SetCommissioningModeProvider(&mCommissioningWindowManager); + err = chip::app::InteractionModelEngine::GetInstance()->Init(&mExchangeMgr, &GetFabricTable()); SuccessOrExit(err); @@ -424,6 +424,7 @@ void Server::Shutdown() chip::Dnssd::Resolver::Instance().Shutdown(); chip::app::InteractionModelEngine::GetInstance()->Shutdown(); + mCommissioningWindowManager.Shutdown(); mMessageCounterManager.Shutdown(); mExchangeMgr.Shutdown(); mSessions.Shutdown(); @@ -431,7 +432,6 @@ void Server::Shutdown() mAccessControl.Finish(); Credentials::SetGroupDataProvider(nullptr); mAttributePersister.Shutdown(); - mCommissioningWindowManager.Shutdown(); // TODO(16969): Remove chip::Platform::MemoryInit() call from Server class, it belongs to outer code chip::Platform::MemoryShutdown(); } diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 792f60af032149..9c62a48e046908 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -57,6 +57,39 @@ using TestContext = Test::LoopbackMessagingContext; namespace chip { namespace { +class TemporarySessionManager +{ +public: + TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx) + { + NL_TEST_ASSERT(suite, + CHIP_NO_ERROR == + mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(), + &mStorage, &ctx.GetFabricTable())); + // The setup here is really weird: we are using one session manager for + // the actual messages we send (the PASE handshake, so the + // unauthenticated sessions) and a different one for allocating the PASE + // sessions. Since our Init() set us up as the thing to handle messages + // on the transport manager, undo that. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + ~TemporarySessionManager() + { + mSessionManager.Shutdown(); + // Reset the session manager on the transport again, just in case + // shutdown messed with it. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + operator SessionManager &() { return mSessionManager; } + +private: + TestContext & mCtx; + TestPersistentStorageDelegate mStorage; + SessionManager mSessionManager; +}; + CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistentStorageDelegate * testStorage, chip::Crypto::OperationalKeystore * opKeyStore, chip::Credentials::PersistentStorageOpCertStore * opCertStore) @@ -281,7 +314,8 @@ class TestCASESession void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; @@ -312,7 +346,7 @@ void TestCASESession::SecurePairingWaitTest(nlTestSuite * inSuite, void * inCont void TestCASESession::SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; @@ -425,7 +459,9 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S void TestCASESession::SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestCASESecurePairingDelegate delegateCommissioner; CASESession pairingCommissioner; pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); @@ -902,13 +938,13 @@ void TestCASESession::SessionResumptionStorage(nlTestSuite * inSuite, void * inC #if CONFIG_BUILD_FOR_HOST_UNIT_TEST void TestCASESession::SimulateUpdateNOCInvalidatePendingEstablishment(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestCASESecurePairingDelegate delegateCommissioner; CASESession pairingCommissioner; pairingCommissioner.SetGroupDataProvider(&gCommissionerGroupDataProvider); - TestContext & ctx = *reinterpret_cast(inContext); - TestCASESecurePairingDelegate delegateAccessory; CASESession pairingAccessory; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index cbf3bbe1a39f5b..932d665a22b2c8 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -111,12 +111,45 @@ class MockAppDelegate : public ExchangeDelegate void OnResponseTimeout(ExchangeContext * ec) override {} }; +class TemporarySessionManager +{ +public: + TemporarySessionManager(nlTestSuite * suite, TestContext & ctx) : mCtx(ctx) + { + NL_TEST_ASSERT(suite, + CHIP_NO_ERROR == + mSessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &ctx.GetMessageCounterManager(), + &mStorage, &ctx.GetFabricTable())); + // The setup here is really weird: we are using one session manager for + // the actual messages we send (the PASE handshake, so the + // unauthenticated sessions) and a different one for allocating the PASE + // sessions. Since our Init() set us up as the thing to handle messages + // on the transport manager, undo that. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + ~TemporarySessionManager() + { + mSessionManager.Shutdown(); + // Reset the session manager on the transport again, just in case + // shutdown messed with it. + mCtx.GetTransportMgr().SetSessionManager(&mCtx.GetSecureSessionManager()); + } + + operator SessionManager &() { return mSessionManager; } + +private: + TestContext & mCtx; + TestPersistentStorageDelegate mStorage; + SessionManager mSessionManager; +}; + using namespace System::Clock::Literals; void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; @@ -157,7 +190,7 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); // Test all combinations of invalid parameters TestSecurePairingDelegate delegate; @@ -285,11 +318,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, S void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, Optional::Missing(), @@ -298,11 +332,12 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -312,11 +347,12 @@ void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -326,11 +362,12 @@ void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inCon void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32); ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32); @@ -341,11 +378,12 @@ void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContex void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) { - SessionManager sessionManager; + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; - TestContext & ctx = *reinterpret_cast(inContext); - auto & loopback = ctx.GetLoopback(); + auto & loopback = ctx.GetLoopback(); loopback.Reset(); loopback.mNumMessagesToDrop = 2; SecurePairingHandshakeTestCommon(inSuite, inContext, sessionManager, pairingCommissioner, @@ -358,7 +396,7 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); - SessionManager sessionManager; + TemporarySessionManager sessionManager(inSuite, ctx); TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index bd84ab651ba84d..bdc8573059b81d 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -112,6 +112,9 @@ void SessionManager::Shutdown() mFabricTable = nullptr; } + // Ensure that we don't create new sessions as we iterate our session table. + mState = State::kNotReady; + mSecureSessions.ForEachSession([&](auto session) { session->MarkForEviction(); return Loop::Continue; @@ -119,7 +122,6 @@ void SessionManager::Shutdown() mMessageCounterManager = nullptr; - mState = State::kNotReady; mSystemLayer = nullptr; mTransportMgr = nullptr; mCB = nullptr; @@ -386,6 +388,7 @@ void SessionManager::ExpireAllPASESessions() Optional SessionManager::AllocateSession(SecureSession::Type secureSessionType, const ScopedNodeId & sessionEvictionHint) { + VerifyOrReturnValue(mState == State::kInitialized, NullOptional); return mSecureSessions.CreateNewSecureSession(secureSessionType, sessionEvictionHint); } diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 45c5468ca61a46..da302ea9bb571c 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -702,7 +702,19 @@ static void RandomSessionIdAllocatorOffset(nlTestSuite * inSuite, SessionManager void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) { + TestContext & ctx = *reinterpret_cast(inContext); + + FabricTableHolder fabricTableHolder; + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init()); + + secure_channel::MessageCounterManager messageCounterManager; + TestPersistentStorageDelegate deviceStorage1, deviceStorage2; + SessionManager sessionManager; + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage1, + &fabricTableHolder.GetFabricTable())); // Allocate a session. uint16_t sessionId1; @@ -735,10 +747,24 @@ void SessionAllocationTest(nlTestSuite * inSuite, void * inContext) } // Reconstruct the Session Manager to reset state. + sessionManager.Shutdown(); sessionManager.~SessionManager(); new (&sessionManager) SessionManager(); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage2, + &fabricTableHolder.GetFabricTable())); + + // Allocate a single session so we know what random id we are starting at. + { + auto handle = sessionManager.AllocateSession( + Transport::SecureSession::Type::kPASE, + ScopedNodeId(NodeIdFromPAKEKeyId(kDefaultCommissioningPasscodeId), kUndefinedFabricIndex)); + NL_TEST_ASSERT(inSuite, handle.HasValue()); + prevSessionId = handle.Value()->AsSecureSession()->GetLocalSessionId(); + handle.Value()->AsSecureSession()->MarkForEviction(); + } - prevSessionId = 0; // Verify that we increment session ID by 1 for each allocation (except for // the wraparound case where we skip session ID 0), even when allocated // sessions are immediately freed. @@ -886,6 +912,8 @@ void SessionCounterExhaustedTest(nlTestSuite * inSuite, void * inContext) static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext) { + TestContext & ctx = *reinterpret_cast(inContext); + IPAddress addr; IPAddress::FromString("::1", addr); @@ -894,9 +922,17 @@ static void SessionShiftingTest(nlTestSuite * inSuite, void * inContext) FabricIndex aliceFabricIndex = 1; FabricIndex bobFabricIndex = 1; + FabricTableHolder fabricTableHolder; + secure_channel::MessageCounterManager messageCounterManager; + TestPersistentStorageDelegate deviceStorage; + SessionManager sessionManager; - secure_channel::MessageCounterManager gMessageCounterManager; - chip::TestPersistentStorageDelegate deviceStorage; + + NL_TEST_ASSERT(inSuite, CHIP_NO_ERROR == fabricTableHolder.Init()); + NL_TEST_ASSERT(inSuite, + CHIP_NO_ERROR == + sessionManager.Init(&ctx.GetSystemLayer(), &ctx.GetTransportMgr(), &messageCounterManager, &deviceStorage, + &fabricTableHolder.GetFabricTable())); Transport::PeerAddress peer(Transport::PeerAddress::UDP(addr, CHIP_PORT));