From f5735bd09bae0cd980ac64f22a5c5b1cdc61a555 Mon Sep 17 00:00:00 2001 From: Angad Tendulkar Date: Tue, 5 Mar 2024 17:03:43 -0500 Subject: [PATCH 1/4] implement #[location] attr and location field for enums --- impl/src/attr.rs | 8 +++ impl/src/expand.rs | 158 +++++++++++++++++++++++++---------------- impl/src/lib.rs | 2 +- impl/src/prop.rs | 32 +++++++++ impl/src/valid.rs | 19 ++++- tests/test_location.rs | 24 +++++++ 6 files changed, 178 insertions(+), 65 deletions(-) create mode 100644 tests/test_location.rs diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 269c69e..e924692 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -12,6 +12,7 @@ pub struct Attrs<'a> { pub display: Option>, pub source: Option<&'a Attribute>, pub backtrace: Option<&'a Attribute>, + pub location: Option<&'a Attribute>, pub from: Option<&'a Attribute>, pub transparent: Option>, } @@ -50,6 +51,7 @@ pub fn get(input: &[Attribute]) -> Result { display: None, source: None, backtrace: None, + location: None, from: None, transparent: None, }; @@ -69,6 +71,12 @@ pub fn get(input: &[Attribute]) -> Result { return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute")); } attrs.backtrace = Some(attr); + } else if attr.path().is_ident("location") { + attr.meta.require_path_only()?; + if attrs.location.is_some() { + return Err(Error::new_spanned(attr, "duplicate #[location] attribute")); + } + attrs.location = Some(attr); } else if attr.path().is_ident("from") { match attr.meta { Meta::Path(_) => {} diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 1b44513..9da26bc 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -190,7 +190,7 @@ fn impl_struct(input: Struct) -> TokenStream { let from_impl = input.from_field().map(|from_field| { let backtrace_field = input.distinct_backtrace_field(); let from = unoptional_type(from_field.ty); - let body = from_initializer(from_field, backtrace_field); + let body = from_initializer(from_field, backtrace_field, None); quote! { #[allow(unused_qualifications)] impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { @@ -281,25 +281,43 @@ fn impl_enum(input: Enum) -> TokenStream { let request = quote!(request); let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; - match (variant.backtrace_field(), variant.source_field()) { - (Some(backtrace_field), Some(source_field)) - if backtrace_field.attrs.backtrace.is_none() => - { - let backtrace = &backtrace_field.member; - let source = &source_field.member; - let varsource = quote!(source); - let source_provide = if type_is_option(source_field.ty) { - quote_spanned! {source.member_span()=> - if let ::core::option::Option::Some(source) = #varsource { - source.thiserror_provide(#request); - } - } + + let mut arm = vec![]; + let mut head = vec![]; + + if let Some((source_field, backtrace_field)) = variant + .backtrace_field() + .and_then(|b| variant.source_field().take().map(|s| (s, b))) + .and_then(|(s, b)| { + // TODO: replace with take_if upon stabilization + + if s.member == b.member { + Some((s, b)) } else { - quote_spanned! {source.member_span()=> - #varsource.thiserror_provide(#request); + None + } + }) + { + let backtrace = &backtrace_field.member; + let varsource = quote!(source); + let source_provide = if type_is_option(source_field.ty) { + quote_spanned! {backtrace.member_span()=> + if let ::core::option::Option::Some(source) = #varsource { + source.thiserror_provide(#request); } - }; - let self_provide = if type_is_option(backtrace_field.ty) { + } + } else { + quote_spanned! {backtrace.member_span()=> + #varsource.thiserror_provide(#request); + } + }; + + head.push(quote! { #backtrace: #varsource }); + arm.push(quote! { #source_provide }); + } else { + if let Some(backtrace_field) = variant.backtrace_field() { + let backtrace = &backtrace_field.member; + let body = if type_is_option(backtrace_field.ty) { quote! { if let ::core::option::Option::Some(backtrace) = backtrace { #request.provide_ref::(backtrace); @@ -310,68 +328,64 @@ fn impl_enum(input: Enum) -> TokenStream { #request.provide_ref::(backtrace); } }; - quote! { - #ty::#ident { - #backtrace: backtrace, - #source: #varsource, - .. - } => { - use thiserror::__private::ThiserrorProvide as _; - #source_provide - #self_provide - } - } + + head.push(quote! { #backtrace: backtrace }); + arm.push(quote! { #body }); } - (Some(backtrace_field), Some(source_field)) - if backtrace_field.member == source_field.member => - { - let backtrace = &backtrace_field.member; + + if let Some(source_field) = variant.source_field() { + let source = &source_field.member; let varsource = quote!(source); + let source_provide = if type_is_option(source_field.ty) { - quote_spanned! {backtrace.member_span()=> + quote_spanned! {source.member_span()=> if let ::core::option::Option::Some(source) = #varsource { source.thiserror_provide(#request); } } } else { - quote_spanned! {backtrace.member_span()=> + quote_spanned! {source.member_span()=> #varsource.thiserror_provide(#request); } }; + + head.push(quote! { #source: #varsource }); + arm.push(quote! { #source_provide }); + } + } + + if let Some(location_field) = variant.location_field() { + let location = &location_field.member; + + let location_provide = if type_is_option(location_field.ty) { quote! { - #ty::#ident {#backtrace: #varsource, ..} => { - use thiserror::__private::ThiserrorProvide as _; - #source_provide + if let ::core::option::Option::Some(location) = location { + #request.provide_ref::<::core::panic::Location>(location); } } - } - (Some(backtrace_field), _) => { - let backtrace = &backtrace_field.member; - let body = if type_is_option(backtrace_field.ty) { - quote! { - if let ::core::option::Option::Some(backtrace) = backtrace { - #request.provide_ref::(backtrace); - } - } - } else { - quote! { - #request.provide_ref::(backtrace); - } - }; + } else { quote! { - #ty::#ident {#backtrace: backtrace, ..} => { - #body - } + #request.provide_ref::<::core::panic::Location>(location); } + }; + + head.push(quote! { #location: location }); + arm.push(quote! { #location_provide }); + } + + quote! { + #ty::#ident { + #(#head,)* + .. + } => { + use thiserror::__private::ThiserrorProvide as _; + #(#arm)* } - (None, _) => quote! { - #ty::#ident {..} => {} - }, } }); Some(quote! { fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) { - #[allow(deprecated)] + #[allow(deprecated, unused_imports)] match self { #(#arms)* } @@ -444,13 +458,17 @@ fn impl_enum(input: Enum) -> TokenStream { let from_impls = input.variants.iter().filter_map(|variant| { let from_field = variant.from_field()?; let backtrace_field = variant.distinct_backtrace_field(); + let location_field = variant.location_field(); let variant = &variant.ident; let from = unoptional_type(from_field.ty); - let body = from_initializer(from_field, backtrace_field); + let body = from_initializer(from_field, backtrace_field, location_field); + let track_caller = location_field.map(|_| quote!(#[track_caller])); + Some(quote! { #[allow(unused_qualifications)] impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { #[allow(deprecated)] + #track_caller fn from(source: #from) -> Self { #ty::#variant #body } @@ -501,7 +519,11 @@ fn use_as_display(needs_as_display: bool) -> Option { } } -fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream { +fn from_initializer( + from_field: &Field, + backtrace_field: Option<&Field>, + location_field: Option<&Field>, +) -> TokenStream { let from_member = &from_field.member; let some_source = if type_is_option(from_field.ty) { quote!(::core::option::Option::Some(source)) @@ -520,9 +542,23 @@ fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> Toke } } }); + let location = location_field.map(|location_field| { + let location_member = &location_field.member; + + if type_is_option(location_field.ty) { + quote! { + #location_member: ::core::option::Option::Some(::core::panic::Location::caller()), + } + } else { + quote! { + #location_member: ::core::convert::From::from(::core::panic::Location::caller()), + } + } + }); quote!({ #from_member: #some_source, #backtrace + #location }) } diff --git a/impl/src/lib.rs b/impl/src/lib.rs index 58f4bb5..678acf3 100644 --- a/impl/src/lib.rs +++ b/impl/src/lib.rs @@ -29,7 +29,7 @@ mod valid; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; -#[proc_macro_derive(Error, attributes(backtrace, error, from, source))] +#[proc_macro_derive(Error, attributes(backtrace, error, from, source, location))] pub fn derive_error(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand::derive(&input).into() diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 2867cd3..2ca30cf 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -62,6 +62,10 @@ impl Variant<'_> { backtrace_field(&self.fields) } + pub(crate) fn location_field(&self) -> Option<&Field> { + location_field(&self.fields) + } + pub(crate) fn distinct_backtrace_field(&self) -> Option<&Field> { let backtrace_field = self.backtrace_field()?; distinct_backtrace_field(backtrace_field, self.from_field()) @@ -73,6 +77,10 @@ impl Field<'_> { type_is_backtrace(self.ty) } + pub(crate) fn is_location(&self) -> bool { + type_is_location(self.ty) + } + pub(crate) fn source_span(&self) -> Span { if let Some(source_attr) = &self.attrs.source { source_attr.path().get_ident().unwrap().span() @@ -122,6 +130,20 @@ fn backtrace_field<'a, 'b>(fields: &'a [Field<'b>]) -> Option<&'a Field<'b>> { None } +fn location_field<'a, 'b>(fields: &'a [Field<'b>]) -> Option<&'a Field<'b>> { + for field in fields { + if field.attrs.location.is_some() { + return Some(field); + } + } + for field in fields { + if field.is_location() { + return Some(field); + } + } + None +} + // The #[backtrace] field, if it is not the same as the #[from] field. fn distinct_backtrace_field<'a, 'b>( backtrace_field: &'a Field<'b>, @@ -145,3 +167,13 @@ fn type_is_backtrace(ty: &Type) -> bool { let last = path.segments.last().unwrap(); last.ident == "Backtrace" && last.arguments.is_empty() } + +fn type_is_location(ty: &Type) -> bool { + let path = match ty { + Type::Path(ty) => &ty.path, + _ => return false, + }; + + let last = path.segments.last().unwrap(); + last.ident == "Location" && last.arguments.is_empty() +} diff --git a/impl/src/valid.rs b/impl/src/valid.rs index cf5b859..e353ad7 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -138,6 +138,7 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> { fn check_field_attrs(fields: &[Field]) -> Result<()> { let mut from_field = None; let mut source_field = None; + let mut location_field = None; let mut backtrace_field = None; let mut has_backtrace = false; for field in fields { @@ -163,6 +164,16 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { backtrace_field = Some(field); has_backtrace = true; } + if let Some(location) = field.attrs.location { + if location_field.is_some() { + return Err(Error::new_spanned( + location, + "duplicate #[location] attribute", + )); + } + + location_field = Some(field); + } if let Some(transparent) = field.attrs.transparent { return Err(Error::new_spanned( transparent.original, @@ -180,9 +191,11 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { } } if let Some(from_field) = from_field { - let max_expected_fields = match backtrace_field { - Some(backtrace_field) => 1 + !same_member(from_field, backtrace_field) as usize, - None => 1 + has_backtrace as usize, + let max_expected_fields = match (backtrace_field, location_field) { + (Some(backtrace), Some(_)) => 2 + !same_member(from_field, backtrace) as usize, + (Some(backtrace_field), None) => 1 + !same_member(from_field, backtrace_field) as usize, + (None, Some(_)) => 2 + has_backtrace as usize, + (None, None) => 1 + has_backtrace as usize, }; if fields.len() > max_expected_fields { return Err(Error::new_spanned( diff --git a/tests/test_location.rs b/tests/test_location.rs new file mode 100644 index 0000000..ace7ede --- /dev/null +++ b/tests/test_location.rs @@ -0,0 +1,24 @@ +use std::{fmt::Debug, io, panic::Location}; +use thiserror::Error; + +#[derive(Error, Debug)] +enum MError { + #[error("At {location}: location test error, sourced from {other}")] + Test { + #[location] + location: &'static Location<'static>, + #[from] + other: io::Error, + }, +} + +#[test] +#[should_panic] +fn test_enum() { + fn inner() -> Result<(), MError> { + Err(io::Error::new(io::ErrorKind::AddrInUse, String::new()))?; + Ok(()) + } + + inner().unwrap(); +} From 43eddc2d30575c5626105444ebec6298889ae611 Mon Sep 17 00:00:00 2001 From: Angad Tendulkar Date: Wed, 6 Mar 2024 08:47:25 -0500 Subject: [PATCH 2/4] impl for structs --- impl/src/expand.rs | 23 ++++++++++++++++++++++- impl/src/prop.rs | 4 ++++ tests/test_location.rs | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 9da26bc..eced50b 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -124,10 +124,28 @@ fn impl_struct(input: Struct) -> TokenStream { #request.provide_ref::(&self.#backtrace); }) }; + let location_provide = if let Some(location_field) = input.location_field() { + let location = &location_field.member; + + if type_is_option(location_field.ty) { + Some(quote! { + if let ::core::option::Option::Some(location) = &self.#location { + #request.provide_ref::<::core::panic::Location>(location); + } + }) + } else { + Some(quote! { + #request.provide_ref::<::core::panic::Location>(&self.#location); + }) + } + } else { + None + }; quote! { use thiserror::__private::ThiserrorProvide as _; #source_provide #self_provide + #location_provide } } else if type_is_option(backtrace_field.ty) { quote! { @@ -190,11 +208,14 @@ fn impl_struct(input: Struct) -> TokenStream { let from_impl = input.from_field().map(|from_field| { let backtrace_field = input.distinct_backtrace_field(); let from = unoptional_type(from_field.ty); - let body = from_initializer(from_field, backtrace_field, None); + let body = from_initializer(from_field, backtrace_field, input.location_field()); + let track_caller = input.location_field().map(|_| quote!(#[track_caller])); + quote! { #[allow(unused_qualifications)] impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { #[allow(deprecated)] + #track_caller fn from(source: #from) -> Self { #ty #body } diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 2ca30cf..397c679 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -16,6 +16,10 @@ impl Struct<'_> { backtrace_field(&self.fields) } + pub(crate) fn location_field(&self) -> Option<&Field> { + location_field(&self.fields) + } + pub(crate) fn distinct_backtrace_field(&self) -> Option<&Field> { let backtrace_field = self.backtrace_field()?; distinct_backtrace_field(backtrace_field, self.from_field()) diff --git a/tests/test_location.rs b/tests/test_location.rs index ace7ede..c7a3679 100644 --- a/tests/test_location.rs +++ b/tests/test_location.rs @@ -12,6 +12,15 @@ enum MError { }, } +#[derive(Error, Debug)] +#[error("Atlocation test error, sourced from {other}")] +pub struct TestError { + #[from] + other: io::Error, + #[location] + location: &'static Location<'static>, +} + #[test] #[should_panic] fn test_enum() { @@ -22,3 +31,14 @@ fn test_enum() { inner().unwrap(); } + +#[test] +#[should_panic] +fn test_struct() { + fn inner() -> Result<(), TestError> { + Err(io::Error::new(io::ErrorKind::AddrInUse, String::new()))?; + Ok(()) + } + + inner().unwrap(); +} From bcd35d71cf79d0ec8a302dcb1ca6d0c9deaabc17 Mon Sep 17 00:00:00 2001 From: Angad Tendulkar Date: Wed, 6 Mar 2024 09:53:33 -0500 Subject: [PATCH 3/4] do not require location attr --- impl/src/prop.rs | 26 +++++++++++++++++++++++--- impl/src/valid.rs | 8 ++++++-- tests/test_location.rs | 6 ++---- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 397c679..120a6be 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -1,7 +1,10 @@ use crate::ast::{Enum, Field, Struct, Variant}; use crate::span::MemberSpan; use proc_macro2::Span; -use syn::{Member, Type}; +use syn::{ + AngleBracketedGenericArguments, GenericArgument, Lifetime, Member, PathArguments, Type, + TypeReference, +}; impl Struct<'_> { pub(crate) fn from_field(&self) -> Option<&Field> { @@ -174,10 +177,27 @@ fn type_is_backtrace(ty: &Type) -> bool { fn type_is_location(ty: &Type) -> bool { let path = match ty { - Type::Path(ty) => &ty.path, + Type::Reference(TypeReference { + lifetime: Some(Lifetime { ident: ltident, .. }), + elem, // TODO: replace with `elem: box Type::Path(path)` once box_patterns stabalizes + .. + }) if ltident == "static" => match &**elem { + Type::Path(ty) => &ty.path, + _ => return false, + }, _ => return false, }; let last = path.segments.last().unwrap(); - last.ident == "Location" && last.arguments.is_empty() + + last.ident == "Location" + && match &last.arguments { + PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { + match args.first() { + Some(GenericArgument::Lifetime(Lifetime { ident, .. })) => ident == "static", + _ => false, + } + } + _ => false, + } } diff --git a/impl/src/valid.rs b/impl/src/valid.rs index e353ad7..212fa49 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -141,6 +141,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { let mut location_field = None; let mut backtrace_field = None; let mut has_backtrace = false; + let mut has_location = false; for field in fields { if let Some(from) = field.attrs.from { if from_field.is_some() { @@ -173,6 +174,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { } location_field = Some(field); + has_location = true; } if let Some(transparent) = field.attrs.transparent { return Err(Error::new_spanned( @@ -181,6 +183,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { )); } has_backtrace |= field.is_backtrace(); + has_location |= field.is_location(); } if let (Some(from_field), Some(source_field)) = (from_field, source_field) { if !same_member(from_field, source_field) { @@ -191,11 +194,12 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { } } if let Some(from_field) = from_field { + let extra_fields = has_backtrace as usize + has_location as usize; let max_expected_fields = match (backtrace_field, location_field) { (Some(backtrace), Some(_)) => 2 + !same_member(from_field, backtrace) as usize, (Some(backtrace_field), None) => 1 + !same_member(from_field, backtrace_field) as usize, - (None, Some(_)) => 2 + has_backtrace as usize, - (None, None) => 1 + has_backtrace as usize, + (None, Some(_)) => 1 + extra_fields, + (None, None) => 1 + extra_fields, }; if fields.len() > max_expected_fields { return Err(Error::new_spanned( diff --git a/tests/test_location.rs b/tests/test_location.rs index c7a3679..a8c7069 100644 --- a/tests/test_location.rs +++ b/tests/test_location.rs @@ -5,19 +5,17 @@ use thiserror::Error; enum MError { #[error("At {location}: location test error, sourced from {other}")] Test { - #[location] - location: &'static Location<'static>, #[from] other: io::Error, + location: &'static Location<'static>, }, } #[derive(Error, Debug)] -#[error("Atlocation test error, sourced from {other}")] +#[error("At {location} test error, sourced from {other}")] pub struct TestError { #[from] other: io::Error, - #[location] location: &'static Location<'static>, } From 7d8dae20ac65b5894916172c4794e78f5341613a Mon Sep 17 00:00:00 2001 From: Angad Tendulkar Date: Wed, 6 Mar 2024 09:58:58 -0500 Subject: [PATCH 4/4] update "deriving From..." message to include location --- impl/src/valid.rs | 2 +- tests/ui/from-backtrace-backtrace.stderr | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/impl/src/valid.rs b/impl/src/valid.rs index 212fa49..5710871 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -204,7 +204,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { if fields.len() > max_expected_fields { return Err(Error::new_spanned( from_field.attrs.from, - "deriving From requires no fields other than source and backtrace", + "deriving From requires no fields other than source, backtrace, and location", )); } } diff --git a/tests/ui/from-backtrace-backtrace.stderr b/tests/ui/from-backtrace-backtrace.stderr index 5c0b9a3..5ead7a5 100644 --- a/tests/ui/from-backtrace-backtrace.stderr +++ b/tests/ui/from-backtrace-backtrace.stderr @@ -1,4 +1,4 @@ -error: deriving From requires no fields other than source and backtrace +error: deriving From requires no fields other than source, backtrace, and location --> tests/ui/from-backtrace-backtrace.rs:9:5 | 9 | #[from]