From eec717bd632ab2d6ba97a389bb491fa5ac7cb6b5 Mon Sep 17 00:00:00 2001 From: Kent Ross Date: Fri, 24 May 2024 03:01:09 -0700 Subject: [PATCH] better checking of tag duplicates, avoid discarding invalid variant errs (#951) * better checking of tag duplicates, avoid discarding invalid variant errors * add some simple derive tests :) * use itertools duplicates() * don't print out unreadable syn junk when encountering unknown attributes * clarify and test backstop for the "multiple tags in a oneof variant" condition * use expect_err * nicer framing around the unknown attribute tokens * express higher minimal versions of itertools and proc-macro2 for cargo hack check * simplify bounds for proc-macro2 Co-authored-by: Casper Meijn * update the other instance of the .tuple_windows() trick to use .duplicates() * clarify & shorten assertion --------- Co-authored-by: Casper Meijn --- prost-derive/Cargo.toml | 4 +- prost-derive/src/field/group.rs | 9 +- prost-derive/src/field/message.rs | 12 +-- prost-derive/src/field/oneof.rs | 12 +-- prost-derive/src/field/scalar.rs | 9 +- prost-derive/src/lib.rs | 164 +++++++++++++++++++++++------- 6 files changed, 148 insertions(+), 62 deletions(-) diff --git a/prost-derive/Cargo.toml b/prost-derive/Cargo.toml index 2ee5c28d1..5049aa24f 100644 --- a/prost-derive/Cargo.toml +++ b/prost-derive/Cargo.toml @@ -20,7 +20,7 @@ proc_macro = true [dependencies] anyhow = "1.0.1" -itertools = { version = ">=0.10, <=0.12", default-features = false, features = ["use_alloc"] } -proc-macro2 = "1" +itertools = ">=0.10.1, <=0.12" +proc-macro2 = "1.0.60" quote = "1" syn = { version = "2", features = ["extra-traits"] } diff --git a/prost-derive/src/field/group.rs b/prost-derive/src/field/group.rs index 076b577d7..485ecfc1b 100644 --- a/prost-derive/src/field/group.rs +++ b/prost-derive/src/field/group.rs @@ -38,10 +38,11 @@ impl Field { return Ok(None); } - match unknown_attrs.len() { - 0 => (), - 1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]), - _ => bail!("unknown attributes for group field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for group field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/message.rs b/prost-derive/src/field/message.rs index 3bcdddfb1..f6ac391e7 100644 --- a/prost-derive/src/field/message.rs +++ b/prost-derive/src/field/message.rs @@ -38,13 +38,11 @@ impl Field { return Ok(None); } - match unknown_attrs.len() { - 0 => (), - 1 => bail!( - "unknown attribute for message field: {:?}", - unknown_attrs[0] - ), - _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for message field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/field/oneof.rs b/prost-derive/src/field/oneof.rs index 78c77eeb1..ad1e32f19 100644 --- a/prost-derive/src/field/oneof.rs +++ b/prost-derive/src/field/oneof.rs @@ -44,13 +44,11 @@ impl Field { None => return Ok(None), }; - match unknown_attrs.len() { - 0 => (), - 1 => bail!( - "unknown attribute for message field: {:?}", - unknown_attrs[0] - ), - _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s) for message field: #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tags = match tags { diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 6be16cd70..c2e870524 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -46,10 +46,11 @@ impl Field { None => return Ok(None), }; - match unknown_attrs.len() { - 0 => (), - 1 => bail!("unknown attribute: {:?}", unknown_attrs[0]), - _ => bail!("unknown attributes: {:?}", unknown_attrs), + if !unknown_attrs.is_empty() { + bail!( + "unknown attribute(s): #[prost({})]", + quote!(#(#unknown_attrs),*) + ); } let tag = match tag.or(inferred_tag) { diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 472e1e2fd..42f0ccd1d 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -7,8 +7,7 @@ extern crate proc_macro; use anyhow::{bail, Error}; use itertools::Itertools; -use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, @@ -19,7 +18,7 @@ mod field; use crate::field::Field; fn try_message(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -91,16 +90,18 @@ fn try_message(input: TokenStream) -> Result { fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); let fields = fields; - let mut tags = fields + if let Some(duplicate_tag) = fields .iter() .flat_map(|(_, field)| field.tags()) - .collect::>(); - let num_tags = tags.len(); - tags.sort_unstable(); - tags.dedup(); - if tags.len() != num_tags { - bail!("message {} has fields with duplicate tags", ident); - } + .duplicates() + .next() + { + bail!( + "message {} has multiple fields with tag {}", + ident, + duplicate_tag + ) + }; let encoded_len = fields .iter() @@ -251,16 +252,16 @@ fn try_message(input: TokenStream) -> Result { #methods }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Message, attributes(prost))] -pub fn message(input: TokenStream) -> TokenStream { - try_message(input).unwrap() +pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_message(input.into()).unwrap().into() } fn try_enumeration(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; let generics = &input.generics; @@ -359,16 +360,16 @@ fn try_enumeration(input: TokenStream) -> Result { } }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Enumeration, attributes(prost))] -pub fn enumeration(input: TokenStream) -> TokenStream { - try_enumeration(input).unwrap() +pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_enumeration(input.into()).unwrap().into() } fn try_oneof(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; + let input: DeriveInput = syn::parse2(input)?; let ident = input.ident; @@ -412,23 +413,21 @@ fn try_oneof(input: TokenStream) -> Result { } } - let mut tags = fields + // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple + // tags. + assert!(fields.iter().all(|(_, field)| field.tags().len() == 1)); + + if let Some(duplicate_tag) = fields .iter() - .flat_map(|(variant_ident, field)| -> Result { - if field.tags().len() > 1 { - bail!( - "invalid oneof variant {}::{}: oneof variants may only have a single tag", - ident, - variant_ident - ); - } - Ok(field.tags()[0]) - }) - .collect::>(); - tags.sort_unstable(); - tags.dedup(); - if tags.len() != fields.len() { - panic!("invalid oneof {}: variants have duplicate tags", ident); + .flat_map(|(_, field)| field.tags()) + .duplicates() + .next() + { + bail!( + "invalid oneof {}: multiple variants have tag {}", + ident, + duplicate_tag + ); } let encode = fields.iter().map(|(variant_ident, field)| { @@ -519,10 +518,99 @@ fn try_oneof(input: TokenStream) -> Result { } }; - Ok(expanded.into()) + Ok(expanded) } #[proc_macro_derive(Oneof, attributes(prost))] -pub fn oneof(input: TokenStream) -> TokenStream { - try_oneof(input).unwrap() +pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + try_oneof(input.into()).unwrap().into() +} + +#[cfg(test)] +mod test { + use crate::{try_message, try_oneof}; + use quote::quote; + + #[test] + fn test_rejects_colliding_message_fields() { + let output = try_message(quote!( + struct Invalid { + #[prost(bool, tag = "1")] + a: bool, + #[prost(oneof = "super::Whatever", tags = "4, 5, 1")] + b: Option, + } + )); + assert_eq!( + output + .expect_err("did not reject colliding message fields") + .to_string(), + "message Invalid has multiple fields with tag 1" + ); + } + + #[test] + fn test_rejects_colliding_oneof_variants() { + let output = try_oneof(quote!( + pub enum Invalid { + #[prost(bool, tag = "1")] + A(bool), + #[prost(bool, tag = "3")] + B(bool), + #[prost(bool, tag = "1")] + C(bool), + } + )); + assert_eq!( + output + .expect_err("did not reject colliding oneof variants") + .to_string(), + "invalid oneof Invalid: multiple variants have tag 1" + ); + } + + #[test] + fn test_rejects_multiple_tags_oneof_variant() { + let output = try_oneof(quote!( + enum What { + #[prost(bool, tag = "1", tag = "2")] + A(bool), + } + )); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "duplicate tag attributes: 1 and 2" + ); + + let output = try_oneof(quote!( + enum What { + #[prost(bool, tag = "3")] + #[prost(tag = "4")] + A(bool), + } + )); + assert!(output.is_err()); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "duplicate tag attributes: 3 and 4" + ); + + let output = try_oneof(quote!( + enum What { + #[prost(bool, tags = "5,6")] + A(bool), + } + )); + assert!(output.is_err()); + assert_eq!( + output + .expect_err("did not reject multiple tags on oneof variant") + .to_string(), + "unknown attribute(s): #[prost(tags = \"5,6\")]" + ); + } }