Skip to content

Commit

Permalink
WIP: Limit TLV stream decoding to type ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
jkczyz committed Sep 30, 2022
1 parent 3db190e commit 91c2906
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 18 deletions.
40 changes: 37 additions & 3 deletions lightning/src/offers/invoice_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use offers::merkle::{SignatureTlvStream, SignatureTlvStreamRef, self};
use offers::offer::{Amount, Offer, OfferContents, OfferTlvStream, OfferTlvStreamRef};
use offers::parse::{ParseError, SemanticError};
use offers::payer::{PayerContents, PayerTlvStream, PayerTlvStreamRef};
use util::ser::{HighZeroBytesDroppedBigSize, Readable, WithoutLength, Writeable, Writer};
use util::ser::{HighZeroBytesDroppedBigSize, SeekReadable, WithoutLength, Writeable, Writer};

use prelude::*;

Expand Down Expand Up @@ -341,12 +341,14 @@ impl TryFrom<Vec<u8>> for InvoiceRequest {
type Error = ParseError;

fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
let tlv_stream: FullInvoiceRequestTlvStream = Readable::read(&mut &bytes[..])?;
let mut cursor = io::Cursor::new(bytes);
let tlv_stream: FullInvoiceRequestTlvStream = SeekReadable::read(&mut cursor)?;
let bytes = cursor.into_inner();
InvoiceRequest::try_from((bytes, tlv_stream))
}
}

tlv_stream!(InvoiceRequestTlvStream, InvoiceRequestTlvStreamRef, {
tlv_stream!(InvoiceRequestTlvStream, InvoiceRequestTlvStreamRef, 80..160, {
(80, chain: ChainHash),
(82, amount: (u64, HighZeroBytesDroppedBigSize)),
(84, features: OfferFeatures),
Expand Down Expand Up @@ -448,3 +450,35 @@ impl TryFrom<PartialInvoiceRequestTlvStream> for InvoiceRequestContents {
})
}
}

#[cfg(test)]
mod tests {
use super::InvoiceRequest;

use bitcoin::secp256k1::{KeyPair, Secp256k1, SecretKey};
use core::convert::TryFrom;
use offers::offer::OfferBuilder;
use util::ser::Writeable;

#[test]
fn builds_invoice_request_with_amount_msats() {
let secp_ctx = Secp256k1::new();
let keys = KeyPair::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
let invoice_request = OfferBuilder::new("foo".into(), keys.public_key())
.build()
.unwrap()
.request_invoice(keys.public_key())
.amount_msats(1000)
.build()
.unwrap()
.sign(|digest| secp_ctx.sign_schnorr_no_aux_rand(digest, &keys))
.unwrap();
assert_eq!(invoice_request.amount_msats(), Some(1000));

let mut buffer = Vec::new();
invoice_request.write(&mut buffer).unwrap();

let invoice_request = InvoiceRequest::try_from(buffer).unwrap();
assert_eq!(invoice_request.amount_msats(), Some(1000));
}
}
2 changes: 1 addition & 1 deletion lightning/src/offers/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use prelude::*;
/// Valid type range for signature TLV records.
const SIGNATURE_TYPES: core::ops::RangeInclusive<u64> = 240..=1000;

tlv_stream!(SignatureTlvStream, SignatureTlvStreamRef, {
tlv_stream!(SignatureTlvStream, SignatureTlvStreamRef, SIGNATURE_TYPES, {
(240, signature: Signature),
});

Expand Down
8 changes: 5 additions & 3 deletions lightning/src/offers/offer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use ln::msgs::MAX_VALUE_MSAT;
use offers::invoice_request::InvoiceRequestBuilder;
use offers::parse::{Bech32Encode, ParseError, SemanticError};
use onion_message::BlindedPath;
use util::ser::{HighZeroBytesDroppedBigSize, Readable, WithoutLength, Writeable, Writer};
use util::ser::{HighZeroBytesDroppedBigSize, SeekReadable, WithoutLength, Writeable, Writer};

use prelude::*;

Expand Down Expand Up @@ -497,7 +497,9 @@ impl TryFrom<Vec<u8>> for Offer {
type Error = ParseError;

fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
let tlv_stream: OfferTlvStream = Readable::read(&mut &bytes[..])?;
let mut cursor = io::Cursor::new(bytes);
let tlv_stream: OfferTlvStream = SeekReadable::read(&mut cursor)?;
let bytes = cursor.into_inner();
Offer::try_from((bytes, tlv_stream))
}
}
Expand Down Expand Up @@ -533,7 +535,7 @@ impl Amount {
/// An ISO 4712 three-letter currency code (e.g., USD).
pub type CurrencyCode = [u8; 3];

tlv_stream!(OfferTlvStream, OfferTlvStreamRef, {
tlv_stream!(OfferTlvStream, OfferTlvStreamRef, 1..80, {
(2, chains: (Vec<ChainHash>, WithoutLength)),
(4, metadata: (Vec<u8>, WithoutLength)),
(6, currency: CurrencyCode),
Expand Down
10 changes: 7 additions & 3 deletions lightning/src/offers/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ use bitcoin::bech32;
use bitcoin::bech32::{FromBase32, ToBase32};
use bitcoin::secp256k1;
use core::fmt;
use io;
use ln::msgs::DecodeError;
use util::ser::Readable;
use util::ser::SeekReadable;

use prelude::*;

/// Indicates a message can be encoded using bech32.
pub(crate) trait Bech32Encode: AsRef<[u8]> {
/// TLV stream that a bech32-encoded message is parsed into.
type TlvStream: Readable;
type TlvStream: SeekReadable;

/// Human readable part of the message's bech32 encoding.
const BECH32_HRP: &'static str;
Expand All @@ -44,7 +45,10 @@ pub(crate) trait Bech32Encode: AsRef<[u8]> {
}

let data = Vec::<u8>::from_base32(&data)?;
Ok((Readable::read(&mut &data[..])?, data))
let mut cursor = io::Cursor::new(data);
let tlv_stream = SeekReadable::read(&mut cursor)?;
let bytes = cursor.into_inner();
Ok((tlv_stream, bytes))
}

/// Formats the message using bech32-encoding.
Expand Down
2 changes: 1 addition & 1 deletion lightning/src/offers/payer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ use prelude::*;
#[derive(Clone, Debug)]
pub(crate) struct PayerContents(pub Option<Vec<u8>>);

tlv_stream!(PayerTlvStream, PayerTlvStreamRef, {
tlv_stream!(PayerTlvStream, PayerTlvStreamRef, 0..1, {
(0, metadata: (Vec<u8>, WithoutLength)),
});
23 changes: 22 additions & 1 deletion lightning/src/util/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//! as ChannelsManagers and ChannelMonitors.
use prelude::*;
use io::{self, Read, Write};
use io::{self, Read, Seek, Write};
use io_extras::{copy, sink};
use core::hash::Hash;
use sync::Mutex;
Expand Down Expand Up @@ -220,6 +220,17 @@ pub trait Readable
fn read<R: Read>(reader: &mut R) -> Result<Self, DecodeError>;
}

/// A trait that various rust-lightning types implement allowing them to be read in from a
/// `Read + Seek`.
///
/// (C-not exported) as we only export serialization to/from byte arrays instead
pub trait SeekReadable
where Self: Sized
{
/// Reads a Self in from the given Read
fn read<R: Read + Seek>(reader: &mut R) -> Result<Self, DecodeError>;
}

/// A trait that various higher-level rust-lightning types implement allowing them to be read in
/// from a Read given some additional set of arguments which is required to deserialize.
///
Expand Down Expand Up @@ -1026,6 +1037,16 @@ impl<A: Writeable, B: Writeable, C: Writeable, D: Writeable> Writeable for (A, B
}
}

impl<A: SeekReadable, B: SeekReadable, C: SeekReadable, D: SeekReadable> SeekReadable for (A, B, C, D) {
fn read<R: Read + Seek>(r: &mut R) -> Result<Self, DecodeError> {
let a: A = SeekReadable::read(r)?;
let b: B = SeekReadable::read(r)?;
let c: C = SeekReadable::read(r)?;
let d: D = SeekReadable::read(r)?;
Ok((a, b, c, d))
}
}

impl Writeable for () {
fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
Ok(())
Expand Down
28 changes: 22 additions & 6 deletions lightning/src/util/ser_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ macro_rules! decode_tlv {

macro_rules! decode_tlv_stream {
($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { {
let rewind = |_, _| { unreachable!() };
decode_tlv_stream_range!($stream, 0.., rewind, {$(($type, $field, $fieldty)),*});
} }
}

macro_rules! decode_tlv_stream_range {
($stream: expr, $range: expr, $rewind: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { {
use ln::msgs::DecodeError;
let mut last_seen_type: Option<u64> = None;
let mut stream_ref = $stream;
Expand All @@ -210,7 +217,7 @@ macro_rules! decode_tlv_stream {
// UnexpectedEof. This should in every case be largely cosmetic, but its nice to
// pass the TLV test vectors exactly, which requre this distinction.
let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref);
match ser::Readable::read(&mut tracking_reader) {
match <ser::BigSize as ser::Readable>::read(&mut tracking_reader) {
Err(DecodeError::ShortRead) => {
if !tracking_reader.have_read {
break 'tlv_read;
Expand All @@ -219,7 +226,13 @@ macro_rules! decode_tlv_stream {
}
},
Err(e) => return Err(e),
Ok(t) => t,
Ok(t) => if $range.contains(&t.0) { t } else {
use util::ser::Writeable;
drop(tracking_reader);
let bytes_read = t.serialized_length();
$rewind(stream_ref, bytes_read);
break 'tlv_read;
},
}
};

Expand Down Expand Up @@ -458,7 +471,7 @@ macro_rules! impl_writeable_tlv_based {
/// [`Readable`]: crate::util::ser::Readable
/// [`Writeable`]: crate::util::ser::Writeable
macro_rules! tlv_stream {
($name:ident, $nameref:ident, {
($name:ident, $nameref:ident, $range:expr, {
$(($type:expr, $field:ident : $fieldty:tt)),* $(,)*
}) => {
#[derive(Debug)]
Expand All @@ -483,12 +496,15 @@ macro_rules! tlv_stream {
}
}

impl ::util::ser::Readable for $name {
fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
impl ::util::ser::SeekReadable for $name {
fn read<R: $crate::io::Read + $crate::io::Seek>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
$(
init_tlv_field_var!($field, option);
)*
decode_tlv_stream!(reader, {
let rewind = |cursor: &mut R, offset: usize| {
cursor.seek($crate::io::SeekFrom::Current(-(offset as i64))).expect("");
};
decode_tlv_stream_range!(reader, $range, rewind, {
$(($type, $field, (option, encoding: $fieldty))),*
});

Expand Down

0 comments on commit 91c2906

Please sign in to comment.