diff --git a/Cargo.toml b/Cargo.toml index e3b4ed4a120..9acad187586 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ libp2p-identity = { workspace = true } log = "0.4" rand = "0.8" quick-protobuf = "0.8" +quick-protobuf-codec = { workspace = true } +asynchronous-codec = "0.6.2" [dev-dependencies] async-std = { version = "1.10", features = ["attributes"] } diff --git a/src/protocol.rs b/src/protocol.rs index a63fd8cdf4d..904af6473e2 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -20,12 +20,13 @@ use crate::proto; use async_trait::async_trait; -use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use libp2p_core::{upgrade, Multiaddr}; +use asynchronous_codec::{FramedRead, FramedWrite}; +use futures::io::{AsyncRead, AsyncWrite}; +use futures::{SinkExt, StreamExt}; +use libp2p_core::Multiaddr; use libp2p_identity::PeerId; use libp2p_request_response::{self as request_response}; use libp2p_swarm::StreamProtocol; -use quick_protobuf::{BytesReader, Writer}; use std::{convert::TryFrom, io}; /// The protocol name used for negotiating with multistream-select. @@ -44,8 +45,12 @@ impl request_response::Codec for AutoNatCodec { where T: AsyncRead + Send + Unpin, { - let bytes = upgrade::read_length_prefixed(io, 1024).await?; - let request = DialRequest::from_bytes(&bytes)?; + let message = FramedRead::new(io, codec()) + .next() + .await + .ok_or(io::ErrorKind::UnexpectedEof)??; + let request = DialRequest::from_proto(message)?; + Ok(request) } @@ -57,8 +62,12 @@ impl request_response::Codec for AutoNatCodec { where T: AsyncRead + Send + Unpin, { - let bytes = upgrade::read_length_prefixed(io, 1024).await?; - let response = DialResponse::from_bytes(&bytes)?; + let message = FramedRead::new(io, codec()) + .next() + .await + .ok_or(io::ErrorKind::UnexpectedEof)??; + let response = DialResponse::from_proto(message)?; + Ok(response) } @@ -71,8 +80,11 @@ impl request_response::Codec for AutoNatCodec { where T: AsyncWrite + Send + Unpin, { - upgrade::write_length_prefixed(io, data.into_bytes()).await?; - io.close().await + let mut framed = FramedWrite::new(io, codec()); + framed.send(data.into_proto()).await?; + framed.close().await?; + + Ok(()) } async fn write_response( @@ -84,11 +96,18 @@ impl request_response::Codec for AutoNatCodec { where T: AsyncWrite + Send + Unpin, { - upgrade::write_length_prefixed(io, data.into_bytes()).await?; - io.close().await + let mut framed = FramedWrite::new(io, codec()); + framed.send(data.into_proto()).await?; + framed.close().await?; + + Ok(()) } } +fn codec() -> quick_protobuf_codec::Codec { + quick_protobuf_codec::Codec::::new(1024) +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct DialRequest { pub peer_id: PeerId, @@ -96,12 +115,7 @@ pub struct DialRequest { } impl DialRequest { - pub fn from_bytes(bytes: &[u8]) -> Result { - use quick_protobuf::MessageRead; - - let mut reader = BytesReader::from_bytes(bytes); - let msg = proto::Message::from_reader(&mut reader, bytes) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + pub fn from_proto(msg: proto::Message) -> Result { if msg.type_pb != Some(proto::MessageType::DIAL) { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type")); } @@ -143,9 +157,7 @@ impl DialRequest { }) } - pub fn into_bytes(self) -> Vec { - use quick_protobuf::MessageWrite; - + pub fn into_proto(self) -> proto::Message { let peer_id = self.peer_id.to_bytes(); let addrs = self .addresses @@ -153,7 +165,7 @@ impl DialRequest { .map(|addr| addr.to_vec()) .collect(); - let msg = proto::Message { + proto::Message { type_pb: Some(proto::MessageType::DIAL), dial: Some(proto::Dial { peer: Some(proto::PeerInfo { @@ -162,12 +174,7 @@ impl DialRequest { }), }), dialResponse: None, - }; - - let mut buf = Vec::with_capacity(msg.get_size()); - let mut writer = Writer::new(&mut buf); - msg.write_message(&mut writer).expect("Encoding to succeed"); - buf + } } } @@ -217,12 +224,7 @@ pub struct DialResponse { } impl DialResponse { - pub fn from_bytes(bytes: &[u8]) -> Result { - use quick_protobuf::MessageRead; - - let mut reader = BytesReader::from_bytes(bytes); - let msg = proto::Message::from_reader(&mut reader, bytes) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + pub fn from_proto(msg: proto::Message) -> Result { if msg.type_pb != Some(proto::MessageType::DIAL_RESPONSE) { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type")); } @@ -258,9 +260,7 @@ impl DialResponse { }) } - pub fn into_bytes(self) -> Vec { - use quick_protobuf::MessageWrite; - + pub fn into_proto(self) -> proto::Message { let dial_response = match self.result { Ok(addr) => proto::DialResponse { status: Some(proto::ResponseStatus::OK), @@ -274,23 +274,17 @@ impl DialResponse { }, }; - let msg = proto::Message { + proto::Message { type_pb: Some(proto::MessageType::DIAL_RESPONSE), dial: None, dialResponse: Some(dial_response), - }; - - let mut buf = Vec::with_capacity(msg.get_size()); - let mut writer = Writer::new(&mut buf); - msg.write_message(&mut writer).expect("Encoding to succeed"); - buf + } } } #[cfg(test)] mod tests { use super::*; - use quick_protobuf::MessageWrite; #[test] fn test_request_encode_decode() { @@ -301,8 +295,8 @@ mod tests { "/ip4/192.168.1.42/tcp/30333".parse().unwrap(), ], }; - let bytes = request.clone().into_bytes(); - let request2 = DialRequest::from_bytes(&bytes).unwrap(); + let proto = request.clone().into_proto(); + let request2 = DialRequest::from_proto(proto).unwrap(); assert_eq!(request, request2); } @@ -312,8 +306,8 @@ mod tests { result: Ok("/ip4/8.8.8.8/tcp/30333".parse().unwrap()), status_text: None, }; - let bytes = response.clone().into_bytes(); - let response2 = DialResponse::from_bytes(&bytes).unwrap(); + let proto = response.clone().into_proto(); + let response2 = DialResponse::from_proto(proto).unwrap(); assert_eq!(response, response2); } @@ -323,8 +317,8 @@ mod tests { result: Err(ResponseError::DialError), status_text: Some("dial failed".to_string()), }; - let bytes = response.clone().into_bytes(); - let response2 = DialResponse::from_bytes(&bytes).unwrap(); + let proto = response.clone().into_proto(); + let response2 = DialResponse::from_proto(proto).unwrap(); assert_eq!(response, response2); } @@ -350,11 +344,7 @@ mod tests { dialResponse: None, }; - let mut bytes = Vec::with_capacity(msg.get_size()); - let mut writer = Writer::new(&mut bytes); - msg.write_message(&mut writer).expect("Encoding to succeed"); - - let request = DialRequest::from_bytes(&bytes).expect("not to fail"); + let request = DialRequest::from_proto(msg).expect("not to fail"); assert_eq!(request.addresses, vec![valid_multiaddr]) }