Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lang: Reuse common override arguments #3154

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
extern crate proc_macro;

use anchor_syn::Overrides;
use quote::{quote, ToTokens};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input,
token::{Comma, Paren},
Expr, Ident, Lit, LitStr, Token,
Ident, LitStr,
};

mod id;
Expand Down Expand Up @@ -99,21 +100,25 @@ pub fn account(
let account_name_str = account_name.to_string();
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();

let discriminator = args.discriminator.unwrap_or_else(|| {
// Namespace the discriminator to prevent collisions.
let discriminator_preimage = if namespace.is_empty() {
format!("account:{account_name}")
} else {
format!("{namespace}:{account_name}")
};
let discriminator = args
.overrides
.and_then(|ov| ov.discriminator)
.unwrap_or_else(|| {
// Namespace the discriminator to prevent collisions.
let discriminator_preimage = if namespace.is_empty() {
format!("account:{account_name}")
} else {
format!("{namespace}:{account_name}")
};

let mut discriminator = [0u8; 8];
discriminator.copy_from_slice(
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
);
let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
});
let mut discriminator = [0u8; 8];
discriminator.copy_from_slice(
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
);
let discriminator: proc_macro2::TokenStream =
format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
});
let disc = if account_strct.generics.lt_token.is_some() {
quote! { #account_name::#type_gen::DISCRIMINATOR }
} else {
Expand Down Expand Up @@ -258,8 +263,8 @@ struct AccountArgs {
zero_copy: Option<bool>,
/// Account namespace override, `account` if not specified
namespace: Option<String>,
/// Discriminator override
discriminator: Option<proc_macro2::TokenStream>,
/// Named overrides
overrides: Option<Overrides>,
}

impl Parse for AccountArgs {
Expand All @@ -274,8 +279,8 @@ impl Parse for AccountArgs {
AccountArg::Namespace(ns) => {
parsed.namespace.replace(ns);
}
AccountArg::Discriminator(disc) => {
parsed.discriminator.replace(disc);
AccountArg::Overrides(ov) => {
parsed.overrides.replace(ov);
}
}
}
Expand All @@ -287,7 +292,7 @@ impl Parse for AccountArgs {
enum AccountArg {
ZeroCopy { is_unsafe: bool },
Namespace(String),
Discriminator(proc_macro2::TokenStream),
Overrides(Overrides),
}

impl Parse for AccountArg {
Expand All @@ -300,8 +305,8 @@ impl Parse for AccountArg {
}

// Zero copy
let ident = input.parse::<Ident>()?;
if ident == "zero_copy" {
if input.fork().parse::<Ident>()? == "zero_copy" {
input.parse::<Ident>()?;
let is_unsafe = if input.peek(Paren) {
let content;
parenthesized!(content in input);
Expand All @@ -321,24 +326,8 @@ impl Parse for AccountArg {
return Ok(Self::ZeroCopy { is_unsafe });
};

// Named arguments
// TODO: Share the common arguments with `#[instruction]`
input.parse::<Token![=]>()?;
let value = input.parse::<Expr>()?;
match ident.to_string().as_str() {
"discriminator" => {
let value = match value {
// Allow `discriminator = 42`
Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
// Allow `discriminator = [0, 1, 2, 3]`
Expr::Array(arr) => quote! { &#arr },
expr => expr.to_token_stream(),
};

Ok(Self::Discriminator(value))
}
_ => Err(syn::Error::new(ident.span(), "Invalid argument")),
}
// Overrides
input.parse::<Overrides>().map(Self::Overrides)
}
}

Expand Down
60 changes: 4 additions & 56 deletions lang/attribute/event/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ extern crate proc_macro;

#[cfg(feature = "event-cpi")]
use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
token::Comma,
Expr, Ident, Lit, Token,
};
use anchor_syn::Overrides;
use quote::quote;
use syn::parse_macro_input;

/// The event attribute allows a struct to be used with
/// [emit!](./macro.emit.html) so that programs can log significant events in
Expand Down Expand Up @@ -37,7 +33,7 @@ pub fn event(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let args = parse_macro_input!(args as EventArgs);
let args = parse_macro_input!(args as Overrides);
let event_strct = parse_macro_input!(input as syn::ItemStruct);
let event_name = &event_strct.ident;

Expand Down Expand Up @@ -80,54 +76,6 @@ pub fn event(
proc_macro::TokenStream::from(ret)
}

#[derive(Debug, Default)]
struct EventArgs {
/// Discriminator override
discriminator: Option<proc_macro2::TokenStream>,
}

impl Parse for EventArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
// TODO: Share impl with `#[instruction]`
let mut parsed = Self::default();
let args = input.parse_terminated::<_, Comma>(EventArg::parse)?;
for arg in args {
match arg.name.to_string().as_str() {
"discriminator" => {
let value = match &arg.value {
// Allow `discriminator = 42`
Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
// Allow `discriminator = [0, 1, 2, 3]`
Expr::Array(arr) => quote! { &#arr },
expr => expr.to_token_stream(),
};
parsed.discriminator.replace(value);
}
_ => return Err(syn::Error::new(arg.name.span(), "Invalid argument")),
}
}

Ok(parsed)
}
}

struct EventArg {
name: Ident,
#[allow(dead_code)]
eq_token: Token![=],
value: Expr,
}

impl Parse for EventArg {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Self {
name: input.parse()?,
eq_token: input.parse()?,
value: input.parse()?,
})
}
}

// EventIndex is a marker macro. It functionally does nothing other than
// allow one to mark fields with the `#[index]` inert attribute, which is
// used to add metadata to IDLs.
Expand Down
6 changes: 3 additions & 3 deletions lang/syn/src/codegen/program/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
})
.collect();
let impls = {
let discriminator = match ix.ix_attr.as_ref() {
Some(ix_attr) if ix_attr.discriminator.is_some() => {
ix_attr.discriminator.as_ref().unwrap().to_owned()
let discriminator = match ix.overrides.as_ref() {
Some(overrides) if overrides.discriminator.is_some() => {
overrides.discriminator.as_ref().unwrap().to_owned()
}
_ => {
// TODO: Remove `interface_discriminator`
Expand Down
22 changes: 11 additions & 11 deletions lang/syn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ pub struct Ix {
// The ident for the struct deriving Accounts.
pub anchor_ident: Ident,
// The discriminator based on the `#[interface]` attribute.
// TODO: Remove and use `ix_attr`
// TODO: Remove and use `overrides`
pub interface_discriminator: Option<[u8; 8]>,
/// `#[instruction]` attribute
pub ix_attr: Option<IxAttr>,
/// Overrides coming from the `#[instruction]` attribute
pub overrides: Option<Overrides>,
}

/// `#[instruction]` attribute proc-macro
/// Common overrides for the `#[instruction]`, `#[account]` and `#[event]` attributes
#[derive(Debug, Default)]
pub struct IxAttr {
/// Discriminator override
pub struct Overrides {
/// Override the default 8-byte discriminator
pub discriminator: Option<TokenStream>,
}

impl Parse for IxAttr {
impl Parse for Overrides {
fn parse(input: ParseStream) -> ParseResult<Self> {
let mut attr = Self::default();
let args = input.parse_terminated::<_, Comma>(AttrArg::parse)?;
let args = input.parse_terminated::<_, Comma>(NamedArg::parse)?;
for arg in args {
match arg.name.to_string().as_str() {
"discriminator" => {
Expand All @@ -106,14 +106,14 @@ impl Parse for IxAttr {
}
}

struct AttrArg {
struct NamedArg {
name: Ident,
#[allow(dead_code)]
eq_token: Token!(=),
eq_token: Token![=],
value: Expr,
}

impl Parse for AttrArg {
impl Parse for NamedArg {
fn parse(input: ParseStream) -> ParseResult<Self> {
Ok(Self {
name: input.parse()?,
Expand Down
10 changes: 5 additions & 5 deletions lang/syn/src/parser/program/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::parser::docs;
use crate::parser::program::ctx_accounts_ident;
use crate::parser::spl_interface;
use crate::{FallbackFn, Ix, IxArg, IxAttr, IxReturn};
use crate::{FallbackFn, Ix, IxArg, IxReturn, Overrides};
use syn::parse::{Error as ParseError, Result as ParseResult};
use syn::spanned::Spanned;

Expand All @@ -25,7 +25,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
})
.map(|method: &syn::ItemFn| {
let (ctx, args) = parse_args(method)?;
let ix_attr = parse_ix_attr(&method.attrs)?;
let overrides = parse_overrides(&method.attrs)?;
let interface_discriminator = spl_interface::parse(&method.attrs);
let docs = docs::parse(&method.attrs);
let returns = parse_return(method)?;
Expand All @@ -38,7 +38,7 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
anchor_ident,
returns,
interface_discriminator,
ix_attr,
overrides,
})
})
.collect::<ParseResult<Vec<Ix>>>()?;
Expand Down Expand Up @@ -73,8 +73,8 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
Ok((ixs, fallback_fn))
}

/// Parse `#[instruction]` attribute proc-macro.
fn parse_ix_attr(attrs: &[syn::Attribute]) -> ParseResult<Option<IxAttr>> {
/// Parse overrides from the `#[instruction]` attribute proc-macro.
fn parse_overrides(attrs: &[syn::Attribute]) -> ParseResult<Option<Overrides>> {
attrs
.iter()
.find(|attr| match attr.path.segments.last() {
Expand Down
Loading