Skip to content

Commit

Permalink
Merge pull request #593 from ShoyuVanilla/issue-577
Browse files Browse the repository at this point in the history
fix: Replace `SelfTy` with actual type in tracked methods
  • Loading branch information
davidbarsky authored Oct 15, 2024
2 parents b14be5c + 7e3426e commit c6c51a0
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 4 deletions.
27 changes: 24 additions & 3 deletions components/salsa-macros/src/tracked_impl.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::HashSet;

use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::parse::Nothing;
use syn::{parse::Nothing, visit_mut::VisitMut};

use crate::{hygiene::Hygiene, tracked_fn::FnArgs};
use crate::{hygiene::Hygiene, tracked_fn::FnArgs, xform::ChangeSelfPath};

pub(crate) fn tracked_impl(
args: proc_macro::TokenStream,
Expand Down Expand Up @@ -32,8 +34,19 @@ struct MethodArguments<'syn> {
impl Macro {
fn try_generate(&self, mut impl_item: syn::ItemImpl) -> syn::Result<TokenStream> {
let mut member_items = std::mem::take(&mut impl_item.items);
let member_idents: HashSet<_> = member_items
.iter()
.filter_map(|item| match item {
syn::ImplItem::Const(it) => Some(it.ident.clone()),
syn::ImplItem::Fn(it) => Some(it.sig.ident.clone()),
syn::ImplItem::Type(it) => Some(it.ident.clone()),
syn::ImplItem::Macro(_) => None,
syn::ImplItem::Verbatim(_) => None,
_ => None,
})
.collect();
for member_item in &mut member_items {
self.modify_member(&impl_item, member_item)?;
self.modify_member(&impl_item, member_item, &member_idents)?;
}
impl_item.items = member_items;
Ok(crate::debug::dump_tokens(
Expand All @@ -47,6 +60,7 @@ impl Macro {
&self,
impl_item: &syn::ItemImpl,
member_item: &mut syn::ImplItem,
member_idents: &HashSet<syn::Ident>,
) -> syn::Result<()> {
let syn::ImplItem::Fn(fn_item) = member_item else {
return Ok(());
Expand All @@ -59,6 +73,13 @@ impl Macro {
return Ok(());
};

let trait_ = match &impl_item.trait_ {
Some((None, path, _)) => Some((path, member_idents)),
_ => None,
};
let mut change = ChangeSelfPath::new(self_ty, trait_);
change.visit_impl_item_fn_mut(fn_item);

let salsa_tracked_attr = fn_item.attrs.remove(tracked_attr_index);
let args: FnArgs = match &salsa_tracked_attr.meta {
syn::Meta::Path(..) => Default::default(),
Expand Down
117 changes: 116 additions & 1 deletion components/salsa-macros/src/xform.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use syn::visit_mut::VisitMut;
use std::collections::HashSet;

use quote::ToTokens;
use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut};

pub(crate) struct ChangeLt<'a> {
from: Option<&'a str>,
Expand All @@ -12,6 +15,7 @@ impl ChangeLt<'_> {
to: db_lt.ident.to_string(),
}
}

pub fn in_type(mut self, ty: &syn::Type) -> syn::Type {
let mut ty = ty.clone();
self.visit_type_mut(&mut ty);
Expand All @@ -26,3 +30,114 @@ impl syn::visit_mut::VisitMut for ChangeLt<'_> {
}
}
}

pub(crate) struct ChangeSelfPath<'a> {
self_ty: &'a syn::Type,
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
}

impl ChangeSelfPath<'_> {
pub fn new<'a>(
self_ty: &'a syn::Type,
trait_: Option<(&'a syn::Path, &'a HashSet<syn::Ident>)>,
) -> ChangeSelfPath<'a> {
ChangeSelfPath { self_ty, trait_ }
}
}

impl syn::visit_mut::VisitMut for ChangeSelfPath<'_> {
fn visit_type_mut(&mut self, i: &mut syn::Type) {
if let syn::Type::Path(syn::TypePath { qself: None, path }) = i {
if path.segments.len() == 1 && path.segments.first().is_some_and(|s| s.ident == "Self")
{
let span = path.segments.first().unwrap().span();
*i = respan(self.self_ty, span);
}
}
syn::visit_mut::visit_type_mut(self, i);
}

fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
// `<Self as ..>` cases are handled in `visit_type_mut`
if i.qself.is_some() {
syn::visit_mut::visit_type_path_mut(self, i);
return;
}

// A single path `Self` case is handled in `visit_type_mut`
if i.path.segments.first().is_some_and(|s| s.ident == "Self") && i.path.segments.len() > 1 {
let span = i.path.segments.first().unwrap().span();
let ty = Box::new(respan::<syn::Type>(self.self_ty, span));
let lt_token = syn::Token![<](span);
let gt_token = syn::Token![>](span);
match self.trait_ {
// If the next segment's ident is a trait member, replace `Self::` with
// `<ActualTy as Trait>::`
Some((trait_, member_idents))
if member_idents.contains(&i.path.segments.iter().nth(1).unwrap().ident) =>
{
let qself = syn::QSelf {
lt_token,
ty,
position: trait_.segments.len(),
as_token: Some(syn::Token![as](span)),
gt_token,
};
i.qself = Some(qself);
i.path.segments = Punctuated::from_iter(
trait_
.segments
.iter()
.chain(i.path.segments.iter().skip(1))
.cloned(),
);
}
// Replace `Self::` with `<ActualTy>::` otherwise
_ => {
let qself = syn::QSelf {
lt_token,
ty,
position: 0,
as_token: None,
gt_token,
};
i.qself = Some(qself);
i.path.segments =
Punctuated::from_iter(i.path.segments.iter().skip(1).cloned());
}
}
}

syn::visit_mut::visit_type_path_mut(self, i);
}
}

fn respan<T>(t: &T, span: proc_macro2::Span) -> T
where
T: ToTokens + Spanned + syn::parse::Parse,
{
let tokens = t.to_token_stream();
let respanned = respan_tokenstream(tokens, span);
syn::parse2(respanned).unwrap()
}

fn respan_tokenstream(
stream: proc_macro2::TokenStream,
span: proc_macro2::Span,
) -> proc_macro2::TokenStream {
stream
.into_iter()
.map(|token| respan_token(token, span))
.collect()
}

fn respan_token(
mut token: proc_macro2::TokenTree,
span: proc_macro2::Span,
) -> proc_macro2::TokenTree {
if let proc_macro2::TokenTree::Group(g) = &mut token {
*g = proc_macro2::Group::new(g.delimiter(), respan_tokenstream(g.stream(), span));
}
token.set_span(span);
token
}
44 changes: 44 additions & 0 deletions tests/tracked_method_with_self_ty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! Test that a `tracked` fn with `Self` in its signature or body on a `salsa::input`
//! compiles and executes successfully.
#![allow(warnings)]

trait TrackedTrait {
type Type;

fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type;

fn untracked_trait_fn();
}

#[salsa::input]
struct MyInput {
field: u32,
}

#[salsa::tracked]
impl MyInput {
#[salsa::tracked]
fn tracked_fn(self, db: &dyn salsa::Database, other: Self) -> u32 {
self.field(db) + other.field(db)
}
}

#[salsa::tracked]
impl TrackedTrait for MyInput {
type Type = u32;

#[salsa::tracked]
fn tracked_trait_fn(self, db: &dyn salsa::Database, ty: Self::Type) -> Self::Type {
Self::untracked_trait_fn();
Self::tracked_fn(self, db, self) + ty
}

fn untracked_trait_fn() {}
}

#[test]
fn execute() {
let mut db = salsa::DatabaseImpl::new();
let object = MyInput::new(&mut db, 10);
assert_eq!(object.tracked_trait_fn(&db, 1), 21);
}

0 comments on commit c6c51a0

Please sign in to comment.