diff --git a/src/app/server/CommissioningWindowManager.cpp b/src/app/server/CommissioningWindowManager.cpp index adf501554c8cd2..dfe2b63e5cf785 100644 --- a/src/app/server/CommissioningWindowManager.cpp +++ b/src/app/server/CommissioningWindowManager.cpp @@ -176,8 +176,9 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow() if (mUseECM) { ReturnErrorOnFailure(SetTemporaryDiscriminator(mECMDiscriminator)); - ReturnErrorOnFailure(mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), - mECMPasscodeID, keyID, this)); + ReturnErrorOnFailure( + mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), mECMPasscodeID, + keyID, Optional::Value(gDefaultMRPConfig), this)); // reset all advertising, indicating we are in commissioningMode app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledEnhanced); @@ -189,7 +190,8 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow() ReturnErrorOnFailure(mPairingSession.WaitForPairing( pinCode, kSpake2p_Iteration_Count, - ByteSpan(reinterpret_cast(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID, this)); + ByteSpan(reinterpret_cast(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID, + Optional::Value(gDefaultMRPConfig), this)); // reset all advertising, indicating we are in commissioningMode app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledBasic); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 18b6cffd2fe5e7..365dd0472c0008 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -859,7 +859,8 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re // TODO - Remove use of SetActive/IsActive from CommissioneeDeviceProxy device->SetActive(true); - err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, exchangeCtxt, this); + err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, + Optional::Value(mMRPConfig), exchangeCtxt, this); SuccessOrExit(err); // Immediately persist the updated mNextKeyID value diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index 346a36a304518d..6909ff5d541411 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -810,6 +810,8 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, Callback::Callback mDeviceNOCChainCallback; SetUpCodePairer mSetUpCodePairer; + + ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig; }; } // namespace Controller diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 10340d0ad29ed6..5598f4cf915106 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -251,7 +251,8 @@ CHIP_ERROR PASESession::SetupSpake2p(uint32_t pbkdf2IterCount, const ByteSpan & } CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t mySessionId, SessionEstablishmentDelegate * delegate) + uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate) { // Return early on error here, as we have not initalized any state yet ReturnErrorCodeIf(salt.empty(), CHIP_ERROR_INVALID_ARGUMENT); @@ -281,6 +282,7 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I mNextExpectedMsg = MsgType::PBKDFParamRequest; mPairingComplete = false; mPasscodeID = 0; + mLocalMRPConfig = mrpConfig; ChipLogDetail(SecureChannel, "Waiting for PBKDF param request"); @@ -293,9 +295,10 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I } CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, - uint16_t passcodeID, uint16_t mySessionId, SessionEstablishmentDelegate * delegate) + uint16_t passcodeID, uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate) { - ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, delegate)); + ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, mrpConfig, delegate)); memmove(&mPASEVerifier, &verifier, sizeof(verifier)); mComputeVerifier = false; @@ -305,7 +308,8 @@ CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t p } CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) + Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate) { ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); CHIP_ERROR err = Init(mySessionId, peerSetUpPINCode, delegate); @@ -316,6 +320,8 @@ CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t SetPeerAddress(peerAddress); + mLocalMRPConfig = mrpConfig; + err = SendPBKDFParamRequest(); SuccessOrExit(err); @@ -355,13 +361,14 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() { ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom, + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; + const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom, sizeof(uint16_t), // initiatorSessionId sizeof(uint16_t), // passcodeId, - sizeof(uint8_t) // hasPBKDFParameters - /* TLV::EstimateStructOverhead(sizeof(uint16_t), - sizeof(uint16)_t), // initiatorMRPParams */ + sizeof(uint8_t), // hasPBKDFParameters + mrpParamsSize // MRP Parameters ); + System::PacketBufferHandle req = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!req.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -374,9 +381,11 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId())); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), mPasscodeID)); ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); - // TODO - Add optional MRP parameter support to PASE - // When we add MRP params here, adjust the TLV::EstimateStructOverhead call - // above accordingly. + if (mLocalMRPConfig.HasValue()) + { + ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param request"); + ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&req)); @@ -433,7 +442,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters)); - // TODO - Check if optional MRP parameters were sent. If so, cache them. + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters); SuccessOrExit(err); @@ -453,14 +462,15 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in { ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; const size_t max_msg_len = - TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom - kPBKDFParamRandomNumberSize, // responderRandom - sizeof(uint16_t), // responderSessionId - TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength) // pbkdf_parameters - /* TLV::EstimateStructOverhead(sizeof(uint16_t), - sizeof(uint16)_t), // responderMRPParams */ + TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom + kPBKDFParamRandomNumberSize, // responderRandom + sizeof(uint16_t), // responderSessionId + TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength), // pbkdf_parameters + mrpParamsSize // MRP Parameters ); + System::PacketBufferHandle resp = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!resp.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -483,8 +493,11 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in ReturnErrorOnFailure(tlvWriter.EndContainer(pbkdfParamContainer)); } - // When we add MRP params here, adjust the TLV::EstimateStructOverhead call - // above accordingly. + if (mLocalMRPConfig.HasValue()) + { + ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param response"); + ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&resp)); @@ -549,6 +562,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m if (mHavePBKDFParameters) { + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + // TODO - Add a unit test that exercises mHavePBKDFParameters path err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength)); SuccessOrExit(err); @@ -568,6 +583,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m saltLength = tlvReader.GetLength(); SuccessOrExit(err = tlvReader.GetDataPtr(salt)); + SuccessOrExit(err = tlvReader.ExitContainer(containerType)); + + SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength)); SuccessOrExit(err); } diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 7a7e86de881738..5c23183babebfd 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -107,7 +107,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t mySessionId, - SessionEstablishmentDelegate * delegate); + Optional mrpConfig, SessionEstablishmentDelegate * delegate); /** * @brief @@ -123,7 +123,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t passcodeID, - uint16_t mySessionId, SessionEstablishmentDelegate * delegate); + uint16_t mySessionId, Optional mrpConfig, + SessionEstablishmentDelegate * delegate); /** * @brief @@ -141,7 +142,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId, - Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); + Optional mrpConfig, Messaging::ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate); /** * @brief @@ -304,6 +306,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin SessionEstablishmentExchangeDispatch mMessageDispatch; + Optional mLocalMRPConfig; + struct Spake2pErrorMsg { Spake2pErrorType error; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index befcd1f292e486..800afa97e879cf 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -98,16 +98,20 @@ void SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - NL_TEST_ASSERT(inSuite, pairing.WaitForPairing(1234, 500, ByteSpan(nullptr, 0), 0, &delegate) == CHIP_ERROR_INVALID_ARGUMENT); + NL_TEST_ASSERT(inSuite, + pairing.WaitForPairing(1234, 500, ByteSpan(nullptr, 0), 0, Optional::Missing(), + &delegate) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, nullptr) == - CHIP_ERROR_INVALID_ARGUMENT); + pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, + Optional::Missing(), + nullptr) == CHIP_ERROR_INVALID_ARGUMENT); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, &delegate) == CHIP_NO_ERROR); + pairing.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSalt", 8), 0, + Optional::Missing(), &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); } @@ -126,11 +130,13 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, nullptr, nullptr) != CHIP_NO_ERROR); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), nullptr, nullptr) != CHIP_NO_ERROR); gLoopback.Reset(); NL_TEST_ASSERT(inSuite, - pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context, &delegate) == CHIP_NO_ERROR); + pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), context, &delegate) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -148,14 +154,17 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == - CHIP_ERROR_BAD_REQUEST); + pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, + Optional::Missing(), context1, + &delegate) == CHIP_ERROR_BAD_REQUEST); ctx.DrainAndServiceIO(); gLoopback.mMessageSendError = CHIP_NO_ERROR; } void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, PASESession & pairingCommissioner, + Optional mrpCommissionerConfig, + Optional mrpAccessoryConfig, TestSecurePairingDelegate & delegateCommissioner) { TestContext & ctx = *reinterpret_cast(inContext); @@ -190,12 +199,12 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSALTsaltSALT", 16), 0, - &delegateAccessory) == CHIP_NO_ERROR); + mrpAccessoryConfig, &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, contextCommissioner, - &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, mrpCommissionerConfig, + contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); while (gLoopback.mMessageDropped) @@ -213,6 +222,22 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount >= 5); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); + + if (mrpCommissionerConfig.HasValue()) + { + NL_TEST_ASSERT(inSuite, + pairingAccessory.GetMRPConfig().mIdleRetransTimeout == mrpCommissionerConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT( + inSuite, pairingAccessory.GetMRPConfig().mActiveRetransTimeout == mrpCommissionerConfig.Value().mActiveRetransTimeout); + } + + if (mrpAccessoryConfig.HasValue()) + { + NL_TEST_ASSERT(inSuite, + pairingCommissioner.GetMRPConfig().mIdleRetransTimeout == mrpAccessoryConfig.Value().mIdleRetransTimeout); + NL_TEST_ASSERT( + inSuite, pairingCommissioner.GetMRPConfig().mActiveRetransTimeout == mrpAccessoryConfig.Value().mActiveRetransTimeout); + } } void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) @@ -220,7 +245,41 @@ void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext) TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; gLoopback.Reset(); - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Missing(), delegateCommissioner); +} + +void SecurePairingHandshakeWithCommissionerMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + Optional::Value(config), + Optional::Missing(), delegateCommissioner); +} + +void SecurePairingHandshakeWithDeviceMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig config(1000_ms32, 10000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Value(config), delegateCommissioner); +} + +void SecurePairingHandshakeWithAllMRPTest(nlTestSuite * inSuite, void * inContext) +{ + TestSecurePairingDelegate delegateCommissioner; + PASESession pairingCommissioner; + gLoopback.Reset(); + ReliableMessageProtocolConfig commissionerConfig(1000_ms32, 10000_ms32); + ReliableMessageProtocolConfig deviceConfig(2000_ms32, 7000_ms32); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, + Optional::Value(commissionerConfig), + Optional::Value(deviceConfig), delegateCommissioner); } void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inContext) @@ -229,7 +288,8 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo PASESession pairingCommissioner; gLoopback.Reset(); gLoopback.mNumMessagesToDrop = 2; - SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, pairingCommissioner, Optional::Missing(), + Optional::Missing(), delegateCommissioner); NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 2); NL_TEST_ASSERT(inSuite, gLoopback.mNumMessagesToDrop == 0); } @@ -269,11 +329,13 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, pairingAccessory.WaitForPairing(1234, 500, ByteSpan((const uint8_t *) "saltSALTsaltSALT", 16), 0, + Optional::Missing(), &delegateAccessory) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); NL_TEST_ASSERT(inSuite, - pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, contextCommissioner, + pairingCommissioner.Pair(Transport::PeerAddress(Transport::Type::kBle), 4321, 0, + Optional::Missing(), contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); ctx.DrainAndServiceIO(); @@ -309,7 +371,8 @@ void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, delegateCommissioner); + SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, Optional::Missing(), + Optional::Missing(), delegateCommissioner); SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); const uint8_t plain_text[] = { 0x86, 0x74, 0x64, 0xe5, 0x0b, 0xd4, 0x0d, 0x90, 0xe1, 0x17, 0xa3, 0x2d, 0x4b, 0xd4, 0xe1, 0xe6 }; @@ -358,6 +421,9 @@ static const nlTest sTests[] = NL_TEST_DEF("WaitInit", SecurePairingWaitTest), NL_TEST_DEF("Start", SecurePairingStartTest), NL_TEST_DEF("Handshake", SecurePairingHandshakeTest), + NL_TEST_DEF("Handshake with Commissioner MRP Parameters", SecurePairingHandshakeWithCommissionerMRPTest), + NL_TEST_DEF("Handshake with Device MRP Parameters", SecurePairingHandshakeWithDeviceMRPTest), + NL_TEST_DEF("Handshake with Both MRP Parameters", SecurePairingHandshakeWithAllMRPTest), NL_TEST_DEF("Handshake with packet loss", SecurePairingHandshakeWithPacketLossTest), NL_TEST_DEF("Failed Handshake", SecurePairingFailedHandshake), NL_TEST_DEF("Serialize", SecurePairingSerializeTest), diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index 7516067eb7c9eb..47ad0fe6e16753 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -25,6 +25,7 @@ static_library("transport") { "MessageCounter.cpp", "MessageCounter.h", "MessageCounterManagerInterface.h", + "PairingSession.cpp", "PeerMessageCounter.h", "SecureMessageCodec.cpp", "SecureMessageCodec.h", diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp new file mode 100644 index 00000000000000..bca568c4bd40c7 --- /dev/null +++ b/src/transport/PairingSession.cpp @@ -0,0 +1,134 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines a common interface to access various types of secure + * pairing sessions (e.g. PASE, CASE) + * + */ + +#include + +#include +#include + +namespace chip { + +void PairingSession::SendStatusReport(Messaging::ExchangeContext * exchangeCtxt, uint16_t protocolCode) +{ + Protocols::SecureChannel::GeneralStatusCode generalCode = (protocolCode == Protocols::SecureChannel::kProtocolCodeSuccess) + ? Protocols::SecureChannel::GeneralStatusCode::kSuccess + : Protocols::SecureChannel::GeneralStatusCode::kFailure; + uint32_t protocolId = Protocols::SecureChannel::Id.ToFullyQualifiedSpecForm(); + + ChipLogDetail(SecureChannel, "Sending status report. Protocol code %d, exchange %d", protocolCode, + exchangeCtxt->GetExchangeId()); + + Protocols::SecureChannel::StatusReport statusReport(generalCode, protocolId, protocolCode); + + Encoding::LittleEndian::PacketBufferWriter bbuf(System::PacketBufferHandle::New(statusReport.Size())); + statusReport.WriteToBuffer(bbuf); + + System::PacketBufferHandle msg = bbuf.Finalize(); + VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate status report message")); + + CHIP_ERROR err = exchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::StatusReport, std::move(msg)); + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Failed to send status report message. %s", ErrorStr(err)); + } +} + +CHIP_ERROR PairingSession::HandleStatusReport(System::PacketBufferHandle && msg, bool successExpected) +{ + Protocols::SecureChannel::StatusReport report; + CHIP_ERROR err = report.Parse(std::move(msg)); + ReturnErrorOnFailure(err); + VerifyOrReturnError(report.GetProtocolId() == Protocols::SecureChannel::Id.ToFullyQualifiedSpecForm(), + CHIP_ERROR_INVALID_ARGUMENT); + + if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kSuccess && + report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeSuccess && successExpected) + { + OnSuccessStatusReport(); + } + else + { + err = OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode()); + } + + return err; +} + +CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig, + System::PacketBufferTLVWriter & tlvWriter) +{ + VerifyOrReturnError(CanCastTo(mrpConfig.mIdleRetransTimeout.count()), CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(CanCastTo(mrpConfig.mActiveRetransTimeout.count()), CHIP_ERROR_INVALID_ARGUMENT); + + TLV::TLVType mrpParamsContainer; + ReturnErrorOnFailure(tlvWriter.StartContainer(tag, TLV::kTLVType_Structure, mrpParamsContainer)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), static_cast(mrpConfig.mIdleRetransTimeout.count()))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), static_cast(mrpConfig.mActiveRetransTimeout.count()))); + return tlvWriter.EndContainer(mrpParamsContainer); +} + +CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(System::PacketBufferTLVReader & tlvReader) +{ + // The MRP parameters are optional. + CHIP_ERROR err = tlvReader.Next(); + if (err == CHIP_END_OF_TLV) + { + return CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err); + + TLV::TLVType containerType = TLV::kTLVType_Structure; + ReturnErrorOnFailure(tlvReader.EnterContainer(containerType)); + + uint16_t tlvElementValue = 0; + + ReturnErrorOnFailure(tlvReader.Next()); + + ChipLogDetail(SecureChannel, "Found MRP parameters in the message"); + + // Both TLV elements in the strucutre are optional. If the first element is present, process it and move + // the TLV reader to the next element. + if (TLV::TagNumFromTag(tlvReader.GetTag()) == 1) + { + ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); + mMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + + // The next element is optional. If it's not present, return CHIP_NO_ERROR. + err = tlvReader.Next(); + if (err == CHIP_END_OF_TLV) + { + return CHIP_NO_ERROR; + } + ReturnErrorOnFailure(err); + } + + VerifyOrReturnError(TLV::TagNumFromTag(tlvReader.GetTag()) == 2, CHIP_ERROR_INVALID_TLV_TAG); + ReturnErrorOnFailure(tlvReader.Get(tlvElementValue)); + mMRPConfig.mActiveRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); + + return tlvReader.ExitContainer(containerType); +} + +} // namespace chip diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 8f9b2f242fed21..945dfc0d73356b 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -113,51 +114,14 @@ class DLL_EXPORT PairingSession return CHIP_ERROR_INTERNAL; } - void SendStatusReport(Messaging::ExchangeContext * exchangeCtxt, uint16_t protocolCode) - { - Protocols::SecureChannel::GeneralStatusCode generalCode = (protocolCode == Protocols::SecureChannel::kProtocolCodeSuccess) - ? Protocols::SecureChannel::GeneralStatusCode::kSuccess - : Protocols::SecureChannel::GeneralStatusCode::kFailure; - uint32_t protocolId = Protocols::SecureChannel::Id.ToFullyQualifiedSpecForm(); - - ChipLogDetail(SecureChannel, "Sending status report. Protocol code %d, exchange %d", protocolCode, - exchangeCtxt->GetExchangeId()); + void SendStatusReport(Messaging::ExchangeContext * exchangeCtxt, uint16_t protocolCode); - Protocols::SecureChannel::StatusReport statusReport(generalCode, protocolId, protocolCode); + CHIP_ERROR HandleStatusReport(System::PacketBufferHandle && msg, bool successExpected); - Encoding::LittleEndian::PacketBufferWriter bbuf(System::PacketBufferHandle::New(statusReport.Size())); - statusReport.WriteToBuffer(bbuf); + static CHIP_ERROR EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig, + System::PacketBufferTLVWriter & tlvWriter); - System::PacketBufferHandle msg = bbuf.Finalize(); - VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate status report message")); - - CHIP_ERROR err = exchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::StatusReport, std::move(msg)); - if (err != CHIP_NO_ERROR) - { - ChipLogError(SecureChannel, "Failed to send status report message. %s", ErrorStr(err)); - } - } - - CHIP_ERROR HandleStatusReport(System::PacketBufferHandle && msg, bool successExpected) - { - Protocols::SecureChannel::StatusReport report; - CHIP_ERROR err = report.Parse(std::move(msg)); - ReturnErrorOnFailure(err); - VerifyOrReturnError(report.GetProtocolId() == Protocols::SecureChannel::Id.ToFullyQualifiedSpecForm(), - CHIP_ERROR_INVALID_ARGUMENT); - - if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kSuccess && - report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeSuccess && successExpected) - { - OnSuccessStatusReport(); - } - else - { - err = OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode()); - } - - return err; - } + CHIP_ERROR DecodeMRPParametersIfPresent(System::PacketBufferTLVReader & tlvReader); // TODO: remove Clear, we should create a new instance instead reset the old instance. void Clear()