Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the PakeError message type. #21239

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions examples/common/tracing/decoder/secure_channel/Decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ constexpr const char * kPBKDFParamResponse = "Password-Based Key Derivation Para
constexpr const char * kPASE_Pake1 = "Password Authenticated Session Establishment '1'";
constexpr const char * kPASE_Pake2 = "Password Authenticated Session Establishment '2'";
constexpr const char * kPASE_Pake3 = "Password Authenticated Session Establishment '3'";
constexpr const char * kPASE_PakeError = "Password-authenticated key exchange Error";
constexpr const char * kCASE_Sigma1 = "Certificate Authenticated Session Establishment Sigma '1'";
constexpr const char * kCASE_Sigma2 = "Certificate Authenticated Session Establishment Sigma '2'";
constexpr const char * kCASE_Sigma3 = "Certificate Authenticated Session Establishment Sigma '3'";
Expand All @@ -60,7 +59,6 @@ CHIP_ERROR DecodePBDFKParamResponse(TLV::TLVReader & reader);
CHIP_ERROR DecodePASEPake1(TLV::TLVReader & reader);
CHIP_ERROR DecodePASEPake2(TLV::TLVReader & reader);
CHIP_ERROR DecodePASEPake3(TLV::TLVReader & reader);
CHIP_ERROR DecodePASEPakeError(TLV::TLVReader & reader);
CHIP_ERROR DecodeCASESigma1(TLV::TLVReader & reader);
CHIP_ERROR DecodeCASESigma2(TLV::TLVReader & reader);
CHIP_ERROR DecodeCASESigma3(TLV::TLVReader & reader);
Expand Down Expand Up @@ -92,8 +90,6 @@ const char * ToProtocolMessageTypeName(uint8_t protocolCode)
return kPASE_Pake2;
case to_underlying(MessageType::PASE_Pake3):
return kPASE_Pake3;
case to_underlying(MessageType::PASE_PakeError):
return kPASE_PakeError;
case to_underlying(MessageType::CASE_Sigma1):
return kCASE_Sigma1;
case to_underlying(MessageType::CASE_Sigma2):
Expand Down Expand Up @@ -132,8 +128,6 @@ CHIP_ERROR LogAsProtocolMessage(uint8_t protocolCode, const uint8_t * data, size
return DecodePASEPake2(reader);
case to_underlying(MessageType::PASE_Pake3):
return DecodePASEPake3(reader);
case to_underlying(MessageType::PASE_PakeError):
return DecodePASEPakeError(reader);
case to_underlying(MessageType::CASE_Sigma1):
return DecodeCASESigma1(reader);
case to_underlying(MessageType::CASE_Sigma2):
Expand Down Expand Up @@ -323,11 +317,6 @@ CHIP_ERROR DecodePASEPake3(TLV::TLVReader & reader)
return CHIP_NO_ERROR;
}

CHIP_ERROR DecodePASEPakeError(TLV::TLVReader & reader)
{
return CHIP_ERROR_NOT_IMPLEMENTED;
}

CHIP_ERROR DecodeCASESigma1(TLV::TLVReader & reader)
{
constexpr uint8_t kInitiatorRandomTag = 1;
Expand Down
1 change: 0 additions & 1 deletion src/messaging/ApplicationExchangeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ bool ApplicationExchangeDispatch::MessagePermitted(Protocols::Id protocol, uint8
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake1):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake2):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake3):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_PakeError):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma1):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma2):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma3):
Expand Down
1 change: 0 additions & 1 deletion src/protocols/secure_channel/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ enum class MsgType : uint8_t
PASE_Pake1 = 0x22,
PASE_Pake2 = 0x23,
PASE_Pake3 = 0x24,
PASE_PakeError = 0x2F,

// Certificate-based session establishment Message Types
CASE_Sigma1 = 0x30,
Expand Down
28 changes: 15 additions & 13 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void PASESession::Clear()
// It's done so that no security related information will be leaked.
memset(&mPASEVerifier, 0, sizeof(mPASEVerifier));
memset(&mKe[0], 0, sizeof(mKe));
mNextExpectedMsg = MsgType::PASE_PakeError;
mNextExpectedMsg.ClearValue();

mSpake2p.Clear();
mCommissioningHash.Clear();
Expand Down Expand Up @@ -185,8 +185,8 @@ CHIP_ERROR PASESession::WaitForPairing(SessionManager & sessionManager, const Sp
memmove(mSalt, salt.data(), mSaltLength);
memmove(&mPASEVerifier, &verifier, sizeof(verifier));

mIterationCount = pbkdf2IterCount;
mNextExpectedMsg = MsgType::PBKDFParamRequest;
mIterationCount = pbkdf2IterCount;
mNextExpectedMsg.SetValue(MsgType::PBKDFParamRequest);
mPairingComplete = false;
mLocalMRPConfig = mrpLocalConfig;

Expand Down Expand Up @@ -234,8 +234,9 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec)
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout was called by null exchange"));
VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec,
ChipLogError(SecureChannel, "PASESession::OnResponseTimeout exchange doesn't match"));
// If we were waiting for something, mNextExpectedMsg had better have a value.
ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %u",
to_underlying(mNextExpectedMsg));
to_underlying(mNextExpectedMsg.Value()));
// Discard the exchange so that Clear() doesn't try closing it. The
// exchange will handle that.
DiscardExchange();
Expand Down Expand Up @@ -293,7 +294,7 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest()
ReturnErrorOnFailure(
mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse)));

mNextExpectedMsg = MsgType::PBKDFParamResponse;
mNextExpectedMsg.SetValue(MsgType::PBKDFParamResponse);

ChipLogDetail(SecureChannel, "Sent PBKDF param request");

Expand Down Expand Up @@ -418,7 +419,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in
mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent PBKDF param response");

mNextExpectedMsg = MsgType::PASE_Pake1;
mNextExpectedMsg.SetValue(MsgType::PASE_Pake1);

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -545,7 +546,7 @@ CHIP_ERROR PASESession::SendMsg1()
mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent spake2p msg1");

mNextExpectedMsg = MsgType::PASE_Pake2;
mNextExpectedMsg.SetValue(MsgType::PASE_Pake2);

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -606,7 +607,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms
err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);

mNextExpectedMsg = MsgType::PASE_Pake3;
mNextExpectedMsg.SetValue(MsgType::PASE_Pake3);
}

ChipLogDetail(SecureChannel, "Sent spake2p msg2");
Expand Down Expand Up @@ -682,7 +683,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms
err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);

mNextExpectedMsg = MsgType::StatusReport;
mNextExpectedMsg.SetValue(MsgType::StatusReport);
}

ChipLogDetail(SecureChannel, "Sent spake2p msg3");
Expand All @@ -703,8 +704,7 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg)

ChipLogDetail(SecureChannel, "Received spake2p msg3");

// We will set NextExpectedMsg to PASE_PakeError in all cases
mNextExpectedMsg = MsgType::PASE_PakeError;
mNextExpectedMsg.ClearValue();

System::PacketBufferTLVReader tlvReader;
TLV::TLVType containerType = TLV::kTLVType_Structure;
Expand Down Expand Up @@ -784,7 +784,8 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons
}

VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(payloadHeader.HasMessageType(mNextExpectedMsg) || payloadHeader.HasMessageType(MsgType::StatusReport),
VerifyOrReturnError((mNextExpectedMsg.HasValue() && payloadHeader.HasMessageType(mNextExpectedMsg.Value())) ||
payloadHeader.HasMessageType(MsgType::StatusReport),
CHIP_ERROR_INVALID_MESSAGE_TYPE);

return CHIP_NO_ERROR;
Expand Down Expand Up @@ -835,7 +836,8 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl
break;

case MsgType::StatusReport:
err = HandleStatusReport(std::move(msg), mNextExpectedMsg == MsgType::StatusReport);
err =
HandleStatusReport(std::move(msg), mNextExpectedMsg.HasValue() && (mNextExpectedMsg.Value() == MsgType::StatusReport));
break;

default:
Expand Down
3 changes: 2 additions & 1 deletion src/protocols/secure_channel/PASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler,

void Finish();

Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError;
// mNextExpectedMsg is set when we are expecting a message.
Optional<Protocols::SecureChannel::MsgType> mNextExpectedMsg;

#ifdef ENABLE_HSM_SPAKE
Spake2pHSM_P256_SHA256_HKDF_HMAC mSpake2p;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(Protocols::Id protoc
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake1):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake2):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_Pake3):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::PASE_PakeError):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma1):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma2):
case static_cast<uint8_t>(Protocols::SecureChannel::MsgType::CASE_Sigma3):
Expand Down