diff --git a/fuzz/src/onion_message.rs b/fuzz/src/onion_message.rs index 61eeb288b5c..c54947e828b 100644 --- a/fuzz/src/onion_message.rs +++ b/fuzz/src/onion_message.rs @@ -10,8 +10,8 @@ use lightning::ln::msgs::{self, DecodeError, OnionMessageHandler}; use lightning::ln::script::ShutdownScript; use lightning::util::enforcing_trait_impls::EnforcingSigner; use lightning::util::logger::Logger; -use lightning::util::ser::{Readable, Writeable, Writer}; -use lightning::onion_message::{CustomOnionMessageHandler, CustomMessageReadable, OnionMessageContents, OnionMessenger}; +use lightning::util::ser::{MaybeReadableArgs, Readable, Writeable, Writer}; +use lightning::onion_message::{CustomOnionMessageHandler, OnionMessageContents, OnionMessenger}; use utils::test_logger; @@ -67,8 +67,8 @@ impl Writeable for TestCustomMessage { } } -impl CustomMessageReadable for TestCustomMessage { - fn read(_message_type: u64, buffer: &mut R) -> Result, DecodeError> where Self: Sized { +impl MaybeReadableArgs for TestCustomMessage { + fn read(buffer: &mut R, _message_type: u64,) -> Result, DecodeError> where Self: Sized { let mut buf = Vec::new(); buffer.read_to_end(&mut buf)?; return Ok(Some(TestCustomMessage {})) diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 3ff0c32550c..7dfe12edff6 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -21,11 +21,11 @@ use ln::features::{InitFeatures, NodeFeatures}; use ln::msgs; use ln::msgs::{ChannelMessageHandler, LightningError, NetAddress, OnionMessageHandler, RoutingMessageHandler}; use ln::channelmanager::{SimpleArcChannelManager, SimpleRefChannelManager}; -use util::ser::{VecWriter, Writeable, Writer}; +use util::ser::{MaybeReadableArgs, VecWriter, Writeable, Writer}; use ln::peer_channel_encryptor::{PeerChannelEncryptor,NextNoiseStep}; use ln::wire; use ln::wire::Encode; -use onion_message::{CustomOnionMessageHandler, SimpleArcOnionMessenger, SimpleRefOnionMessenger}; +use onion_message::{CustomOnionMessageHandler, OnionMessageContents, SimpleArcOnionMessenger, SimpleRefOnionMessenger}; use routing::gossip::{NetworkGraph, P2PGossipSync}; use util::atomic_counter::AtomicCounter; use util::crypto::sign; @@ -96,9 +96,22 @@ impl OnionMessageHandler for IgnoringMessageHandler { } } impl CustomOnionMessageHandler for IgnoringMessageHandler { - type CustomMessage = (); - fn handle_custom_message(&self, _msg: Self::CustomMessage) {} + type CustomMessage = Infallible; + fn handle_custom_message(&self, _msg: Self::CustomMessage) { + // Since we always return `None` in the read the handle method should never be called. + unreachable!(); + } +} +impl MaybeReadableArgs for Infallible { + fn read(_buffer: &mut R, _msg_type: u64) -> Result, msgs::DecodeError> where Self: Sized { + Ok(None) + } } + +impl OnionMessageContents for Infallible { + fn tlv_type(&self) -> u64 { unreachable!(); } +} + impl Deref for IgnoringMessageHandler { type Target = IgnoringMessageHandler; fn deref(&self) -> &Self { self } diff --git a/lightning/src/onion_message/functional_tests.rs b/lightning/src/onion_message/functional_tests.rs index c04ad535960..3fb127890dd 100644 --- a/lightning/src/onion_message/functional_tests.rs +++ b/lightning/src/onion_message/functional_tests.rs @@ -12,9 +12,9 @@ use chain::keysinterface::{KeysInterface, Recipient}; use ln::features::InitFeatures; use ln::msgs::{self, DecodeError, OnionMessageHandler}; -use super::{BlindedRoute, CustomOnionMessageHandler, CustomMessageReadable, Destination, Message, OnionMessageContents, OnionMessenger, SendError}; +use super::{BlindedRoute, CustomOnionMessageHandler, Destination, Message, OnionMessageContents, OnionMessenger, SendError}; use util::enforcing_trait_impls::EnforcingSigner; -use util::ser::{Writeable, Writer}; +use util::ser::{MaybeReadableArgs, Writeable, Writer}; use util::test_utils; use bitcoin::network::constants::Network; @@ -54,8 +54,8 @@ impl Writeable for TestCustomMessage { } } -impl CustomMessageReadable for TestCustomMessage { - fn read(message_type: u64, buffer: &mut R) -> Result, DecodeError> where Self: Sized { +impl MaybeReadableArgs for TestCustomMessage { + fn read(buffer: &mut R, message_type: u64) -> Result, DecodeError> where Self: Sized { if message_type == CUSTOM_MESSAGE_TYPE { let mut buf = Vec::new(); buffer.read_to_end(&mut buf)?; diff --git a/lightning/src/onion_message/messenger.rs b/lightning/src/onion_message/messenger.rs index 42a8c7bad58..bd09b76516e 100644 --- a/lightning/src/onion_message/messenger.rs +++ b/lightning/src/onion_message/messenger.rs @@ -21,7 +21,7 @@ use ln::msgs::{self, OnionMessageHandler}; use ln::onion_utils; use ln::peer_handler::IgnoringMessageHandler; use super::blinded_route::{BlindedRoute, ForwardTlvs, ReceiveTlvs}; -pub use super::packet::{CustomMessageReadable, Message, OnionMessageContents}; +pub use super::packet::{Message, OnionMessageContents}; use super::packet::{BIG_PACKET_HOP_DATA_LEN, ForwardControlTlvs, Packet, Payload, ReceiveControlTlvs, SMALL_PACKET_HOP_DATA_LEN}; use super::utils; use util::events::OnionMessageProvider; @@ -43,9 +43,12 @@ use prelude::*; /// # use bitcoin::hashes::_export::_core::time::Duration; /// # use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; /// # use lightning::chain::keysinterface::{InMemorySigner, KeysManager, KeysInterface}; +/// # use lightning::ln::msgs::DecodeError; /// # use lightning::ln::peer_handler::IgnoringMessageHandler; -/// # use lightning::onion_message::{BlindedRoute, Destination, Message, OnionMessenger}; +/// # use lightning::onion_message::{BlindedRoute, Destination, Message, OnionMessageContents, OnionMessenger}; /// # use lightning::util::logger::{Logger, Record}; +/// # use lightning::util::ser::{MaybeReadableArgs, Writeable, Writer}; +/// # use lightning::io; /// # use std::sync::Arc; /// # struct FakeLogger {}; /// # impl Logger for FakeLogger { @@ -58,18 +61,37 @@ use prelude::*; /// # let node_secret = SecretKey::from_slice(&hex::decode("0101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]).unwrap(); /// # let secp_ctx = Secp256k1::new(); /// # let hop_node_id1 = PublicKey::from_secret_key(&secp_ctx, &node_secret); -/// # let (hop_node_id2, hop_node_id3, hop_node_id4) = (hop_node_id1, hop_node_id1, -/// hop_node_id1); +/// # let (hop_node_id2, hop_node_id3, hop_node_id4) = (hop_node_id1, hop_node_id1, hop_node_id1); /// # let destination_node_id = hop_node_id1; /// # let your_custom_message_handler = IgnoringMessageHandler {}; /// // Create the onion messenger. This must use the same `keys_manager` as is passed to your /// // ChannelManager. /// let onion_messenger = OnionMessenger::new(&keys_manager, logger, your_custom_message_handler); /// -/// // Send an custom onion message to a node id. +/// # struct YourCustomMessage {} +/// impl Writeable for YourCustomMessage { +/// fn write(&self, w: &mut W) -> Result<(), io::Error> { +/// # Ok(()) +/// // Write your custom onion message to `w` +/// } +/// } +/// impl OnionMessageContents for YourCustomMessage { +/// fn tlv_type(&self) -> u64 { +/// # let your_custom_message_type = 42; +/// your_custom_message_type +/// } +/// } +/// impl MaybeReadableArgs for YourCustomMessage { +/// fn read(r: &mut R, message_type: u64) -> Result, DecodeError> { +/// # unreachable!() +/// // Read your custom onion message of type `message_type` from `r`, or return `None` +/// // if the message type is unknown +/// } +/// } +/// // Send a custom onion message to a node id. /// let intermediate_hops = [hop_node_id1, hop_node_id2]; /// let reply_path = None; -/// # let your_custom_message = (); +/// # let your_custom_message = YourCustomMessage {}; /// let message = Message::Custom(your_custom_message); /// onion_messenger.send_onion_message(&intermediate_hops, Destination::Node(destination_node_id), message, reply_path); /// @@ -81,6 +103,7 @@ use prelude::*; /// // Send a custom onion message to a blinded route. /// # let intermediate_hops = [hop_node_id1, hop_node_id2]; /// let reply_path = None; +/// # let your_custom_message = YourCustomMessage {}; /// let message = Message::Custom(your_custom_message); /// onion_messenger.send_onion_message(&intermediate_hops, Destination::BlindedRoute(blinded_route), message, reply_path); /// ``` @@ -419,8 +442,8 @@ pub type SimpleRefOnionMessenger<'a, 'b, L> = OnionMessenger( - secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], destination: Destination, message: - Message, mut reply_path: Option, session_priv: &SecretKey + secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], destination: Destination, + message: Message, mut reply_path: Option, session_priv: &SecretKey ) -> Result<(Vec<(Payload, [u8; 32])>, Vec), secp256k1::Error> { let num_hops = unblinded_path.len() + destination.num_hops(); let mut payloads = Vec::with_capacity(num_hops); diff --git a/lightning/src/onion_message/mod.rs b/lightning/src/onion_message/mod.rs index 4775bff5a3b..f240064bd01 100644 --- a/lightning/src/onion_message/mod.rs +++ b/lightning/src/onion_message/mod.rs @@ -29,5 +29,5 @@ mod functional_tests; // Re-export structs so they can be imported with just the `onion_message::` module prefix. pub use self::blinded_route::{BlindedRoute, BlindedHop}; -pub use self::messenger::{CustomOnionMessageHandler, CustomMessageReadable, Destination, Message, OnionMessageContents, OnionMessenger, SendError, SimpleArcOnionMessenger, SimpleRefOnionMessenger}; +pub use self::messenger::{CustomOnionMessageHandler, Destination, Message, OnionMessageContents, OnionMessenger, SendError, SimpleArcOnionMessenger, SimpleRefOnionMessenger}; pub(crate) use self::packet::Packet; diff --git a/lightning/src/onion_message/packet.rs b/lightning/src/onion_message/packet.rs index 23855f6a9f9..2a0943932bb 100644 --- a/lightning/src/onion_message/packet.rs +++ b/lightning/src/onion_message/packet.rs @@ -16,7 +16,7 @@ use ln::msgs::DecodeError; use ln::onion_utils; use super::blinded_route::{BlindedRoute, ForwardTlvs, ReceiveTlvs}; use util::chacha20poly1305rfc::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter}; -use util::ser::{BigSize, FixedLengthReader, LengthRead, LengthReadable, LengthReadableArgs, Readable, ReadableArgs, Writeable, Writer}; +use util::ser::{BigSize, FixedLengthReader, LengthRead, LengthReadable, LengthReadableArgs, MaybeReadableArgs, Readable, ReadableArgs, Writeable, Writer}; use core::cmp; use io::{self, Read}; @@ -103,15 +103,6 @@ pub(super) enum Payload { } } -/// Trait to be implemented by custom onion messages. -pub trait CustomMessageReadable { - /// Decodes a custom message given the message type. If the given message type is known to the - /// implementation and the message could be decoded, must return `Ok(Some(message))`. If the - /// message type is unknown to the implementation, must return `Ok(None)`. If a decoding error - /// occur, must return `Err(DecodeError::X)` where `X` details the encountered error. - fn read(message_type: u64, buffer: &mut R) -> Result, DecodeError> where Self: Sized; -} - #[derive(Debug)] /// The contents of an onion message. In the context of offers, this would be the invoice, invoice /// request, or invoice error. @@ -142,21 +133,11 @@ impl Writeable for Message { } /// Defines a type identifier for encoding an onion message's contents. -pub trait OnionMessageContents: Writeable + CustomMessageReadable { +pub trait OnionMessageContents: Writeable + MaybeReadableArgs { /// Returns the type identifying the message payload. fn tlv_type(&self) -> u64; } -impl CustomMessageReadable for () { - fn read(_message_type: u64, _buffer: &mut R) -> Result, DecodeError> where Self: Sized { - Ok(None) - } -} - -impl OnionMessageContents for () { - fn tlv_type(&self) -> u64 { u64::max_value() } -} - /// Forward control TLVs in their blinded and unblinded form. pub(super) enum ForwardControlTlvs { /// If we're sending to a blinded route, the node that constructed the blinded route has provided @@ -223,27 +204,25 @@ impl ReadableArgs for Payload { let mut reply_path: Option = None; let mut read_adapter: Option> = None; let rho = onion_utils::gen_rho_from_shared_secret(&encrypted_tlvs_ss.secret_bytes()); - let (mut message_type, mut message) = (None, None); + let mut message_type: Option = None; + let mut message = None; decode_tlv_stream!(&mut rd, { (2, reply_path, option), (4, read_adapter, (option: LengthReadableArgs, rho)), }, |msg_type, msg_reader| { - if msg_type >= 64 { - // Don't allow reading more than one data TLV from an onion message. - if message.is_some() || message_type.is_some() { - return Err(DecodeError::InvalidValue) - } - message_type = Some(msg_type); - match T::read(msg_type, msg_reader) { - Ok(Some(msg)) => { - message = Some(msg); - return Ok(true) - }, - Ok(None) => return Ok(false), - Err(e) => return Err(e), - } + if msg_type < 64 { return Ok(false) } + // Don't allow reading more than one data TLV from an onion message. + if message_type.is_some() { return Err(DecodeError::InvalidValue) } + + message_type = Some(msg_type); + match T::read(msg_reader, msg_type) { + Ok(Some(msg)) => { + message = Some(msg); + Ok(true) + }, + Ok(None) => Ok(false), + Err(e) => Err(e), } - Ok(false) }); rd.eat_remaining().map_err(|_| DecodeError::ShortRead)?; diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index 852aa8f1589..40ab35eb8c1 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -269,6 +269,14 @@ impl MaybeReadable for T { } } +/// A trait that various rust-lightning types implement allowing them to (maybe) be read in from a Read, given some additional set of arguments which is required to deserialize. +/// +/// (C-not exported) as we only export serialization to/from byte arrays instead +pub trait MaybeReadableArgs

{ + /// Reads a Self in from the given Read + fn read(reader: &mut R, params: P) -> Result, DecodeError> where Self: Sized; +} + pub(crate) struct OptionDeserWrapper(pub Option); impl Readable for OptionDeserWrapper { #[inline]