Skip to content

Commit

Permalink
Expose BIP324CipherSuite AAD via transport classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruv committed Dec 7, 2022
1 parent 8c35286 commit 10126d8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
16 changes: 11 additions & 5 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
// decompose a transport agnostic CNetMessage from the deserializer
bool reject_message{false};
bool disconnect{false};
CNetMessage msg = m_deserializer->GetMessage(time, reject_message, disconnect);
CNetMessage msg = m_deserializer->GetMessage(time, reject_message, disconnect, {});

if (disconnect) {
// v2 p2p incorrect MAC tag. Disconnect from peer.
Expand Down Expand Up @@ -772,7 +772,10 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
return data_hash;
}

CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect)
CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time,
bool& reject_message,
bool& disconnect,
Span<const std::byte> aad)
{
// Initialize out parameter
reject_message = false;
Expand Down Expand Up @@ -877,7 +880,10 @@ int V2TransportDeserializer::readData(Span<const uint8_t> pkt_bytes)
return copy_bytes;
}

CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect)
CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds time,
bool& reject_message,
bool& disconnect,
Span<const std::byte> aad)
{
const size_t min_contents_size = 1; // BIP324 1-byte message type id is the minimum contents

Expand All @@ -896,7 +902,7 @@ CNetMessage V2TransportDeserializer::GetMessage(const std::chrono::microseconds

BIP324HeaderFlags flags;
size_t msg_type_size = 1; // at least one byte needed for message type
if (m_cipher_suite->Crypt({},
if (m_cipher_suite->Crypt(aad,
Span{reinterpret_cast<const std::byte*>(vRecv.data() + BIP324_LENGTH_FIELD_LEN), BIP324_HEADER_LEN + m_contents_size + RFC8439_EXPANSION},
Span{reinterpret_cast<std::byte*>(vRecv.data()), m_contents_size}, flags, false)) {
// MAC check was successful
Expand Down Expand Up @@ -984,7 +990,7 @@ bool V2TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vec

BIP324HeaderFlags flags{BIP324_NONE};
// encrypt the payload, this should always succeed (controlled buffers, don't check the MAC during encrypting)
auto success = m_cipher_suite->Crypt({},
auto success = m_cipher_suite->Crypt(msg.aad,
Span{reinterpret_cast<const std::byte*>(msg.data.data()), contents_size},
Span{reinterpret_cast<std::byte*>(msg.data.data()), encrypted_pkt_size},
flags, true);
Expand Down
16 changes: 13 additions & 3 deletions src/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ struct CSerializedNetMsg {
}

std::vector<unsigned char> data;
std::vector<std::byte> aad; // associated authenticated data for encrypted BIP324 (v2) transport
std::string m_type;
};

Expand Down Expand Up @@ -260,7 +261,10 @@ class TransportDeserializer {
/** read and deserialize data, advances msg_bytes data pointer */
virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
// decomposes a message from the context
virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message, bool& disconnect) = 0;
virtual CNetMessage GetMessage(std::chrono::microseconds time,
bool& reject_message,
bool& disconnect,
Span<const std::byte> aad) = 0;
virtual ~TransportDeserializer() {}
};

Expand Down Expand Up @@ -324,7 +328,10 @@ class V1TransportDeserializer final : public TransportDeserializer
}
return ret;
}
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message, bool& disconnect) override;
CNetMessage GetMessage(std::chrono::microseconds time,
bool& reject_message,
bool& disconnect,
Span<const std::byte> aad) override;
};

/** V2TransportDeserializer is a transport deserializer after BIP324 */
Expand Down Expand Up @@ -383,7 +390,10 @@ class V2TransportDeserializer final : public TransportDeserializer
}
return ret;
}
CNetMessage GetMessage(const std::chrono::microseconds time, bool& reject_message, bool& disconnect) override;
CNetMessage GetMessage(const std::chrono::microseconds time,
bool& reject_message,
bool& disconnect,
Span<const std::byte> aad) override;
};

/** The TransportSerializer prepares messages for the network transport
Expand Down
2 changes: 1 addition & 1 deletion src/test/fuzz/p2p_transport_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
bool reject_message{false};
bool disconnect{false};
CNetMessage msg = deserializer.GetMessage(m_time, reject_message, disconnect);
CNetMessage msg = deserializer.GetMessage(m_time, reject_message, disconnect, {});
assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE);
assert(msg.m_raw_message_size <= mutable_msg_bytes.size());
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
Expand Down
15 changes: 9 additions & 6 deletions src/test/fuzz/p2p_v2_transport_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization)

// There is no sense in providing a mac assist if the length is incorrect.
bool mac_assist = length_assist && fdp.ConsumeBool();
auto aad = fdp.ConsumeBytes<std::byte>(fdp.ConsumeIntegralInRange(0, 1024));
auto encrypted_packet = fdp.ConsumeRemainingBytes<uint8_t>();
bool is_decoy_packet{false};

Expand All @@ -56,17 +57,18 @@ FUZZ_TARGET(p2p_v2_transport_serialization)

if (mac_assist) {
std::array<std::byte, RFC8439_EXPANSION> tag;
ComputeRFC8439Tag(GetPoly1305Key(c20), {},
ComputeRFC8439Tag(GetPoly1305Key(c20), aad,
{reinterpret_cast<std::byte*>(encrypted_packet.data()) + BIP324_LENGTH_FIELD_LEN,
encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION}, tag);
encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION},
tag);
memcpy(encrypted_packet.data() + encrypted_packet.size() - RFC8439_EXPANSION, tag.data(), RFC8439_EXPANSION);

std::vector<std::byte> dec_header_and_contents(
encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION);
RFC8439Decrypt({}, key_P, nonce,
encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN - RFC8439_EXPANSION);
RFC8439Decrypt(aad, key_P, nonce,
{reinterpret_cast<std::byte*>(encrypted_packet.data() + BIP324_LENGTH_FIELD_LEN),
encrypted_packet.size() - BIP324_LENGTH_FIELD_LEN},
dec_header_and_contents);
dec_header_and_contents);
if (BIP324HeaderFlags((uint8_t)dec_header_and_contents.at(0) & BIP324_IGNORE) != BIP324_NONE) {
is_decoy_packet = true;
}
Expand All @@ -83,7 +85,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization)
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
bool reject_message{true};
bool disconnect{true};
CNetMessage result{deserializer.GetMessage(m_time, reject_message, disconnect)};
CNetMessage result{deserializer.GetMessage(m_time, reject_message, disconnect, aad)};

if (mac_assist) {
assert(!disconnect);
Expand All @@ -104,6 +106,7 @@ FUZZ_TARGET(p2p_v2_transport_serialization)

std::vector<unsigned char> header;
auto msg = CNetMsgMaker{result.m_recv.GetVersion()}.Make(result.m_type, MakeUCharSpan(result.m_recv));
msg.aad = aad;
// if decryption succeeds, encryption must succeed
assert(serializer.prepareForTransport(msg, header));
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/net_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ void message_serialize_deserialize_test(bool v2, const std::vector<CSerializedNe

bool reject_message{true};
bool disconnect{true};
CNetMessage result{deserializer->GetMessage(GetTime<std::chrono::microseconds>(), reject_message, disconnect)};
CNetMessage result{deserializer->GetMessage(GetTime<std::chrono::microseconds>(), reject_message, disconnect, {})};
BOOST_CHECK(!reject_message);
BOOST_CHECK(!disconnect);
BOOST_CHECK_EQUAL(result.m_type, msg_orig.m_type);
Expand Down

0 comments on commit 10126d8

Please sign in to comment.