Skip to content

Commit

Permalink
[msg] Consolidate use of SetSecurityFlags. (#10921)
Browse files Browse the repository at this point in the history
* [msg] Utilize new SetMessageFlags() and SetSecurityFlags() methods.

* [msg] Flatten SuccessOrExit calls in MessageHeader.cpp.
  • Loading branch information
turon authored and pull[bot] committed Feb 9, 2022
1 parent a391e1a commit 2180396
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 39 deletions.
55 changes: 17 additions & 38 deletions src/transport/raw/MessageHeader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ constexpr uint8_t kMsgFlagsMask = 0x07;
/// Shift to convert to/from a masked version 8bit value to a 4bit version.
constexpr int kVersionShift = 4;

// Mask to extract sessionType
constexpr uint8_t kSessionTypeMask = 0x03;

} // namespace

uint16_t PacketHeader::EncodeSizeBytes() const
Expand Down Expand Up @@ -143,31 +140,23 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1
uint16_t octets_read;

uint8_t msgFlags;
err = reader.Read8(&msgFlags).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read8(&msgFlags).StatusCode());
version = ((msgFlags & kVersionMask) >> kVersionShift);
VerifyOrExit(version == kMsgHeaderVersion, err = CHIP_ERROR_VERSION_MISMATCH);

mMsgFlags.SetRaw(msgFlags);
SetMessageFlags(msgFlags);

uint8_t securityFlags;
err = reader.Read8(&securityFlags).StatusCode();
SuccessOrExit(err);
mSecFlags.SetRaw(securityFlags);

mSessionType = static_cast<Header::SessionType>(securityFlags & kSessionTypeMask);
SuccessOrExit(err = reader.Read8(&securityFlags).StatusCode());
SetSecurityFlags(securityFlags);

err = reader.Read16(&mSessionId).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read16(&mSessionId).StatusCode());

err = reader.Read32(&mMessageCounter).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read32(&mMessageCounter).StatusCode());

if (mMsgFlags.Has(Header::MsgFlagValues::kSourceNodeIdPresent))
{
uint64_t sourceNodeId;
err = reader.Read64(&sourceNodeId).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read64(&sourceNodeId).StatusCode());
mSourceNodeId.SetValue(sourceNodeId);
}
else
Expand All @@ -178,39 +167,33 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1
if (!IsSessionTypeValid())
{
// Reserved.
err = CHIP_ERROR_INTERNAL;
SuccessOrExit(err);
SuccessOrExit(err = CHIP_ERROR_INTERNAL);
}

if (mMsgFlags.HasAll(Header::MsgFlagValues::kDestinationNodeIdPresent, Header::MsgFlagValues::kDestinationGroupIdPresent))
{
// Reserved.
err = CHIP_ERROR_INTERNAL;
SuccessOrExit(err);
SuccessOrExit(err = CHIP_ERROR_INTERNAL);
}
else if (mMsgFlags.Has(Header::MsgFlagValues::kDestinationNodeIdPresent))
{
if (mSessionType != Header::SessionType::kUnicastSession)
{
err = CHIP_ERROR_INTERNAL;
SuccessOrExit(err);
SuccessOrExit(err = CHIP_ERROR_INTERNAL);
}
uint64_t destinationNodeId;
err = reader.Read64(&destinationNodeId).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read64(&destinationNodeId).StatusCode());
mDestinationNodeId.SetValue(destinationNodeId);
mDestinationGroupId.ClearValue();
}
else if (mMsgFlags.Has(Header::MsgFlagValues::kDestinationGroupIdPresent))
{
if (mSessionType != Header::SessionType::kGroupSession)
{
err = CHIP_ERROR_INTERNAL;
SuccessOrExit(err);
SuccessOrExit(err = CHIP_ERROR_INTERNAL);
}
uint16_t destinationGroupId;
err = reader.Read16(&destinationGroupId).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read16(&destinationGroupId).StatusCode());
mDestinationGroupId.SetValue(destinationGroupId);
mDestinationNodeId.ClearValue();
}
Expand Down Expand Up @@ -244,17 +227,15 @@ CHIP_ERROR PayloadHeader::Decode(const uint8_t * const data, uint16_t size, uint
uint8_t header;
uint16_t octets_read;

err = reader.Read8(&header).Read8(&mMessageType).Read16(&mExchangeID).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read8(&header).Read8(&mMessageType).Read16(&mExchangeID).StatusCode());

mExchangeFlags.SetRaw(header);

VendorId vendor_id;
if (HaveVendorId())
{
uint16_t vendor_id_raw;
err = reader.Read16(&vendor_id_raw).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read16(&vendor_id_raw).StatusCode());
vendor_id = static_cast<VendorId>(vendor_id_raw);
}
else
Expand All @@ -263,16 +244,14 @@ CHIP_ERROR PayloadHeader::Decode(const uint8_t * const data, uint16_t size, uint
}

uint16_t protocol_id;
err = reader.Read16(&protocol_id).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read16(&protocol_id).StatusCode());

mProtocolID = Protocols::Id(vendor_id, protocol_id);

if (mExchangeFlags.Has(Header::ExFlagValues::kExchangeFlag_AckMsg))
{
uint32_t ack_message_counter;
err = reader.Read32(&ack_message_counter).StatusCode();
SuccessOrExit(err);
SuccessOrExit(err = reader.Read32(&ack_message_counter).StatusCode());
mAckMessageCounter.SetValue(ack_message_counter);
}
else
Expand Down
11 changes: 10 additions & 1 deletion src/transport/raw/MessageHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ enum class SecFlagValues : uint8_t
kMsgExtensionFlag = 0b00100000,
};

enum SecFlagMask
{
kSessionTypeMask = 0b00000011, ///< Mask to extract sessionType
};

using MsgFlags = BitFlags<MsgFlagValues>;
using SecFlags = BitFlags<SecFlagValues>;

Expand Down Expand Up @@ -164,7 +169,11 @@ class PacketHeader

void SetMessageFlags(uint8_t flags) { mMsgFlags.SetRaw(flags); }

void SetSecurityFlags(uint8_t flags) { mSecFlags.SetRaw(flags); }
void SetSecurityFlags(uint8_t securityFlags)
{
mSecFlags.SetRaw(securityFlags);
mSessionType = static_cast<Header::SessionType>(securityFlags & Header::SecFlagMask::kSessionTypeMask);
}

bool IsGroupSession() const { return mSessionType == Header::SessionType::kGroupSession; }
bool IsUnicastSession() const { return mSessionType == Header::SessionType::kUnicastSession; }
Expand Down

0 comments on commit 2180396

Please sign in to comment.