Skip to content

Commit

Permalink
lang: Add discriminator argument to #[account] attribute (#3149)
Browse files Browse the repository at this point in the history
  • Loading branch information
acheroncrypto authored Aug 4, 2024
1 parent 4853cd1 commit d73983d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/reusable-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ jobs:
path: tests/safety-checks
- cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit
path: tests/custom-coder
- cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit
- cmd: cd tests/custom-discriminator && anchor test
path: tests/custom-discriminator
- cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit
path: tests/validator-clone
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
- lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
- lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).
- lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)).

### Fixes

Expand Down
66 changes: 52 additions & 14 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
token::{Comma, Paren},
Ident, LitStr,
Expr, Ident, Lit, LitStr, Token,
};

mod id;
Expand All @@ -31,6 +31,22 @@ mod id;
/// check this discriminator. If it doesn't match, an invalid account was given,
/// and the account deserialization will exit with an error.
///
/// # Args
///
/// - `discriminator`: Override the default 8-byte discriminator
///
/// **Usage:** `discriminator = <CONST_EXPR>`
///
/// All constant expressions are supported.
///
/// **Examples:**
///
/// - `discriminator = 0` (shortcut for `[0]`)
/// - `discriminator = [1, 2, 3, 4]`
/// - `discriminator = b"hi"`
/// - `discriminator = MY_DISC`
/// - `discriminator = get_disc(...)`
///
/// # Zero Copy Deserialization
///
/// **WARNING**: Zero copy deserialization is an experimental feature. It's
Expand Down Expand Up @@ -83,23 +99,21 @@ pub fn account(
let account_name_str = account_name.to_string();
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();

let discriminator: proc_macro2::TokenStream = {
let discriminator = args.discriminator.unwrap_or_else(|| {
// Namespace the discriminator to prevent collisions.
let discriminator_preimage = {
// For now, zero copy accounts can't be namespaced.
if namespace.is_empty() {
format!("account:{account_name}")
} else {
format!("{namespace}:{account_name}")
}
let discriminator_preimage = if namespace.is_empty() {
format!("account:{account_name}")
} else {
format!("{namespace}:{account_name}")
};

let mut discriminator = [0u8; 8];
discriminator.copy_from_slice(
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
);
format!("{discriminator:?}").parse().unwrap()
};
let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
quote! { &#discriminator }
});
let disc = if account_strct.generics.lt_token.is_some() {
quote! { #account_name::#type_gen::DISCRIMINATOR }
} else {
Expand Down Expand Up @@ -159,7 +173,7 @@ pub fn account(

#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
const DISCRIMINATOR: &'static [u8] = &#discriminator;
const DISCRIMINATOR: &'static [u8] = #discriminator;
}

// This trait is useful for clients deserializing accounts.
Expand Down Expand Up @@ -229,7 +243,7 @@ pub fn account(

#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
const DISCRIMINATOR: &'static [u8] = &#discriminator;
const DISCRIMINATOR: &'static [u8] = #discriminator;
}

#owner_impl
Expand All @@ -242,7 +256,10 @@ pub fn account(
struct AccountArgs {
/// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
zero_copy: Option<bool>,
/// Account namespace override, `account` if not specified
namespace: Option<String>,
/// Discriminator override
discriminator: Option<proc_macro2::TokenStream>,
}

impl Parse for AccountArgs {
Expand All @@ -257,6 +274,9 @@ impl Parse for AccountArgs {
AccountArg::Namespace(ns) => {
parsed.namespace.replace(ns);
}
AccountArg::Discriminator(disc) => {
parsed.discriminator.replace(disc);
}
}
}

Expand All @@ -267,6 +287,7 @@ impl Parse for AccountArgs {
enum AccountArg {
ZeroCopy { is_unsafe: bool },
Namespace(String),
Discriminator(proc_macro2::TokenStream),
}

impl Parse for AccountArg {
Expand Down Expand Up @@ -300,7 +321,24 @@ impl Parse for AccountArg {
return Ok(Self::ZeroCopy { is_unsafe });
};

Err(syn::Error::new(ident.span(), "Unexpected argument"))
// Named arguments
// TODO: Share the common arguments with `#[instruction]`
input.parse::<Token![=]>()?;
let value = input.parse::<Expr>()?;
match ident.to_string().as_str() {
"discriminator" => {
let value = match value {
// Allow `discriminator = 42`
Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
// Allow `discriminator = [0, 1, 2, 3]`
Expr::Array(arr) => quote! { &#arr },
expr => expr.to_token_stream(),
};

Ok(Self::Discriminator(value))
}
_ => Err(syn::Error::new(ident.span(), "Invalid argument")),
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,34 @@ pub mod custom_discriminator {
pub fn const_fn(_ctx: Context<DefaultIx>) -> Result<()> {
Ok(())
}

pub fn account(ctx: Context<CustomAccountIx>, field: u8) -> Result<()> {
ctx.accounts.my_account.field = field;
Ok(())
}
}

#[derive(Accounts)]
pub struct DefaultIx<'info> {
pub signer: Signer<'info>,
}

#[derive(Accounts)]
pub struct CustomAccountIx<'info> {
#[account(mut)]
pub signer: Signer<'info>,
#[account(
init,
payer = signer,
space = MyAccount::DISCRIMINATOR.len() + core::mem::size_of::<MyAccount>(),
seeds = [b"my_account"],
bump
)]
pub my_account: Account<'info, MyAccount>,
pub system_program: Program<'info, System>,
}

#[account(discriminator = 1)]
pub struct MyAccount {
pub field: u8,
}
24 changes: 23 additions & 1 deletion tests/custom-discriminator/tests/custom-discriminator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ describe("custom-discriminator", () => {
const program: anchor.Program<CustomDiscriminator> =
anchor.workspace.customDiscriminator;

describe("Can use custom instruction discriminators", () => {
describe("Instructions", () => {
const testCommon = async (ixName: keyof typeof program["methods"]) => {
const tx = await program.methods[ixName]().transaction();

Expand All @@ -28,4 +28,26 @@ describe("custom-discriminator", () => {
it("Constant", () => testCommon("constant"));
it("Const Fn", () => testCommon("constFn"));
});

describe("Accounts", () => {
it("Works", async () => {
// Verify discriminator
const acc = program.idl.accounts.find((acc) => acc.name === "myAccount")!;
assert(acc.discriminator.length < 8);

// Verify regular `init` ix works
const field = 5;
const { pubkeys, signature } = await program.methods
.account(field)
.rpcAndKeys();
await program.provider.connection.confirmTransaction(
signature,
"confirmed"
);
const myAccount = await program.account.myAccount.fetch(
pubkeys.myAccount
);
assert.strictEqual(field, myAccount.field);
});
});
});

0 comments on commit d73983d

Please sign in to comment.