From e7f89a7eea766af95604405d666e089a42cd4b48 Mon Sep 17 00:00:00 2001 From: Ding Xiang Fei Date: Sat, 13 Jul 2024 20:14:22 +0800 Subject: [PATCH] derive(SmartPointer): rewrite bounds in where and generic bounds --- .../src/deriving/smart_ptr.rs | 244 ++++++++++++++++-- .../deriving-smart-pointer-expanded.rs | 22 ++ .../deriving-smart-pointer-expanded.stdout | 44 ++++ .../smart-pointer-bounds-issue-127647.rs | 78 ++++++ 4 files changed, 371 insertions(+), 17 deletions(-) create mode 100644 tests/ui/deriving/deriving-smart-pointer-expanded.rs create mode 100644 tests/ui/deriving/deriving-smart-pointer-expanded.stdout create mode 100644 tests/ui/deriving/smart-pointer-bounds-issue-127647.rs diff --git a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs index bbc7cd3962720..02555bd799c42 100644 --- a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs +++ b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs @@ -1,14 +1,18 @@ use std::mem::swap; +use ast::ptr::P; use ast::HasAttrs; +use rustc_ast::mut_visit::MutVisitor; +use rustc_ast::visit::BoundKind; use rustc_ast::{ self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem, - TraitBoundModifiers, VariantData, + TraitBoundModifiers, VariantData, WherePredicate, }; use rustc_attr as attr; +use rustc_data_structures::flat_map_in_place::FlatMapInPlace; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{sym, Ident}; -use rustc_span::Span; +use rustc_span::{Span, Symbol}; use smallvec::{smallvec, SmallVec}; use thin_vec::{thin_vec, ThinVec}; @@ -141,33 +145,239 @@ pub fn expand_deriving_smart_ptr( alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone()); let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params)); + // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location + // // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. let mut impl_generics = generics.clone(); + let pointee_ty_ident = generics.params[pointee_param_idx].ident; + let mut self_bounds; { - let p = &mut impl_generics.params[pointee_param_idx]; + let pointee = &mut impl_generics.params[pointee_param_idx]; + self_bounds = pointee.bounds.clone(); let arg = GenericArg::Type(s_ty.clone()); let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]); - p.bounds.push(cx.trait_bound(unsize, false)); + pointee.bounds.push(cx.trait_bound(unsize, false)); let mut attrs = thin_vec![]; - swap(&mut p.attrs, &mut attrs); - p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); + swap(&mut pointee.attrs, &mut attrs); + // Drop `#[pointee]` attribute since it should not be recognized outside `derive(SmartPointer)` + pointee.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); } - // Add the `__S: ?Sized` extra parameter to the impl block. + // # Rewrite generic parameter bounds + // For each bound `U: ..` in `struct`, make a new bound with `__S` in place of `#[pointee]` + // Example: + // ``` + // struct< + // U: Trait, + // #[pointee] T: Trait, + // V: Trait> ... + // ``` + // ... generates this `impl` generic parameters + // ``` + // impl< + // U: Trait + Trait<__S>, + // T: Trait + Unsize<__S>, // (**) + // __S: Trait<__S> + ?Sized, // (*) + // V: Trait + Trait<__S>> ... + // ``` + // The new bound marked with (*) has to be done separately. + // See next section + for (idx, (params, orig_params)) in + impl_generics.params.iter_mut().zip(&generics.params).enumerate() + { + // Default type parameters are rejected for `impl` block. + // We should drop them now. + match &mut params.kind { + ast::GenericParamKind::Const { default, .. } => *default = None, + ast::GenericParamKind::Type { default } => *default = None, + ast::GenericParamKind::Lifetime => {} + } + // We CANNOT rewrite `#[pointee]` type parameter bounds. + // This has been set in stone. (**) + // So we skip over it. + // Otherwise, we push extra bounds involving `__S`. + if idx != pointee_param_idx { + for bound in &orig_params.bounds { + let mut bound = bound.clone(); + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: false, + }; + substitution.visit_param_bound(&mut bound, BoundKind::Bound); + if substitution.rewritten { + // We found use of `#[pointee]` somewhere, + // so we make a new bound using `__S` in place of `#[pointee]` + params.bounds.push(bound); + } + } + } + } + + // # Insert `__S` type parameter + // + // We now insert `__S` with the missing bounds marked with (*) above. + // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`. let sized = cx.path_global(span, path!(span, core::marker::Sized)); - let bound = GenericBound::Trait( - cx.poly_trait_ref(span, sized), - TraitBoundModifiers { - polarity: ast::BoundPolarity::Maybe(span), - constness: ast::BoundConstness::Never, - asyncness: ast::BoundAsyncness::Normal, - }, - ); - let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None); - impl_generics.params.push(extra_param); + // For some reason, we are not allowed to write `?Sized` bound twice like `__S: ?Sized + ?Sized`. + if !contains_maybe_sized_bound(&self_bounds) + && !contains_maybe_sized_bound_on_pointee( + &generics.where_clause.predicates, + pointee_ty_ident.name, + ) + { + self_bounds.push(GenericBound::Trait( + cx.poly_trait_ref(span, sized), + TraitBoundModifiers { + polarity: ast::BoundPolarity::Maybe(span), + constness: ast::BoundConstness::Never, + asyncness: ast::BoundAsyncness::Normal, + }, + )); + } + { + let mut substitution = + TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false }; + for bound in &mut self_bounds { + substitution.visit_param_bound(bound, BoundKind::Bound); + } + } + + // # Rewrite `where` clauses + // + // Move on to `where` clauses. + // Example: + // ``` + // struct MyPointer<#[pointee] T, ..> + // where + // U: Trait + Trait, + // Companion: Trait, + // T: Trait, + // { .. } + // ``` + // ... will have a impl prelude like so + // ``` + // impl<..> .. + // where + // U: Trait + Trait, + // U: Trait<__S>, + // Companion: Trait, + // Companion<__S>: Trait<__S>, + // T: Trait, + // __S: Trait<__S>, + // ``` + // + // We should also write a few new `where` bounds from `#[pointee] T` to `__S` + // as well as any bound that indirectly involves the `#[pointee] T` type. + for bound in &generics.where_clause.predicates { + if let ast::WherePredicate::BoundPredicate(bound) = bound { + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: false, + }; + let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate { + span: bound.span, + bound_generic_params: bound.bound_generic_params.clone(), + bounded_ty: bound.bounded_ty.clone(), + bounds: bound.bounds.clone(), + }); + substitution.visit_where_predicate(&mut predicate); + if substitution.rewritten { + impl_generics.where_clause.predicates.push(predicate); + } + } + } + + let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None); + impl_generics.params.insert(pointee_param_idx + 1, extra_param); // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`. let gen_args = vec![GenericArg::Type(alt_self_type.clone())]; add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone()); add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone()); } + +fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool { + for bound in predicates { + if let ast::WherePredicate::BoundPredicate(bound) = bound + && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee) + { + for bound in &bound.bounds { + if is_maybe_sized_bound(bound) { + return true; + } + } + } + } + false +} + +fn is_maybe_sized_bound(bound: &GenericBound) -> bool { + if let GenericBound::Trait( + trait_ref, + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, + ) = bound + { + is_sized_marker(&trait_ref.trait_ref.path) + } else { + false + } +} + +fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool { + bounds.iter().any(is_maybe_sized_bound) +} + +fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool { + path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol) +} + +fn is_sized_marker(path: &ast::Path) -> bool { + const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized]; + const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized]; + if path.segments.len() == 4 && path.is_global() { + path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE) + || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE) + } else if path.segments.len() == 3 { + path_segment_is_exact_match(&path.segments, &CORE_UNSIZE) + || path_segment_is_exact_match(&path.segments, &STD_UNSIZE) + } else { + *path == sym::Sized + } +} + +struct TypeSubstitution<'a> { + from_name: Symbol, + to_ty: &'a ast::Ty, + rewritten: bool, +} + +impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> { + fn visit_ty(&mut self, ty: &mut P) { + if let Some(name) = ty.kind.is_simple_path() + && name == self.from_name + { + **ty = self.to_ty.clone(); + self.rewritten = true; + } else { + ast::mut_visit::walk_ty(self, ty); + } + } + + fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) { + match where_predicate { + rustc_ast::WherePredicate::BoundPredicate(bound) => { + bound + .bound_generic_params + .flat_map_in_place(|param| self.flat_map_generic_param(param)); + self.visit_ty(&mut bound.bounded_ty); + for bound in &mut bound.bounds { + self.visit_param_bound(bound, BoundKind::Bound) + } + } + rustc_ast::WherePredicate::RegionPredicate(_) + | rustc_ast::WherePredicate::EqPredicate(_) => {} + } + } +} diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.rs b/tests/ui/deriving/deriving-smart-pointer-expanded.rs new file mode 100644 index 0000000000000..b78258c25290e --- /dev/null +++ b/tests/ui/deriving/deriving-smart-pointer-expanded.rs @@ -0,0 +1,22 @@ +//@ check-pass +//@ compile-flags: -Zunpretty=expanded +#![feature(derive_smart_pointer)] +use std::marker::SmartPointer; + +pub trait MyTrait {} + +#[derive(SmartPointer)] +#[repr(transparent)] +struct MyPointer<'a, #[pointee] T: ?Sized> { + ptr: &'a T, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct MyPointer2<'a, Y, Z: MyTrait, #[pointee] T: ?Sized + MyTrait, X: MyTrait = ()> +where + Y: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.stdout b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout new file mode 100644 index 0000000000000..3c7e719818042 --- /dev/null +++ b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout @@ -0,0 +1,44 @@ +#![feature(prelude_import)] +#![no_std] +//@ check-pass +//@ compile-flags: -Zunpretty=expanded +#![feature(derive_smart_pointer)] +#[prelude_import] +use ::std::prelude::rust_2015::*; +#[macro_use] +extern crate std; +use std::marker::SmartPointer; + +pub trait MyTrait {} + +#[repr(transparent)] +struct MyPointer<'a, #[pointee] T: ?Sized> { + ptr: &'a T, +} +#[automatically_derived] +impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized> + ::core::ops::DispatchFromDyn> for MyPointer<'a, T> { +} +#[automatically_derived] +impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized> + ::core::ops::CoerceUnsized> for MyPointer<'a, T> { +} + +#[repr(transparent)] +pub struct MyPointer2<'a, Y, Z: MyTrait, #[pointee] T: ?Sized + MyTrait, + X: MyTrait = ()> where Y: MyTrait { + data: &'a mut T, + x: core::marker::PhantomData, +} +#[automatically_derived] +impl<'a, Y, Z: MyTrait + MyTrait<__S>, T: ?Sized + MyTrait + + ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait + + MyTrait<__S>> ::core::ops::DispatchFromDyn> + for MyPointer2<'a, Y, Z, T, X> where Y: MyTrait, Y: MyTrait<__S> { +} +#[automatically_derived] +impl<'a, Y, Z: MyTrait + MyTrait<__S>, T: ?Sized + MyTrait + + ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait + + MyTrait<__S>> ::core::ops::CoerceUnsized> for + MyPointer2<'a, Y, Z, T, X> where Y: MyTrait, Y: MyTrait<__S> { +} diff --git a/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs new file mode 100644 index 0000000000000..4cae1b32896fb --- /dev/null +++ b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs @@ -0,0 +1,78 @@ +//@ check-pass + +#![feature(derive_smart_pointer)] + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait OnDrop { + fn on_drop(&mut self); +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr2<'a, #[pointee] T: ?Sized, X> +where + T: OnDrop, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait MyTrait {} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr3<'a, #[pointee] T: ?Sized, X> +where + T: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr4<'a, #[pointee] T: MyTrait + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr5<'a, #[pointee] T: ?Sized, X> +where + Ptr5Companion: MyTrait, + Ptr5Companion2: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub struct Ptr5Companion(core::marker::PhantomData); +pub struct Ptr5Companion2; + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait = (), const PARAM: usize = 0> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/ +#[repr(transparent)] +#[derive(core::marker::SmartPointer)] +pub struct ListArc<#[pointee] T, const ID: u64 = 0> +where + T: ListArcSafe + ?Sized, +{ + arc: *const T, +} + +pub trait ListArcSafe {} + +fn main() {}