Skip to content

Commit

Permalink
better checking of tag duplicates, avoid discarding invalid variant e…
Browse files Browse the repository at this point in the history
…rrs (#951)

* better checking of tag duplicates, avoid discarding invalid variant errors

* add some simple derive tests :)

* use itertools duplicates()

* don't print out unreadable syn junk when encountering unknown attributes

* clarify and test backstop for the "multiple tags in a oneof variant" condition

* use expect_err

* nicer framing around the unknown attribute tokens

* express higher minimal versions of itertools and proc-macro2 for cargo hack check

* simplify bounds for proc-macro2

Co-authored-by: Casper Meijn <[email protected]>

* update the other instance of the .tuple_windows() trick to use .duplicates()

* clarify & shorten assertion

---------

Co-authored-by: Casper Meijn <[email protected]>
  • Loading branch information
mumbleskates and caspermeijn authored May 24, 2024
1 parent b9c4d3d commit eec717b
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 62 deletions.
4 changes: 2 additions & 2 deletions prost-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ proc_macro = true

[dependencies]
anyhow = "1.0.1"
itertools = { version = ">=0.10, <=0.12", default-features = false, features = ["use_alloc"] }
proc-macro2 = "1"
itertools = ">=0.10.1, <=0.12"
proc-macro2 = "1.0.60"
quote = "1"
syn = { version = "2", features = ["extra-traits"] }
9 changes: 5 additions & 4 deletions prost-derive/src/field/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes for group field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for group field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ impl Field {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
12 changes: 5 additions & 7 deletions prost-derive/src/field/oneof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for message field: {:?}",
unknown_attrs[0]
),
_ => bail!("unknown attributes for message field: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s) for message field: #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tags = match tags {
Expand Down
9 changes: 5 additions & 4 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ impl Field {
None => return Ok(None),
};

match unknown_attrs.len() {
0 => (),
1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
_ => bail!("unknown attributes: {:?}", unknown_attrs),
if !unknown_attrs.is_empty() {
bail!(
"unknown attribute(s): #[prost({})]",
quote!(#(#unknown_attrs),*)
);
}

let tag = match tag.or(inferred_tag) {
Expand Down
164 changes: 126 additions & 38 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ extern crate proc_macro;

use anyhow::{bail, Error};
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
Expand All @@ -19,7 +18,7 @@ mod field;
use crate::field::Field;

fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -91,16 +90,18 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
let fields = fields;

let mut tags = fields
if let Some(duplicate_tag) = fields
.iter()
.flat_map(|(_, field)| field.tags())
.collect::<Vec<_>>();
let num_tags = tags.len();
tags.sort_unstable();
tags.dedup();
if tags.len() != num_tags {
bail!("message {} has fields with duplicate tags", ident);
}
.duplicates()
.next()
{
bail!(
"message {} has multiple fields with tag {}",
ident,
duplicate_tag
)
};

let encoded_len = fields
.iter()
Expand Down Expand Up @@ -251,16 +252,16 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#methods
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Message, attributes(prost))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_message(input.into()).unwrap().into()
}

fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;
let ident = input.ident;

let generics = &input.generics;
Expand Down Expand Up @@ -359,16 +360,16 @@ fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Enumeration, attributes(prost))]
pub fn enumeration(input: TokenStream) -> TokenStream {
try_enumeration(input).unwrap()
pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_enumeration(input.into()).unwrap().into()
}

fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
let input: DeriveInput = syn::parse(input)?;
let input: DeriveInput = syn::parse2(input)?;

let ident = input.ident;

Expand Down Expand Up @@ -412,23 +413,21 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
}

let mut tags = fields
// Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple
// tags.
assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));

if let Some(duplicate_tag) = fields
.iter()
.flat_map(|(variant_ident, field)| -> Result<u32, Error> {
if field.tags().len() > 1 {
bail!(
"invalid oneof variant {}::{}: oneof variants may only have a single tag",
ident,
variant_ident
);
}
Ok(field.tags()[0])
})
.collect::<Vec<_>>();
tags.sort_unstable();
tags.dedup();
if tags.len() != fields.len() {
panic!("invalid oneof {}: variants have duplicate tags", ident);
.flat_map(|(_, field)| field.tags())
.duplicates()
.next()
{
bail!(
"invalid oneof {}: multiple variants have tag {}",
ident,
duplicate_tag
);
}

let encode = fields.iter().map(|(variant_ident, field)| {
Expand Down Expand Up @@ -519,10 +518,99 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
}
};

Ok(expanded.into())
Ok(expanded)
}

#[proc_macro_derive(Oneof, attributes(prost))]
pub fn oneof(input: TokenStream) -> TokenStream {
try_oneof(input).unwrap()
pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
try_oneof(input.into()).unwrap().into()
}

#[cfg(test)]
mod test {
use crate::{try_message, try_oneof};
use quote::quote;

#[test]
fn test_rejects_colliding_message_fields() {
let output = try_message(quote!(
struct Invalid {
#[prost(bool, tag = "1")]
a: bool,
#[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
b: Option<super::Whatever>,
}
));
assert_eq!(
output
.expect_err("did not reject colliding message fields")
.to_string(),
"message Invalid has multiple fields with tag 1"
);
}

#[test]
fn test_rejects_colliding_oneof_variants() {
let output = try_oneof(quote!(
pub enum Invalid {
#[prost(bool, tag = "1")]
A(bool),
#[prost(bool, tag = "3")]
B(bool),
#[prost(bool, tag = "1")]
C(bool),
}
));
assert_eq!(
output
.expect_err("did not reject colliding oneof variants")
.to_string(),
"invalid oneof Invalid: multiple variants have tag 1"
);
}

#[test]
fn test_rejects_multiple_tags_oneof_variant() {
let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "1", tag = "2")]
A(bool),
}
));
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 1 and 2"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tag = "3")]
#[prost(tag = "4")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"duplicate tag attributes: 3 and 4"
);

let output = try_oneof(quote!(
enum What {
#[prost(bool, tags = "5,6")]
A(bool),
}
));
assert!(output.is_err());
assert_eq!(
output
.expect_err("did not reject multiple tags on oneof variant")
.to_string(),
"unknown attribute(s): #[prost(tags = \"5,6\")]"
);
}
}

0 comments on commit eec717b

Please sign in to comment.