Skip to content

Commit

Permalink
Handle multiple aux data returns from code expansion
Browse files Browse the repository at this point in the history
commit-id:0504e339
  • Loading branch information
maciektr committed Mar 22, 2024
1 parent aa83ef3 commit da02999
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
44 changes: 35 additions & 9 deletions scarb/src/compiler/plugin/proc_macro/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use itertools::Itertools;
use scarb_stable_hash::short_hash;
use std::any::Any;
use std::sync::Arc;
use std::vec::IntoIter;
use tracing::{debug, trace_span};

/// A Cairo compiler plugin controlling the procedural macro execution.
Expand All @@ -45,7 +46,7 @@ impl ProcMacroId {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ProcMacroAuxData {
value: Vec<u8>,
macro_id: ProcMacroId,
Expand All @@ -63,14 +64,35 @@ impl From<ProcMacroAuxData> for AuxData {
}
}

impl GeneratedFileAuxData for ProcMacroAuxData {
#[derive(Debug, Clone, Default)]
pub struct EmittedAuxData(Vec<ProcMacroAuxData>);

impl GeneratedFileAuxData for EmittedAuxData {
fn as_any(&self) -> &dyn Any {
self
}

fn eq(&self, other: &dyn GeneratedFileAuxData) -> bool {
self.value == other.as_any().downcast_ref::<Self>().unwrap().value
&& self.macro_id == other.as_any().downcast_ref::<Self>().unwrap().macro_id
self.0 == other.as_any().downcast_ref::<Self>().unwrap().0
}
}

impl EmittedAuxData {
pub fn push(&mut self, aux_data: ProcMacroAuxData) {
self.0.push(aux_data);
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

impl IntoIterator for EmittedAuxData {
type Item = ProcMacroAuxData;
type IntoIter = IntoIter<Self::Item>;

fn into_iter(self) -> IntoIter<ProcMacroAuxData> {
self.0.into_iter()
}
}

Expand Down Expand Up @@ -172,9 +194,9 @@ impl ProcMacroHostPlugin {
let aux_data = file_info
.aux_data
.as_ref()
.and_then(|ad| ad.as_any().downcast_ref::<ProcMacroAuxData>());
.and_then(|ad| ad.as_any().downcast_ref::<EmittedAuxData>());
if let Some(aux_data) = aux_data {
data.push(aux_data.clone());
data.extend(aux_data.clone().into_iter());
}
}
}
Expand Down Expand Up @@ -218,7 +240,7 @@ impl MacroPlugin for ProcMacroHostPlugin {

let mut token_stream = TokenStream::from_item_ast(db, item_ast)
.with_metadata(TokenStreamMetadata::new(file_path, file_id));
let mut aux_data: Option<ProcMacroAuxData> = None;
let mut aux_data = EmittedAuxData::default();
let mut modified = false;
let mut all_diagnostics: Vec<Diagnostic> = Vec::new();
for input in expansions {
Expand All @@ -235,7 +257,7 @@ impl MacroPlugin for ProcMacroHostPlugin {
} => {
token_stream = new_token_stream;
if let Some(new_aux_data) = new_aux_data {
aux_data = Some(ProcMacroAuxData::new(
aux_data.push(ProcMacroAuxData::new(
new_aux_data.into(),
ProcMacroId::new(input.package_id, input.expansion.clone()),
));
Expand All @@ -262,7 +284,11 @@ impl MacroPlugin for ProcMacroHostPlugin {
name: "proc_macro".into(),
content: token_stream.to_string(),
code_mappings: Default::default(),
aux_data: aux_data.map(DynGeneratedFileAuxData::new),
aux_data: if aux_data.is_empty() {
None
} else {
Some(DynGeneratedFileAuxData::new(aux_data))
},
}),
diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr),
remove_original_item: true,
Expand Down
23 changes: 18 additions & 5 deletions scarb/tests/build_cairo_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ fn can_define_multiple_macros() {
simple_project_with_code(
&t,
indoc! {r##"
use cairo_lang_macro::{ProcMacroResult, TokenStream, attribute_macro};
use cairo_lang_macro::{ProcMacroResult, TokenStream, attribute_macro, AuxData, post_process};
#[attribute_macro]
pub fn hello(token_stream: TokenStream) -> ProcMacroResult {
Expand All @@ -607,7 +607,8 @@ fn can_define_multiple_macros() {
.replace("#[hello]", "")
.replace("12", "34")
);
ProcMacroResult::replace(token_stream, None)
let aux_data = AuxData::new(Vec::new());
ProcMacroResult::replace(token_stream, Some(aux_data))
}
#[attribute_macro]
Expand All @@ -619,7 +620,13 @@ fn can_define_multiple_macros() {
.replace("#[world]", "")
.replace("56", "78")
);
ProcMacroResult::replace(token_stream, None)
let aux_data = AuxData::new(Vec::new());
ProcMacroResult::replace(token_stream, Some(aux_data))
}
#[post_process]
pub fn callback(aux_data: Vec<AuxData>) {
assert_eq!(aux_data.len(), 2);
}
"##},
);
Expand All @@ -628,7 +635,7 @@ fn can_define_multiple_macros() {
simple_project_with_code_and_name(
&w,
indoc! {r##"
use cairo_lang_macro::{ProcMacroResult, TokenStream, attribute_macro};
use cairo_lang_macro::{ProcMacroResult, TokenStream, attribute_macro, AuxData, post_process};
#[attribute_macro]
pub fn beautiful(token_stream: TokenStream) -> ProcMacroResult {
Expand All @@ -639,7 +646,13 @@ fn can_define_multiple_macros() {
.replace("#[beautiful]", "")
.replace("90", "09")
);
ProcMacroResult::replace(token_stream, None)
let aux_data = AuxData::new(Vec::new());
ProcMacroResult::replace(token_stream, Some(aux_data))
}
#[post_process]
pub fn callback(aux_data: Vec<AuxData>) {
assert_eq!(aux_data.len(), 1);
}
"##},
"other",
Expand Down

0 comments on commit da02999

Please sign in to comment.