diff --git a/examples/common/tracing/decoder/secure_channel/Decoder.cpp b/examples/common/tracing/decoder/secure_channel/Decoder.cpp index a1cf96ea187fc2..04ec5522ff8391 100644 --- a/examples/common/tracing/decoder/secure_channel/Decoder.cpp +++ b/examples/common/tracing/decoder/secure_channel/Decoder.cpp @@ -35,7 +35,6 @@ constexpr const char * kPBKDFParamResponse = "Password-Based Key Derivation Para constexpr const char * kPASE_Pake1 = "Password Authenticated Session Establishment '1'"; constexpr const char * kPASE_Pake2 = "Password Authenticated Session Establishment '2'"; constexpr const char * kPASE_Pake3 = "Password Authenticated Session Establishment '3'"; -constexpr const char * kPASE_PakeError = "Password-authenticated key exchange Error"; constexpr const char * kCASE_Sigma1 = "Certificate Authenticated Session Establishment Sigma '1'"; constexpr const char * kCASE_Sigma2 = "Certificate Authenticated Session Establishment Sigma '2'"; constexpr const char * kCASE_Sigma3 = "Certificate Authenticated Session Establishment Sigma '3'"; @@ -60,7 +59,6 @@ CHIP_ERROR DecodePBDFKParamResponse(TLV::TLVReader & reader); CHIP_ERROR DecodePASEPake1(TLV::TLVReader & reader); CHIP_ERROR DecodePASEPake2(TLV::TLVReader & reader); CHIP_ERROR DecodePASEPake3(TLV::TLVReader & reader); -CHIP_ERROR DecodePASEPakeError(TLV::TLVReader & reader); CHIP_ERROR DecodeCASESigma1(TLV::TLVReader & reader); CHIP_ERROR DecodeCASESigma2(TLV::TLVReader & reader); CHIP_ERROR DecodeCASESigma3(TLV::TLVReader & reader); @@ -92,8 +90,6 @@ const char * ToProtocolMessageTypeName(uint8_t protocolCode) return kPASE_Pake2; case to_underlying(MessageType::PASE_Pake3): return kPASE_Pake3; - case to_underlying(MessageType::PASE_PakeError): - return kPASE_PakeError; case to_underlying(MessageType::CASE_Sigma1): return kCASE_Sigma1; case to_underlying(MessageType::CASE_Sigma2): @@ -132,8 +128,6 @@ CHIP_ERROR LogAsProtocolMessage(uint8_t protocolCode, const uint8_t * data, size return DecodePASEPake2(reader); case to_underlying(MessageType::PASE_Pake3): return DecodePASEPake3(reader); - case to_underlying(MessageType::PASE_PakeError): - return DecodePASEPakeError(reader); case to_underlying(MessageType::CASE_Sigma1): return DecodeCASESigma1(reader); case to_underlying(MessageType::CASE_Sigma2): @@ -323,11 +317,6 @@ CHIP_ERROR DecodePASEPake3(TLV::TLVReader & reader) return CHIP_NO_ERROR; } -CHIP_ERROR DecodePASEPakeError(TLV::TLVReader & reader) -{ - return CHIP_ERROR_NOT_IMPLEMENTED; -} - CHIP_ERROR DecodeCASESigma1(TLV::TLVReader & reader) { constexpr uint8_t kInitiatorRandomTag = 1; diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 170e6bb702864d..183e8540e075fb 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -38,7 +38,6 @@ bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8 case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake1): case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake2): case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_PakeError): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma1): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma2): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma3): diff --git a/src/protocols/secure_channel/Constants.h b/src/protocols/secure_channel/Constants.h index b7c66687e322db..108b547f86e4f8 100644 --- a/src/protocols/secure_channel/Constants.h +++ b/src/protocols/secure_channel/Constants.h @@ -60,7 +60,6 @@ enum class MsgType : uint8_t PASE_Pake1 = 0x22, PASE_Pake2 = 0x23, PASE_Pake3 = 0x24, - PASE_PakeError = 0x2F, // Certificate-based session establishment Message Types CASE_Sigma1 = 0x30, diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index cbe4988aff45f7..fc84abe8cd89f1 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -87,7 +87,7 @@ void PASESession::Clear() // It's done so that no security related information will be leaked. memset(&mPASEVerifier, 0, sizeof(mPASEVerifier)); memset(&mKe[0], 0, sizeof(mKe)); - mNextExpectedMsg = MsgType::PASE_PakeError; + mNextExpectedMsg.ClearValue(); mSpake2p.Clear(); mCommissioningHash.Clear(); @@ -185,8 +185,8 @@ CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Sp memmove(mSalt, salt.data(), mSaltLength); memmove(&mPASEVerifier, &verifier, sizeof(verifier)); - mIterationCount = pbkdf2IterCount; - mNextExpectedMsg = MsgType::PBKDFParamRequest; + mIterationCount = pbkdf2IterCount; + mNextExpectedMsg.SetValue(MsgType::PBKDFParamRequest); mPairingComplete = false; mLocalMRPConfig = mrpLocalConfig; @@ -234,8 +234,9 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec) VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout was called by null exchange")); VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout exchange doesn't match")); + // If we were waiting for something, mNextExpectedMsg had better have a value. ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %u", - to_underlying(mNextExpectedMsg)); + to_underlying(mNextExpectedMsg.Value())); // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. DiscardExchange(); @@ -293,7 +294,7 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() ReturnErrorOnFailure( mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse))); - mNextExpectedMsg = MsgType::PBKDFParamResponse; + mNextExpectedMsg.SetValue(MsgType::PBKDFParamResponse); ChipLogDetail(SecureChannel, "Sent PBKDF param request"); @@ -418,7 +419,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent PBKDF param response"); - mNextExpectedMsg = MsgType::PASE_Pake1; + mNextExpectedMsg.SetValue(MsgType::PASE_Pake1); return CHIP_NO_ERROR; } @@ -545,7 +546,7 @@ CHIP_ERROR PASESession::SendMsg1() mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent spake2p msg1"); - mNextExpectedMsg = MsgType::PASE_Pake2; + mNextExpectedMsg.SetValue(MsgType::PASE_Pake2); return CHIP_NO_ERROR; } @@ -606,7 +607,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); - mNextExpectedMsg = MsgType::PASE_Pake3; + mNextExpectedMsg.SetValue(MsgType::PASE_Pake3); } ChipLogDetail(SecureChannel, "Sent spake2p msg2"); @@ -682,7 +683,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); - mNextExpectedMsg = MsgType::StatusReport; + mNextExpectedMsg.SetValue(MsgType::StatusReport); } ChipLogDetail(SecureChannel, "Sent spake2p msg3"); @@ -703,8 +704,7 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) ChipLogDetail(SecureChannel, "Received spake2p msg3"); - // We will set NextExpectedMsg to PASE_PakeError in all cases - mNextExpectedMsg = MsgType::PASE_PakeError; + mNextExpectedMsg.ClearValue(); System::PacketBufferTLVReader tlvReader; TLV::TLVType containerType = TLV::kTLVType_Structure; @@ -784,7 +784,8 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons } VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(payloadHeader.HasMessageType(mNextExpectedMsg) || payloadHeader.HasMessageType(MsgType::StatusReport), + VerifyOrReturnError((mNextExpectedMsg.HasValue() && payloadHeader.HasMessageType(mNextExpectedMsg.Value())) || + payloadHeader.HasMessageType(MsgType::StatusReport), CHIP_ERROR_INVALID_MESSAGE_TYPE); return CHIP_NO_ERROR; @@ -835,7 +836,8 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl break; case MsgType::StatusReport: - err = HandleStatusReport(std::move(msg), mNextExpectedMsg == MsgType::StatusReport); + err = + HandleStatusReport(std::move(msg), mNextExpectedMsg.HasValue() && (mNextExpectedMsg.Value() == MsgType::StatusReport)); break; default: diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index cb44865075935d..19918b7a3515d5 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -210,7 +210,8 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler, void Finish(); - Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError; + // mNextExpectedMsg is set when we are expecting a message. + Optional mNextExpectedMsg; #ifdef ENABLE_HSM_SPAKE Spake2pHSM_P256_SHA256_HKDF_HMAC mSpake2p; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 3a26d0df1ee06d..dd76f91a0cd454 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -40,7 +40,6 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protoc case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake1): case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake2): case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_PakeError): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma1): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma2): case static_cast(Protocols::SecureChannel::MsgType::CASE_Sigma3):