From fac605376128a7983eb8646bc0240283e6f2f075 Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Thu, 14 Nov 2024 00:01:05 +0900 Subject: [PATCH 1/5] Assign memo ingredients per salsa-struct-ingredient --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 19 +++++++++++ src/accumulator.rs | 4 +++ src/function.rs | 4 +-- src/ingredient.rs | 11 +++++- src/input.rs | 10 +++++- src/interned.rs | 5 +++ src/tracked_struct.rs | 9 +++-- src/zalsa.rs | 34 +++++++++++++++---- 8 files changed, 84 insertions(+), 12 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 79ec96d5..1a23ebea 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -199,7 +199,22 @@ macro_rules! setup_tracked_fn { aux: &dyn $zalsa::JarAux, first_index: $zalsa::IngredientIndex, ) -> Vec> { + let struct_index = $zalsa::macro_if! { + if $needs_interner { + first_index.successor(0) + } else { + aux + .lookup_struct_ingredient_index( + core::any::TypeId::of::<$InternedData<'static>>() + ) + .expect( + "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" + ) + } + }; + let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( + struct_index, first_index, aux, ); @@ -219,6 +234,10 @@ macro_rules! setup_tracked_fn { } } } + + fn salsa_struct_type_id(&self) -> Option { + None + } } #[allow(non_local_definitions)] diff --git a/src/accumulator.rs b/src/accumulator.rs index b9355419..cd566f40 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -53,6 +53,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(>::new(first_index))] } + + fn salsa_struct_type_id(&self) -> Option { + None + } } pub struct IngredientImpl { diff --git a/src/function.rs b/src/function.rs index 07f13d49..b06be486 100644 --- a/src/function.rs +++ b/src/function.rs @@ -126,10 +126,10 @@ impl IngredientImpl where C: Configuration, { - pub fn new(index: IngredientIndex, aux: &dyn JarAux) -> Self { + pub fn new(struct_index: IngredientIndex, index: IngredientIndex, aux: &dyn JarAux) -> Self { Self { index, - memo_ingredient_index: aux.next_memo_ingredient_index(index), + memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index), lru: Default::default(), deleted_entries: Default::default(), } diff --git a/src/ingredient.rs b/src/ingredient.rs index 383fdc6b..12601b16 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -23,10 +23,19 @@ pub trait Jar: Any { aux: &dyn JarAux, first_index: IngredientIndex, ) -> Vec>; + + /// If this jar's first ingredient is a salsa struct, return its `TypeId` + fn salsa_struct_type_id(&self) -> Option; } pub trait JarAux { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex; + fn lookup_struct_ingredient_index(&self, type_id: TypeId) -> Option; + + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex; } pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { diff --git a/src/input.rs b/src/input.rs index fdad27ac..4b82ef2f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,4 +1,8 @@ -use std::{any::Any, fmt, ops::DerefMut}; +use std::{ + any::{Any, TypeId}, + fmt, + ops::DerefMut, +}; pub mod input_field; pub mod setter; @@ -60,6 +64,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct>()) + } } pub struct IngredientImpl { diff --git a/src/interned.rs b/src/interned.rs index 0c6d32cd..f6767fb1 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -9,6 +9,7 @@ use crate::table::Slot; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; +use std::any::TypeId; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::marker::PhantomData; @@ -92,6 +93,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } impl IngredientImpl diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 540bb765..8cca21e9 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; +use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; use crossbeam::{atomic::AtomicCell, queue::SegQueue}; use tracked_field::FieldIngredientImpl; @@ -112,6 +112,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } pub trait TrackedStructInDb: SalsaStructInDb { @@ -501,7 +505,8 @@ where // and the code that references the memo-table has a read-lock. let memo_table = unsafe { (*data).take_memo_table() }; for (memo_ingredient_index, memo) in memo_table.into_memos() { - let ingredient_index = zalsa.ingredient_index_for_memo(memo_ingredient_index); + let ingredient_index = + zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); let executor = DatabaseKeyIndex { ingredient_index, diff --git a/src/zalsa.rs b/src/zalsa.rs index 2f8fa95f..fbe14eed 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,5 +1,5 @@ use append_only_vec::AppendOnlyVec; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; use std::marker::PhantomData; @@ -119,8 +119,10 @@ pub struct Zalsa { nonce: Nonce, - /// Number of memo ingredient indices created by calls to [`next_memo_ingredient_index`](`Self::next_memo_ingredient_index`) - memo_ingredients: Mutex>, + /// Map from the [`IngredientIndex`][] of a salsa struct to a list of + /// [ingredient-indices](`IngredientIndex`)for tracked functions that have this salsa struct + /// as input. + memo_ingredients: RwLock>>, /// Map from the type-id of an `impl Jar` to the index of its first ingredient. /// This is using a `Mutex` (versus, say, a `FxDashMap`) @@ -130,6 +132,9 @@ pub struct Zalsa { /// adding new kinds of ingredients. jar_map: Mutex>, + /// Map from the type-id of a salsa struct to the index of its first ingredient. + salsa_struct_map: Mutex>, + /// Vector of ingredients. /// /// Immutable unless the mutex on `ingredients_map` is held. @@ -149,6 +154,7 @@ impl Zalsa { views_of: Views::new::(), nonce: NONCE.nonce(), jar_map: Default::default(), + salsa_struct_map: Default::default(), ingredients_vec: AppendOnlyVec::new(), ingredients_requiring_reset: AppendOnlyVec::new(), runtime: Runtime::default(), @@ -211,6 +217,9 @@ impl Zalsa { ); } + if let Some(type_id) = jar.salsa_struct_type_id() { + self.salsa_struct_map.lock().insert(type_id, index); + } index }) } @@ -290,15 +299,28 @@ impl Zalsa { pub(crate) fn ingredient_index_for_memo( &self, + struct_ingredient_index: IngredientIndex, memo_ingredient_index: MemoIngredientIndex, ) -> IngredientIndex { - self.memo_ingredients.lock()[memo_ingredient_index.as_usize()] + self.memo_ingredients.read()[&struct_ingredient_index][memo_ingredient_index.as_usize()] } } impl JarAux for Zalsa { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex { - let mut memo_ingredients = self.memo_ingredients.lock(); + fn lookup_struct_ingredient_index(&self, type_id: TypeId) -> Option { + self.salsa_struct_map + .lock() + .get(&type_id) + .map(ToOwned::to_owned) + } + + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex { + let mut memo_ingredients = self.memo_ingredients.write(); + let memo_ingredients = memo_ingredients.entry(struct_ingredient_index).or_default(); let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index); mi From 2cae1ef5bb36b7eedae89179fe0effbe8f1fc0ee Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Fri, 15 Nov 2024 01:40:45 +0900 Subject: [PATCH 2/5] Modify struct ingredient lookups --- .../src/setup_input_struct.rs | 3 + .../src/setup_interned_struct.rs | 3 + .../salsa-macro-rules/src/setup_tracked_fn.rs | 8 +-- .../src/setup_tracked_struct.rs | 3 + src/ingredient.rs | 2 +- src/salsa_struct.rs | 6 +- src/zalsa.rs | 67 ++++++++++--------- 7 files changed, 55 insertions(+), 37 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 51cd482b..b506b28b 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -124,6 +124,9 @@ macro_rules! setup_input_struct { } impl $zalsa::SalsaStructInDb for $Struct { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $Struct { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index bf9d98f5..e8b8af18 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -141,6 +141,9 @@ macro_rules! setup_interned_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } unsafe impl $zalsa::Update for $Struct<'_> { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 1a23ebea..ce5d313d 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -99,6 +99,9 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientCache::new(); impl $zalsa::SalsaStructInDb for $InternedData<'_> { + fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + None + } } impl $zalsa::interned::Configuration for $Configuration { @@ -203,10 +206,7 @@ macro_rules! setup_tracked_fn { if $needs_interner { first_index.successor(0) } else { - aux - .lookup_struct_ingredient_index( - core::any::TypeId::of::<$InternedData<'static>>() - ) + <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux) .expect( "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" ) diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index d0d42c6d..a783e376 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -152,6 +152,9 @@ macro_rules! setup_tracked_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $zalsa::TrackedStructInDb for $Struct<'_> { diff --git a/src/ingredient.rs b/src/ingredient.rs index 12601b16..f20c95ef 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -29,7 +29,7 @@ pub trait Jar: Any { } pub trait JarAux { - fn lookup_struct_ingredient_index(&self, type_id: TypeId) -> Option; + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; fn next_memo_ingredient_index( &self, diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index fcf7920a..8674dc12 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1 +1,5 @@ -pub trait SalsaStructInDb {} +use crate::{plumbing::JarAux, IngredientIndex}; + +pub trait SalsaStructInDb { + fn lookup_ingredient_index(aux: &dyn JarAux) -> Option; +} diff --git a/src/zalsa.rs b/src/zalsa.rs index fbe14eed..0d7ccddc 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -119,10 +119,10 @@ pub struct Zalsa { nonce: Nonce, - /// Map from the [`IngredientIndex`][] of a salsa struct to a list of - /// [ingredient-indices](`IngredientIndex`)for tracked functions that have this salsa struct + /// Map from the [`IngredientIndex::as_usize`][] of a salsa struct to a list of + /// [ingredient-indices](`IngredientIndex`) for tracked functions that have this salsa struct /// as input. - memo_ingredients: RwLock>>, + memo_ingredient_indices: RwLock>>, /// Map from the type-id of an `impl Jar` to the index of its first ingredient. /// This is using a `Mutex` (versus, say, a `FxDashMap`) @@ -132,9 +132,6 @@ pub struct Zalsa { /// adding new kinds of ingredients. jar_map: Mutex>, - /// Map from the type-id of a salsa struct to the index of its first ingredient. - salsa_struct_map: Mutex>, - /// Vector of ingredients. /// /// Immutable unless the mutex on `ingredients_map` is held. @@ -154,11 +151,10 @@ impl Zalsa { views_of: Views::new::(), nonce: NONCE.nonce(), jar_map: Default::default(), - salsa_struct_map: Default::default(), ingredients_vec: AppendOnlyVec::new(), ingredients_requiring_reset: AppendOnlyVec::new(), runtime: Runtime::default(), - memo_ingredients: Default::default(), + memo_ingredient_indices: Default::default(), } } @@ -192,11 +188,14 @@ impl Zalsa { { let jar_type_id = jar.type_id(); let mut jar_map = self.jar_map.lock(); - *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(self, index); + let mut should_create = false; + let index = *jar_map.entry(jar_type_id).or_insert_with(|| { + should_create = true; + IngredientIndex::from(self.ingredients_vec.len()) + }); + if should_create { + let aux = JarAuxImpl(self, &jar_map); + let ingredients = jar.create_ingredients(&aux, index); for ingredient in ingredients { let expected_index = ingredient.ingredient_index(); @@ -204,9 +203,7 @@ impl Zalsa { self.ingredients_requiring_reset.push(expected_index); } - let actual_index = self - .ingredients_vec - .push(ingredient); + let actual_index = self.ingredients_vec.push(ingredient); assert_eq!( expected_index.as_usize(), actual_index, @@ -215,13 +212,10 @@ impl Zalsa { expected_index, actual_index, ); - } - if let Some(type_id) = jar.salsa_struct_type_id() { - self.salsa_struct_map.lock().insert(type_id, index); - } - index - }) + } + + index } } @@ -302,16 +296,16 @@ impl Zalsa { struct_ingredient_index: IngredientIndex, memo_ingredient_index: MemoIngredientIndex, ) -> IngredientIndex { - self.memo_ingredients.read()[&struct_ingredient_index][memo_ingredient_index.as_usize()] + self.memo_ingredient_indices.read()[struct_ingredient_index.as_usize()] + [memo_ingredient_index.as_usize()] } } -impl JarAux for Zalsa { - fn lookup_struct_ingredient_index(&self, type_id: TypeId) -> Option { - self.salsa_struct_map - .lock() - .get(&type_id) - .map(ToOwned::to_owned) +struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); + +impl<'a> JarAux for JarAuxImpl<'a> { + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { + self.1.get(&jar.type_id()).map(ToOwned::to_owned) } fn next_memo_ingredient_index( @@ -319,8 +313,19 @@ impl JarAux for Zalsa { struct_ingredient_index: IngredientIndex, ingredient_index: IngredientIndex, ) -> MemoIngredientIndex { - let mut memo_ingredients = self.memo_ingredients.write(); - let memo_ingredients = memo_ingredients.entry(struct_ingredient_index).or_default(); + let mut memo_ingredients = self.0.memo_ingredient_indices.write(); + let memo_ingredients = if let Some(memo_ingredients) = + memo_ingredients.get_mut(struct_ingredient_index.as_usize()) + { + memo_ingredients + } else { + while memo_ingredients.len() <= struct_ingredient_index.as_usize() { + memo_ingredients.push(Vec::new()); + } + memo_ingredients + .get_mut(struct_ingredient_index.as_usize()) + .unwrap() + }; let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index); mi From d34ed1495e72900f891da2e2d7ee8e1bc895c7dc Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Fri, 15 Nov 2024 01:41:18 +0900 Subject: [PATCH 3/5] Add a test for tracked function with multiple salsa struct args --- tests/tracked_fn_multiple_args.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/tracked_fn_multiple_args.rs diff --git a/tests/tracked_fn_multiple_args.rs b/tests/tracked_fn_multiple_args.rs new file mode 100644 index 00000000..7c014356 --- /dev/null +++ b/tests/tracked_fn_multiple_args.rs @@ -0,0 +1,25 @@ +//! Test that a `tracked` fn on multiple salsa struct args +//! compiles and executes successfully. + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput, interned: MyInterned<'db>) -> u32 { + input.field(db) + interned.field(db) +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + let input = MyInput::new(&db, 22); + let interned = MyInterned::new(&db, 33); + assert_eq!(tracked_fn(&db, input, interned), 55); +} From 738d5f94037b5a8bd47f38ee5b639e66cc686e2f Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Fri, 15 Nov 2024 01:46:08 +0900 Subject: [PATCH 4/5] Fix clippy --- src/zalsa.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zalsa.rs b/src/zalsa.rs index 0d7ccddc..b15ac80d 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -303,7 +303,7 @@ impl Zalsa { struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); -impl<'a> JarAux for JarAuxImpl<'a> { +impl JarAux for JarAuxImpl<'_> { fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { self.1.get(&jar.type_id()).map(ToOwned::to_owned) } From 38ad6555ebe4374bbc69022e005f4136a1839ad9 Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Wed, 20 Nov 2024 10:19:28 +0900 Subject: [PATCH 5/5] Clean up code a bit --- src/zalsa.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/zalsa.rs b/src/zalsa.rs index b15ac80d..a1416f6d 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -314,17 +314,12 @@ impl JarAux for JarAuxImpl<'_> { ingredient_index: IngredientIndex, ) -> MemoIngredientIndex { let mut memo_ingredients = self.0.memo_ingredient_indices.write(); - let memo_ingredients = if let Some(memo_ingredients) = - memo_ingredients.get_mut(struct_ingredient_index.as_usize()) - { + let idx = struct_ingredient_index.as_usize(); + let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { memo_ingredients } else { - while memo_ingredients.len() <= struct_ingredient_index.as_usize() { - memo_ingredients.push(Vec::new()); - } - memo_ingredients - .get_mut(struct_ingredient_index.as_usize()) - .unwrap() + memo_ingredients.resize_with(idx + 1, Vec::new); + memo_ingredients.get_mut(idx).unwrap() }; let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index);