Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lang: Refactor discriminator generation #3182

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate proc_macro;

use anchor_syn::Overrides;
use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
use quote::{quote, ToTokens};
use syn::{
parenthesized,
Expand Down Expand Up @@ -105,19 +105,13 @@ pub fn account(
.and_then(|ov| ov.discriminator)
.unwrap_or_else(|| {
// Namespace the discriminator to prevent collisions.
let discriminator_preimage = if namespace.is_empty() {
format!("account:{account_name}")
let namespace = if namespace.is_empty() {
"account"
} else {
format!("{namespace}:{account_name}")
&namespace
};

let mut discriminator = [0u8; 8];
discriminator.copy_from_slice(
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
);
let discriminator: proc_macro2::TokenStream =
format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
gen_discriminator(namespace, account_name)
});
let disc = if account_strct.generics.lt_token.is_some() {
quote! { #account_name::#type_gen::DISCRIMINATOR }
Expand Down
12 changes: 4 additions & 8 deletions lang/attribute/event/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate proc_macro;

#[cfg(feature = "event-cpi")]
use anchor_syn::parser::accounts::event_cpi::{add_event_cpi_accounts, EventAuthority};
use anchor_syn::Overrides;
use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
use quote::quote;
use syn::parse_macro_input;

Expand Down Expand Up @@ -37,13 +37,9 @@ pub fn event(
let event_strct = parse_macro_input!(input as syn::ItemStruct);
let event_name = &event_strct.ident;

let discriminator = args.discriminator.unwrap_or_else(|| {
let discriminator_preimage = format!("event:{event_name}").into_bytes();
let discriminator = anchor_syn::hash::hash(&discriminator_preimage);
let discriminator: proc_macro2::TokenStream =
format!("{:?}", &discriminator.0[..8]).parse().unwrap();
quote! { &#discriminator }
});
let discriminator = args
.discriminator
.unwrap_or_else(|| gen_discriminator("event", event_name));

let ret = quote! {
#[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)]
Expand Down
5 changes: 5 additions & 0 deletions lang/syn/src/codegen/program/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ pub fn sighash(namespace: &str, name: &str) -> [u8; 8] {
sighash
}

pub fn gen_discriminator(namespace: &str, name: impl ToString) -> proc_macro2::TokenStream {
let discriminator = sighash(namespace, name.to_string().as_str());
format!("&{:?}", discriminator).parse().unwrap()
}

pub fn generate_ix_variant(name: String, args: &[IxArg]) -> proc_macro2::TokenStream {
let ix_arg_names: Vec<&syn::Ident> = args.iter().map(|arg| &arg.name).collect();
let ix_name_camel: proc_macro2::TokenStream = {
Expand Down
14 changes: 7 additions & 7 deletions lang/syn/src/codegen/program/cpi.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::codegen::program::common::{generate_ix_variant, sighash, SIGHASH_GLOBAL_NAMESPACE};
use crate::codegen::program::common::{
gen_discriminator, generate_ix_variant, SIGHASH_GLOBAL_NAMESPACE,
};
use crate::Program;
use heck::SnakeCase;
use quote::{quote, ToTokens};
Expand All @@ -11,13 +13,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
.map(|ix| {
let accounts_ident: proc_macro2::TokenStream = format!("crate::cpi::accounts::{}", &ix.anchor_ident.to_string()).parse().unwrap();
let cpi_method = {
let ix_variant = generate_ix_variant(ix.raw_method.sig.ident.to_string(), &ix.args);
let name = &ix.raw_method.sig.ident;
let ix_variant = generate_ix_variant(name.to_string(), &ix.args);
let method_name = &ix.ident;
let args: Vec<&syn::PatType> = ix.args.iter().map(|arg| &arg.raw_arg).collect();
let name = &ix.raw_method.sig.ident.to_string();
let sighash_arr = sighash(SIGHASH_GLOBAL_NAMESPACE, name);
let sighash_tts: proc_macro2::TokenStream =
format!("{sighash_arr:?}").parse().unwrap();
let discriminator = gen_discriminator(SIGHASH_GLOBAL_NAMESPACE, name);
let ret_type = &ix.returns.ty.to_token_stream();
let (method_ret, maybe_return) = match ret_type.to_string().as_str() {
"()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
Expand All @@ -35,7 +35,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
let ix = {
let ix = instruction::#ix_variant;
let mut data = Vec::with_capacity(256);
data.extend_from_slice(&#sighash_tts);
data.extend_from_slice(#discriminator);
AnchorSerialize::serialize(&ix, &mut data)
.map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotSerialize)?;
let accounts = ctx.to_account_metas(None);
Expand Down
14 changes: 5 additions & 9 deletions lang/syn/src/codegen/program/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
Some(overrides) if overrides.discriminator.is_some() => {
overrides.discriminator.as_ref().unwrap().to_owned()
}
_ => {
// TODO: Remove `interface_discriminator`
let discriminator = ix
.interface_discriminator
.unwrap_or_else(|| sighash(SIGHASH_GLOBAL_NAMESPACE, name));
let discriminator: proc_macro2::TokenStream =
format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
}
// TODO: Remove `interface_discriminator`
_ => match &ix.interface_discriminator {
Some(disc) => format!("&{disc:?}").parse().unwrap(),
_ => gen_discriminator(SIGHASH_GLOBAL_NAMESPACE, name),
},
};

quote! {
Expand Down
Loading