diff --git a/Cargo.toml b/Cargo.toml index 0a4b5c687..6bba8bd5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,13 +12,13 @@ categories = ["network-programming", "asynchronous"] exclude = [".gitignore", ".github/*"] [dependencies] -enr = { version = "0.10", features = ["k256", "ed25519"] } +enr = { version = "0.11.0", features = ["k256", "ed25519"] } tokio = { version = "1", features = ["net", "sync", "macros", "rt"] } libp2p = { version = "0.53", features = ["ed25519", "secp256k1"], optional = true } zeroize = { version = "1", features = ["zeroize_derive"] } futures = "0.3" uint = { version = "0.9", default-features = false } -rlp = "0.5" +alloy-rlp = { version = "0.3.4", default-features = true } # This version must be kept up to date do it uses the same dependencies as ENR hkdf = "0.12" hex = "0.4" diff --git a/examples/request_enr.rs b/examples/request_enr.rs index d3daac640..cc3248334 100644 --- a/examples/request_enr.rs +++ b/examples/request_enr.rs @@ -17,7 +17,7 @@ use discv5::ConfigBuilder; #[cfg(feature = "libp2p")] use discv5::ListenConfig; #[cfg(feature = "libp2p")] -use discv5::{enr, enr::CombinedKey, Discv5}; +use discv5::{enr::CombinedKey, Discv5}; #[cfg(feature = "libp2p")] use std::net::Ipv4Addr; diff --git a/src/discv5.rs b/src/discv5.rs index e33c60f25..d28d3038f 100644 --- a/src/discv5.rs +++ b/src/discv5.rs @@ -23,7 +23,7 @@ use crate::{ service::{QueryKind, Service, ServiceRequest, TalkRequest}, Config, DefaultProtocolId, Enr, IpMode, }; -use enr::{CombinedKey, EnrError, EnrKey, NodeId}; +use enr::{CombinedKey, EnrKey, Error as EnrError, NodeId}; use parking_lot::RwLock; use std::{ future::Future, @@ -433,7 +433,7 @@ impl Discv5

{ } /// Allows application layer to insert an arbitrary field into the local ENR. - pub fn enr_insert( + pub fn enr_insert( &self, key: &str, value: &T, diff --git a/src/discv5/test.rs b/src/discv5/test.rs index f6cd70225..db5edf2a0 100644 --- a/src/discv5/test.rs +++ b/src/discv5/test.rs @@ -1,6 +1,7 @@ #![cfg(test)] use crate::{socket::ListenConfig, Discv5, *}; +use alloy_rlp::bytes::Bytes; use enr::{k256, CombinedKey, Enr, EnrKey, NodeId}; use rand_core::{RngCore, SeedableRng}; use std::{ @@ -14,7 +15,7 @@ fn init() { .try_init(); } -fn update_enr(discv5: &mut Discv5, key: &str, value: &T) -> bool { +fn update_enr(discv5: &mut Discv5, key: &str, value: &T) -> bool { discv5.enr_insert(key, value).is_ok() } @@ -675,8 +676,9 @@ async fn test_predicate_search() { // Update `num_nodes` with the required attnet value let num_nodes = total_nodes / 2; - let required_attnet_value = vec![1, 0, 0, 0]; - let unwanted_attnet_value = vec![0, 0, 0, 0]; + + let required_attnet_value = Bytes::copy_from_slice(&[1, 0, 0, 0]); + let unwanted_attnet_value = Bytes::copy_from_slice(&[0, 0, 0, 0]); println!("Bootstrap node: {}", bootstrap_node.local_enr().node_id()); println!("Target node: {}", target_node.local_enr().node_id()); @@ -703,7 +705,7 @@ async fn test_predicate_search() { // Predicate function for filtering enrs let predicate = move |enr: &Enr| { if let Some(v) = enr.get("attnets") { - v == required_attnet_value.as_slice() + v == required_attnet_value.to_vec().as_slice() } else { false } diff --git a/src/error.rs b/src/error.rs index 35b2b4768..f1d34300d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,5 @@ use crate::{handler::Challenge, node_info::NonContactable}; -use rlp::DecoderError; +use alloy_rlp::Error as DecoderError; use std::fmt; #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 99b2eb256..5008070bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,7 +53,7 @@ //! A simple example of creating this service is as follows: //! //! ```rust -//! use discv5::{enr, enr::{CombinedKey, NodeId}, TokioExecutor, Discv5, ConfigBuilder}; +//! use discv5::{enr, enr::{CombinedKey, Enr, NodeId}, TokioExecutor, Discv5, ConfigBuilder}; //! use discv5::socket::ListenConfig; //! use std::net::{Ipv4Addr, SocketAddr}; //! diff --git a/src/packet/mod.rs b/src/packet/mod.rs index f071263bd..15ad802ae 100644 --- a/src/packet/mod.rs +++ b/src/packet/mod.rs @@ -14,6 +14,7 @@ use aes::{ cipher::{generic_array::GenericArray, NewCipher, StreamCipher}, Aes128Ctr, }; +use alloy_rlp::Decodable; use enr::NodeId; use rand::Rng; use std::convert::TryInto; @@ -173,7 +174,7 @@ impl PacketKind { } => { let sig_size = id_nonce_sig.len(); let pubkey_size = ephem_pubkey.len(); - let node_record = enr_record.as_ref().map(rlp::encode); + let node_record = enr_record.as_ref().map(alloy_rlp::encode); let expected_len = 34 + sig_size + pubkey_size @@ -259,7 +260,7 @@ impl PacketKind { let enr_record = if remaining_data.len() > total_size { Some( - rlp::decode::(&remaining_data[total_size..]) + ::decode(&mut &remaining_data[total_size..]) .map_err(PacketError::InvalidEnr)?, ) } else { diff --git a/src/rpc.rs b/src/rpc.rs index d59968abb..dfa385cc1 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -1,5 +1,8 @@ +use alloy_rlp::{ + bytes::{Buf, Bytes, BytesMut}, + Decodable, Encodable, Error as DecoderError, Header, +}; use enr::{CombinedKey, Enr}; -use rlp::{DecoderError, RlpStream}; use std::{ convert::TryInto, net::{IpAddr, Ipv6Addr}, @@ -126,28 +129,40 @@ impl Request { let id = &self.id; match self.body { RequestBody::Ping { enr_seq } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.append(&enr_seq); - buf.extend_from_slice(&s.out()); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + enr_seq.encode(&mut list); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } RequestBody::FindNode { distances } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.append_list(&distances); - buf.extend_from_slice(&s.out()); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + distances.encode(&mut list); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } RequestBody::Talk { protocol, request } => { - let mut s = RlpStream::new(); - s.begin_list(3); - s.append(&id.as_bytes()); - s.append(&protocol); - s.append(&request); - buf.extend_from_slice(&s.out()); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + protocol.encode(&mut list); + request.encode(&mut list); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } } @@ -182,41 +197,62 @@ impl Response { let id = &self.id; match self.body { ResponseBody::Pong { enr_seq, ip, port } => { - let mut s = RlpStream::new(); - s.begin_list(4); - s.append(&id.as_bytes()); - s.append(&enr_seq); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + enr_seq.encode(&mut list); match ip { - IpAddr::V4(addr) => s.append(&(&addr.octets() as &[u8])), - IpAddr::V6(addr) => s.append(&(&addr.octets() as &[u8])), + IpAddr::V4(addr) => addr.encode(&mut list), + IpAddr::V6(addr) => addr.encode(&mut list), }; - s.append(&port.get()); - buf.extend_from_slice(&s.out()); + port.get().encode(&mut list); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } ResponseBody::Nodes { total, nodes } => { - let mut s = RlpStream::new(); - s.begin_list(3); - s.append(&id.as_bytes()); - s.append(&total); - - if nodes.is_empty() { - s.begin_list(0); - } else { - s.begin_list(nodes.len()); - for node in nodes { - s.append(&node); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + total.encode(&mut list); + if !nodes.is_empty() { + let mut out = BytesMut::new(); + for node in nodes.clone() { + node.encode(&mut out); } + let tmp_header = Header { + list: true, + payload_length: out.len(), + }; + let mut tmp_out = BytesMut::new(); + tmp_header.encode(&mut tmp_out); + tmp_out.extend_from_slice(&out); + list.extend_from_slice(&tmp_out); + } else { + let mut out = BytesMut::new(); + nodes.encode(&mut out); + list.extend_from_slice(&out); } - buf.extend_from_slice(&s.out()); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } ResponseBody::Talk { response } => { - let mut s = RlpStream::new(); - s.begin_list(2); - s.append(&id.as_bytes()); - s.append(&response); - buf.extend_from_slice(&s.out()); + let mut list = Vec::::new(); + id.as_bytes().encode(&mut list); + response.as_slice().encode(&mut list); + let header = Header { + list: true, + payload_length: list.len(), + }; + header.encode(&mut buf); + buf.extend_from_slice(&list); buf } } @@ -304,57 +340,41 @@ impl Message { pub fn decode(data: &[u8]) -> Result { if data.len() < 3 { - return Err(DecoderError::RlpIsTooShort); + return Err(DecoderError::InputTooShort); } let msg_type = data[0]; - let data = &data[1..]; - let rlp = rlp::Rlp::new(data); + let payload = &mut &data[1..]; - let list_len = rlp.item_count().and_then(|size| { - if size < 2 { - Err(DecoderError::RlpIncorrectListLen) - } else { - Ok(size) - } - })?; + let header = Header::decode(payload)?; + if !header.list { + return Err(DecoderError::Custom("Invalid format of header")); + } - // verify there is no extra data - let payload_info = rlp.payload_info()?; - if data.len() != payload_info.header_len + payload_info.value_len { - return Err(DecoderError::RlpInconsistentLengthAndData); + if header.payload_length != payload.len() { + return Err(DecoderError::Custom("Reject the extra data")); } - let id = RequestId::decode(rlp.val_at::>(0)?)?; + let id_bytes = Bytes::decode(payload)?; + let id = RequestId(id_bytes.to_vec()); let message = match msg_type { 1 => { // PingRequest - if list_len != 2 { - debug!( - "Ping Request has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); + let enr_seq = u64::decode(payload)?; + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); } Message::Request(Request { id, - body: RequestBody::Ping { - enr_seq: rlp.val_at::(1)?, - }, + body: RequestBody::Ping { enr_seq }, }) } 2 => { // PingResponse - if list_len != 4 { - debug!( - "Ping Response has an invalid RLP list length. Expected 4, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let ip_bytes = rlp.val_at::>(2)?; + let enr_seq = u64::decode(payload)?; + let ip_bytes = Bytes::decode(payload)?; let ip = match ip_bytes.len() { 4 => { let mut ip = [0u8; 4]; @@ -379,18 +399,17 @@ impl Message { } _ => { debug!("Ping Response has incorrect byte length for IP"); - return Err(DecoderError::RlpIncorrectListLen); + return Err(DecoderError::Custom("Incorrect List Length")); } }; - let raw_port = rlp.val_at::(3)?; + let raw_port = u16::decode(payload)?; if let Ok(port) = raw_port.try_into() { + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); + } Message::Response(Response { id, - body: ResponseBody::Pong { - enr_seq: rlp.val_at::(1)?, - ip, - port, - }, + body: ResponseBody::Pong { enr_seq, ip, port }, }) } else { debug!("The port number should be non zero: {raw_port}"); @@ -399,14 +418,7 @@ impl Message { } 3 => { // FindNodeRequest - if list_len != 2 { - debug!( - "FindNode Request has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - let distances = rlp.list_at::(1)?; + let distances = Vec::::decode(payload)?; for distance in distances.iter() { if distance > &256u64 { @@ -417,7 +429,9 @@ impl Message { return Err(DecoderError::Custom("FINDNODE request distance invalid")); } } - + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); + } Message::Request(Request { id, body: RequestBody::FindNode { distances }, @@ -425,42 +439,51 @@ impl Message { } 4 => { // NodesResponse - if list_len != 3 { - debug!( - "Nodes Response has an invalid RLP list length. Expected 3, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); - } - + let total = u64::decode(payload)?; let nodes = { - let enr_list_rlp = rlp.at(2)?; + let header = Header::decode(payload)?; + if !header.list { + return Err(DecoderError::Custom("Invalid format of header")); + } + let mut enr_list_rlp = Vec::>::new(); + while !payload.is_empty() { + let node_header = Header::decode(&mut &payload[..])?; + if !node_header.list { + return Err(DecoderError::Custom("Invalid format of header")); + } + if node_header.payload_length + 2 > payload.len() { + return Err(DecoderError::Custom( + "Payload size is smaller than payload_length", + )); + } + let enr_rlp = Enr::::decode( + &mut &payload[..node_header.payload_length + 2], + )?; + payload.advance(enr_rlp.size()); + enr_list_rlp.append(&mut vec![enr_rlp]); + } if enr_list_rlp.is_empty() { // no records vec![] } else { - enr_list_rlp.as_list::>()? + enr_list_rlp } }; + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); + } Message::Response(Response { id, - body: ResponseBody::Nodes { - total: rlp.val_at::(1)?, - nodes, - }, + body: ResponseBody::Nodes { total, nodes }, }) } 5 => { // Talk Request - if list_len != 3 { - debug!( - "Talk Request has an invalid RLP list length. Expected 3, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); + let protocol = Vec::::decode(payload)?; + let request = Vec::::decode(payload)?; + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); } - let protocol = rlp.val_at::>(1)?; - let request = rlp.val_at::>(2)?; Message::Request(Request { id, body: RequestBody::Talk { protocol, request }, @@ -468,17 +491,15 @@ impl Message { } 6 => { // Talk Response - if list_len != 2 { - debug!( - "Talk Response has an invalid RLP list length. Expected 2, found {}", - list_len - ); - return Err(DecoderError::RlpIncorrectListLen); + let response = Bytes::decode(payload)?; + if !payload.is_empty() { + return Err(DecoderError::Custom("Payload should be empty")); } - let response = rlp.val_at::>(1)?; Message::Response(Response { id, - body: ResponseBody::Talk { response }, + body: ResponseBody::Talk { + response: response.to_vec(), + }, }) } _ => { @@ -774,11 +795,21 @@ mod tests { body: ResponseBody::Talk { response: vec![75] } }) ); + assert_eq!(data.to_vec(), msg.encode()); let data2 = [6, 193, 0, 75, 252]; Message::decode(&data2).expect_err("should reject extra data"); let data3 = [6, 194, 0, 75, 252]; Message::decode(&data3).expect_err("should reject extra data"); + + let data4 = [6, 193, 0, 63]; + Message::decode(&data4).expect_err("should reject extra data"); + + let data5 = [6, 193, 128, 75]; + Message::decode(&data5).expect_err("should reject extra data"); + + let data6 = [6, 193, 128, 128]; + Message::decode(&data6).expect_err("should reject extra data"); } } diff --git a/src/service.rs b/src/service.rs index 225a1bd67..af968fdb7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1055,7 +1055,7 @@ impl Service { let mut rpc_index = 0; to_send_nodes.push(Vec::new()); for enr in nodes_to_send.into_iter() { - let entry_size = rlp::encode(&enr).len(); + let entry_size = alloy_rlp::encode(&enr).len(); // Responses assume that a session is established. Thus, on top of the encoded // ENR's the packet should be a regular message. A regular message has an IV (16 // bytes), and a header of 55 bytes. The find-nodes RPC requires 16 bytes for the ID and the