diff --git a/src/app/server/CommissioningWindowManager.cpp b/src/app/server/CommissioningWindowManager.cpp index 279a88eba36c11..2f2454f299faf2 100644 --- a/src/app/server/CommissioningWindowManager.cpp +++ b/src/app/server/CommissioningWindowManager.cpp @@ -176,8 +176,9 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow() if (mUseECM) { ReturnErrorOnFailure(SetTemporaryDiscriminator(mECMDiscriminator)); - ReturnErrorOnFailure(mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), - mECMPasscodeID, keyID, this)); + ReturnErrorOnFailure( + mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), mECMPasscodeID, + keyID, Optional::Value(gDefaultMRPConfig), this)); // reset all advertising, indicating we are in commissioningMode app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledEnhanced); @@ -189,7 +190,8 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow() ReturnErrorOnFailure(mPairingSession.WaitForPairing( pinCode, kSpake2p_Iteration_Count, - ByteSpan(reinterpret_cast(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID, this)); + ByteSpan(reinterpret_cast(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID, + Optional::Value(gDefaultMRPConfig), this)); // reset all advertising, indicating we are in commissioningMode app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledBasic); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 18b6cffd2fe5e7..365dd0472c0008 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -859,7 +859,8 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re // TODO - Remove use of SetActive/IsActive from CommissioneeDeviceProxy device->SetActive(true); - err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, exchangeCtxt, this); + err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, + Optional::Value(mMRPConfig), exchangeCtxt, this); SuccessOrExit(err); // Immediately persist the updated mNextKeyID value diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index 346a36a304518d..6909ff5d541411 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -810,6 +810,8 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, Callback::Callback mDeviceNOCChainCallback; SetUpCodePairer mSetUpCodePairer; + + ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig; }; } // namespace Controller diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index fab7aa2b40d4f7..f2eb5b8ab7b678 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -144,10 +144,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin */ virtual CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; - const char * GetI2RSessionInfo() const override { return "Sigma I2R Key"; } - - const char * GetR2ISessionInfo() const override { return "Sigma R2I Key"; } - /** * @brief Serialize the CASESession to the given cachableSession data structure for secure pairing **/ diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 10340d0ad29ed6..5598f4cf915106 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -251,7 +251,8 @@ CHIP_ERROR PASESession::SetupSpake2p(uint32_t pbkdf2IterCount, const ByteSpan & } CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, SessionEstablishmentDelegate * delegate) + uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate) { // Return early on error here, as we have not initalized any state yet ReturnErrorCodeIf(salt.empty(), CHIP_ERROR_INVALID_ARGUMENT); @@ -281,6 +282,7 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I mNextExpectedMsg = MsgType::PBKDFParamRequest; mPairingComplete = false; mPasscodeID = 0; + mLocalMRPConfig = mrpConfig; ChipLogDetail(SecureChannel, "Waiting for PBKDF param request"); @@ -293,9 +295,10 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I } CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t passcodeID, uint16_t mySessionId, SessionEstablishmentDelegate * delegate) + uint16_t passcodeID, uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate) { - ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, delegate)); + ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, mrpConfig, delegate)); memmove(&mPASEVerifier, &verifier, sizeof(verifier)); mComputeVerifier = false; @@ -305,7 +308,8 @@ CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t p } CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) + Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate) { ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); CHIP_ERROR err = Init(mySessionId, peerSetUpPINCode, delegate); @@ -316,6 +320,8 @@ CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t SetPeerAddress(peerAddress); + mLocalMRPConfig = mrpConfig; + err = SendPBKDFParamRequest(); SuccessOrExit(err); @@ -355,13 +361,14 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() { ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom, + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; + const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom, sizeof(uint16_t), // initiatorSessionId sizeof(uint16_t), // passcodeId, - sizeof(uint8_t) // hasPBKDFParameters - /* TLV::EstimateStructOverhead(sizeof(uint16_t), - sizeof(uint16)_t), // initiatorMRPParams */ + sizeof(uint8_t), // hasPBKDFParameters + mrpParamsSize // MRP Parameters ); + System::PacketBufferHandle req = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!req.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -374,9 +381,11 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), mPasscodeID)); ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); - // TODO - Add optional MRP parameter support to PASE - // When we add MRP params here, adjust the TLV::EstimateStructOverhead call - // above accordingly. + if (mLocalMRPConfig.HasValue()) + { + ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param request"); + ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&req)); @@ -433,7 +442,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters)); - // TODO - Check if optional MRP parameters were sent. If so, cache them. + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters); SuccessOrExit(err); @@ -453,14 +462,15 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in { ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; const size_t max_msg_len = - TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom - kPBKDFParamRandomNumberSize, // responderRandom - sizeof(uint16_t), // responderSessionId - TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength) // pbkdf_parameters - /* TLV::EstimateStructOverhead(sizeof(uint16_t), - sizeof(uint16)_t), // responderMRPParams */ + TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom + kPBKDFParamRandomNumberSize, // responderRandom + sizeof(uint16_t), // responderSessionId + TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength), // pbkdf_parameters + mrpParamsSize // MRP Parameters ); + System::PacketBufferHandle resp = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!resp.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -483,8 +493,11 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in ReturnErrorOnFailure(tlvWriter.EndContainer(pbkdfParamContainer)); } - // When we add MRP params here, adjust the TLV::EstimateStructOverhead call - // above accordingly. + if (mLocalMRPConfig.HasValue()) + { + ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param response"); + ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&resp)); @@ -549,6 +562,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m if (mHavePBKDFParameters) { + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + // TODO - Add a unit test that exercises mHavePBKDFParameters path err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength)); SuccessOrExit(err); @@ -568,6 +583,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m saltLength = tlvReader.GetLength(); SuccessOrExit(err = tlvReader.GetDataPtr(salt)); + SuccessOrExit(err = tlvReader.ExitContainer(containerType)); + + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength)); SuccessOrExit(err); } diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 7a7e86de881738..7cd218c1ef9035 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -107,7 +107,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t mySessionId, - SessionEstablishmentDelegate * delegate); + Optional mrpConfig, SessionEstablishmentDelegate * delegate); /** * @brief @@ -123,7 +123,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t passcodeID, - uint16_t mySessionId, SessionEstablishmentDelegate * delegate); + uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate); /** * @brief @@ -141,7 +142,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); + Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate); /** * @brief @@ -170,10 +172,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin */ CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override; - const char * GetI2RSessionInfo() const override { return kSpake2pI2RSessionInfo; } - - const char * GetR2ISessionInfo() const override { return kSpake2pR2ISessionInfo; } - /** @brief Serialize the Pairing Session to a string. * * @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise @@ -304,6 +302,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin SessionEstablishmentExchangeDispatch mMessageDispatch; + Optional mLocalMRPConfig; + struct Spake2pErrorMsg { Spake2pErrorType error; @@ -369,10 +369,6 @@ class SecurePairingUsingTestSecret : public PairingSession return CHIP_NO_ERROR; } - const char * GetI2RSessionInfo() const override { return "i2r"; } - - const char * GetR2ISessionInfo() const override { return "r2i"; } - private: const char * kTestSecret = CHIP_CONFIG_TEST_SHARED_SECRET_VALUE; }; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index befcd1f292e486..800afa97e879cf 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -98,16 +98,20 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(1234, 500, ByteSpan(nullptr, 0), 0, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, + pairing.WaitForPairing(1234, 500, ByteSpan(nullptr, 0), 0, Optional::Missing(), + &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, nullptr) == - CHIP_ERROR_INVALID_ARGUMENT); + pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, + Optional::Missing(), + nullptr) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, &delegate) == CHIP_NO_ERROR); + pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, + Optional::Missing(), &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); } @@ -126,11 +130,13 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, nullptr, nullptr) != CHIP_NO_ERROR); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), nullptr, nullptr) != CHIP_NO_ERROR); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context, &delegate) == CHIP_NO_ERROR); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -148,14 +154,17 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == - CHIP_ERROR_BAD_REQUEST); + pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), context1, + &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; } void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, + Optional mrpCommissionerConfig, + Optional mrpAccessoryConfig, TestSecurePairingDelegate & delegateCommissioner) { TestContext & ctx = *reinterpret_cast(inContext); @@ -190,12 +199,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSALTsaltSALT", 16), 0, - &delegateAccessory) == CHIP_NO_ERROR); + mrpAccessoryConfig, &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, contextCommissioner, - &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, mrpCommissionerConfig, + contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); while (gLoopback.mMessageDropped) @@ -213,6 +222,22 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount >= 5); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); + + if (mrpCommissionerConfig.HasValue()) + { + NL_TEST_ASSERT(inSuite, + pairingAccessory.GetMRPConfig().mIdleRetransTimeout == mrpCommissionerConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT( + inSuite, pairingAccessory.GetMRPConfig().mActiveRetransTimeout == mrpCommissionerConfig.Value().mActiveRetransTimeout); + } + + if (mrpAccessoryConfig.HasValue()) + { + NL_TEST_ASSERT(inSuite, + pairingCommissioner.GetMRPConfig().mIdleRetransTimeout == mrpAccessoryConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT( + inSuite, pairingCommissioner.GetMRPConfig().mActiveRetransTimeout == mrpAccessoryConfig.Value().mActiveRetransTimeout); + } } void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) @@ -220,7 +245,41 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Missing(), delegateCommissioner); +} + +void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + Optional::Value(config), + Optional::Missing(), delegateCommissioner); +} + +void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Value(config), delegateCommissioner); +} + +void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32); + ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + Optional::Value(commissionerConfig), + Optional::Value(deviceConfig), delegateCommissioner); } void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) @@ -229,7 +288,8 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo PASESession pairingCommissioner; gLoopback.Reset(); gLoopback.mNumMessagesToDrop = 2; - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Missing(), delegateCommissioner); NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 2); NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); } @@ -269,11 +329,13 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSALTsaltSALT", 16), 0, + Optional::Missing(), &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, contextCommissioner, + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, + Optional::Missing(), contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -309,7 +371,8 @@ void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, Optional::Missing(), + Optional::Missing(), delegateCommissioner); SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); const uint8_t plain_text[] = { 0x86, 0x74, 0x64, 0xe5, 0x0b, 0xd4, 0x0d, 0x90, 0xe1, 0x17, 0xa3, 0x2d, 0x4b, 0xd4, 0xe1, 0xe6 }; @@ -358,6 +421,9 @@ static const nlTest sTests[] = NL_TEST_DEF("WaitInit", SecurePairingWaitTest), NL_TEST_DEF("Start", SecurePairingStartTest), NL_TEST_DEF("Handshake", SecurePairingHandshakeTest), + NL_TEST_DEF("Handshake with Commissioner MRP Parameters", SecurePairingHandshakeWithCommissionerMRPTest), + NL_TEST_DEF("Handshake with Device MRP Parameters", SecurePairingHandshakeWithDeviceMRPTest), + NL_TEST_DEF("Handshake with Both MRP Parameters", SecurePairingHandshakeWithAllMRPTest), NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake), NL_TEST_DEF("Serialize", SecurePairingSerializeTest), diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index 7516067eb7c9eb..47ad0fe6e16753 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -25,6 +25,7 @@ static_library("transport") { "MessageCounter.cpp", "MessageCounter.h", "MessageCounterManagerInterface.h", + "PairingSession.cpp", "PeerMessageCounter.h", "SecureMessageCodec.cpp", "SecureMessageCodec.h", diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp new file mode 100644 index 00000000000000..affc27f51367a2 --- /dev/null +++ b/src/transport/PairingSession.cpp @@ -0,0 +1,81 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +namespace chip { + +CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig, + TLV::TLVWriter & tlvWriter) +{ + VerifyOrReturnError(CanCastTo(mrpConfig.mIdleRetransTimeout.count()), CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(CanCastTo(mrpConfig.mActiveRetransTimeout.count()), CHIP_ERROR_INVALID_ARGUMENT); + + TLV::TLVType mrpParamsContainer; + ReturnErrorOnFailure(tlvWriter.StartContainer(tag, TLV::kTLVType_Structure, mrpParamsContainer)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), static_cast(mrpConfig.mIdleRetransTimeout.count()))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), static_cast(mrpConfig.mActiveRetransTimeout.count()))); + return tlvWriter.EndContainer(mrpParamsContainer); +} + +CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader) +{ + // The MRP parameters are optional. + CHIP_ERROR err = tlvReader.Next(); + if (err == CHIP_END_OF_TLV) + { + return CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err); + + TLV::TLVType containerType = TLV::kTLVType_Structure; + ReturnErrorOnFailure(tlvReader.EnterContainer(containerType)); + + uint16_t tlvElementValue = 0; + + ReturnErrorOnFailure(tlvReader.Next()); + + ChipLogDetail(SecureChannel, "Found MRP parameters in the message"); + + // Both TLV elements in the strucutre are optional. If the first element is present, process it and move + // the TLV reader to the next element. + if (TLV::TagNumFromTag(tlvReader.GetTag()) == 1) + { + ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); + mMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + + // The next element is optional. If it's not present, return CHIP_NO_ERROR. + err = tlvReader.Next(); + if (err == CHIP_END_OF_TLV) + { + return CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err); + } + + VerifyOrReturnError(TLV::TagNumFromTag(tlvReader.GetTag()) == 2, CHIP_ERROR_INVALID_TLV_TAG); + ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); + mMRPConfig.mActiveRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + + return tlvReader.ExitContainer(containerType); +} + +} // namespace chip diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 8f9b2f242fed21..b0f4769de6af4d 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -26,6 +26,7 @@ #pragma once #include +#include #include #include #include @@ -96,9 +97,11 @@ class DLL_EXPORT PairingSession void SetMRPConfig(const ReliableMessageProtocolConfig & config) { mMRPConfig = config; } - virtual const char * GetI2RSessionInfo() const = 0; - - virtual const char * GetR2ISessionInfo() const = 0; + /** + * Encode the provided MRP parameters using the provided TLV tag. + */ + static CHIP_ERROR EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig, + TLV::TLVWriter & tlvWriter); protected: void SetSecureSessionType(Transport::SecureSession::Type secureSessionType) { mSecureSessionType = secureSessionType; } @@ -159,6 +162,18 @@ class DLL_EXPORT PairingSession return err; } + /** + * Try to decode the next element (pointed by the TLV reader) as MRP parameters. + * If the MRP parameters are found, mMRPConfig is updated with the devoded values. + * + * MRP parameters are optional. So, if the TLV reader is not pointing to the MRP parameters, + * the function is a noop. + * + * If the parameters are present, but TLV reader fails to correctly parse it, the function will + * return the corresponding error. + */ + CHIP_ERROR DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader); + // TODO: remove Clear, we should create a new instance instead reset the old instance. void Clear() { diff --git a/src/transport/tests/BUILD.gn b/src/transport/tests/BUILD.gn index aa7a40160d7020..69be6a5a1a1ecc 100644 --- a/src/transport/tests/BUILD.gn +++ b/src/transport/tests/BUILD.gn @@ -23,6 +23,7 @@ chip_test_suite("tests") { output_name = "libTransportLayerTests" test_sources = [ + "TestPairingSession.cpp", "TestPeerConnections.cpp", "TestSecureSession.cpp", "TestSessionHandle.cpp", diff --git a/src/transport/tests/TestPairingSession.cpp b/src/transport/tests/TestPairingSession.cpp new file mode 100644 index 00000000000000..7d7869805930f7 --- /dev/null +++ b/src/transport/tests/TestPairingSession.cpp @@ -0,0 +1,166 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file implements unit tests for the CryptoContext implementation. + */ + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +using namespace chip; +using namespace chip::System::Clock; + +class TestPairingSession : public PairingSession +{ +public: + CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override { return CHIP_NO_ERROR; } + + CHIP_ERROR DecodeMRPParametersIfPresent(System::PacketBufferTLVReader & tlvReader) + { + return PairingSession::DecodeMRPParametersIfPresent(tlvReader); + } +}; + +void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext) +{ + TestPairingSession session; + + ReliableMessageProtocolConfig config(Milliseconds32(100), Milliseconds32(200)); + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(64, 0); + System::PacketBufferTLVWriter writer; + writer.Init(buf.Retain()); + + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR); + + CHIP_ERROR err = PairingSession::EncodeMRPParameters(TLV::ContextTag(1), config, writer); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR); + + System::PacketBufferTLVReader reader; + TLV::TLVType containerType = TLV::kTLVType_Structure; + + reader.Init(std::move(buf)); + NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR); + + err = session.DecodeMRPParametersIfPresent(reader); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == config.mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == config.mActiveRetransTimeout); +} + +void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inContext) +{ + TestPairingSession session; + + System::PacketBufferHandle buf = System::PacketBufferHandle::New(64, 0); + System::PacketBufferTLVWriter writer; + writer.Init(buf.Retain()); + + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR); + + System::PacketBufferTLVReader reader; + TLV::TLVType containerType = TLV::kTLVType_Structure; + + reader.Init(std::move(buf)); + NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(reader) == CHIP_NO_ERROR); + + NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == gDefaultMRPConfig.mIdleRetransTimeout); + NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == gDefaultMRPConfig.mActiveRetransTimeout); +} + +// Test Suite + +/** + * Test Suite that lists all the test functions. + */ +// clang-format off +static const nlTest sTests[] = +{ + NL_TEST_DEF("Encode and Decode MRP params", PairingSessionEncodeDecodeMRPParams), + NL_TEST_DEF("Decode missing MRP params", PairingSessionTryDecodeMissingMRPParams), + + NL_TEST_SENTINEL() +}; +// clang-format on + +/** + * Set up the test suite. + */ +int TestPairingSession_Setup(void * inContext) +{ + CHIP_ERROR error = chip::Platform::MemoryInit(); + if (error != CHIP_NO_ERROR) + return FAILURE; + return SUCCESS; +} + +/** + * Tear down the test suite. + */ +int TestPairingSession_Teardown(void * inContext) +{ + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +// clang-format off +static nlTestSuite sSuite = +{ + "Test-CHIP-PairingSession", + &sTests[0], + TestPairingSession_Setup, + TestPairingSession_Teardown +}; +// clang-format on + +/** + * Main + */ +int TestPairingSessionInit() +{ + // Run test suit against one context + nlTestRunner(&sSuite, nullptr); + + return (nlTestRunnerStats(&sSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestPairingSessionInit)