From d73983d3db3f70155c20967663b8386774be5bc5 Mon Sep 17 00:00:00 2001 From: acheron <98934430+acheroncrypto@users.noreply.github.com> Date: Sun, 4 Aug 2024 22:51:46 +0200 Subject: [PATCH] lang: Add `discriminator` argument to `#[account]` attribute (#3149) --- .github/workflows/reusable-tests.yaml | 2 +- CHANGELOG.md | 1 + lang/attribute/account/src/lib.rs | 66 +++++++++++++++---- .../programs/custom-discriminator/src/lib.rs | 25 +++++++ .../tests/custom-discriminator.ts | 24 ++++++- 5 files changed, 102 insertions(+), 16 deletions(-) diff --git a/.github/workflows/reusable-tests.yaml b/.github/workflows/reusable-tests.yaml index 05be12adea..3a01ab8a5f 100644 --- a/.github/workflows/reusable-tests.yaml +++ b/.github/workflows/reusable-tests.yaml @@ -439,7 +439,7 @@ jobs: path: tests/safety-checks - cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit path: tests/custom-coder - - cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit + - cmd: cd tests/custom-discriminator && anchor test path: tests/custom-discriminator - cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit path: tests/validator-clone diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ccad0f035..0dc280f1bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The minor version will be incremented upon a breaking change and the patch versi - client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)). - lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)). - lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)). +- lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)). ### Fixes diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index f4f936ece8..fd90257b2e 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -6,7 +6,7 @@ use syn::{ parse::{Parse, ParseStream}, parse_macro_input, token::{Comma, Paren}, - Ident, LitStr, + Expr, Ident, Lit, LitStr, Token, }; mod id; @@ -31,6 +31,22 @@ mod id; /// check this discriminator. If it doesn't match, an invalid account was given, /// and the account deserialization will exit with an error. /// +/// # Args +/// +/// - `discriminator`: Override the default 8-byte discriminator +/// +/// **Usage:** `discriminator = ` +/// +/// All constant expressions are supported. +/// +/// **Examples:** +/// +/// - `discriminator = 0` (shortcut for `[0]`) +/// - `discriminator = [1, 2, 3, 4]` +/// - `discriminator = b"hi"` +/// - `discriminator = MY_DISC` +/// - `discriminator = get_disc(...)` +/// /// # Zero Copy Deserialization /// /// **WARNING**: Zero copy deserialization is an experimental feature. It's @@ -83,23 +99,21 @@ pub fn account( let account_name_str = account_name.to_string(); let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl(); - let discriminator: proc_macro2::TokenStream = { + let discriminator = args.discriminator.unwrap_or_else(|| { // Namespace the discriminator to prevent collisions. - let discriminator_preimage = { - // For now, zero copy accounts can't be namespaced. - if namespace.is_empty() { - format!("account:{account_name}") - } else { - format!("{namespace}:{account_name}") - } + let discriminator_preimage = if namespace.is_empty() { + format!("account:{account_name}") + } else { + format!("{namespace}:{account_name}") }; let mut discriminator = [0u8; 8]; discriminator.copy_from_slice( &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], ); - format!("{discriminator:?}").parse().unwrap() - }; + let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap(); + quote! { &#discriminator } + }); let disc = if account_strct.generics.lt_token.is_some() { quote! { #account_name::#type_gen::DISCRIMINATOR } } else { @@ -159,7 +173,7 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { - const DISCRIMINATOR: &'static [u8] = &#discriminator; + const DISCRIMINATOR: &'static [u8] = #discriminator; } // This trait is useful for clients deserializing accounts. @@ -229,7 +243,7 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { - const DISCRIMINATOR: &'static [u8] = &#discriminator; + const DISCRIMINATOR: &'static [u8] = #discriminator; } #owner_impl @@ -242,7 +256,10 @@ pub fn account( struct AccountArgs { /// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)` zero_copy: Option, + /// Account namespace override, `account` if not specified namespace: Option, + /// Discriminator override + discriminator: Option, } impl Parse for AccountArgs { @@ -257,6 +274,9 @@ impl Parse for AccountArgs { AccountArg::Namespace(ns) => { parsed.namespace.replace(ns); } + AccountArg::Discriminator(disc) => { + parsed.discriminator.replace(disc); + } } } @@ -267,6 +287,7 @@ impl Parse for AccountArgs { enum AccountArg { ZeroCopy { is_unsafe: bool }, Namespace(String), + Discriminator(proc_macro2::TokenStream), } impl Parse for AccountArg { @@ -300,7 +321,24 @@ impl Parse for AccountArg { return Ok(Self::ZeroCopy { is_unsafe }); }; - Err(syn::Error::new(ident.span(), "Unexpected argument")) + // Named arguments + // TODO: Share the common arguments with `#[instruction]` + input.parse::()?; + let value = input.parse::()?; + match ident.to_string().as_str() { + "discriminator" => { + let value = match value { + // Allow `discriminator = 42` + Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] }, + // Allow `discriminator = [0, 1, 2, 3]` + Expr::Array(arr) => quote! { &#arr }, + expr => expr.to_token_stream(), + }; + + Ok(Self::Discriminator(value)) + } + _ => Err(syn::Error::new(ident.span(), "Invalid argument")), + } } } diff --git a/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs b/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs index 64d8335dc7..cdcc479375 100644 --- a/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs +++ b/tests/custom-discriminator/programs/custom-discriminator/src/lib.rs @@ -39,9 +39,34 @@ pub mod custom_discriminator { pub fn const_fn(_ctx: Context) -> Result<()> { Ok(()) } + + pub fn account(ctx: Context, field: u8) -> Result<()> { + ctx.accounts.my_account.field = field; + Ok(()) + } } #[derive(Accounts)] pub struct DefaultIx<'info> { pub signer: Signer<'info>, } + +#[derive(Accounts)] +pub struct CustomAccountIx<'info> { + #[account(mut)] + pub signer: Signer<'info>, + #[account( + init, + payer = signer, + space = MyAccount::DISCRIMINATOR.len() + core::mem::size_of::(), + seeds = [b"my_account"], + bump + )] + pub my_account: Account<'info, MyAccount>, + pub system_program: Program<'info, System>, +} + +#[account(discriminator = 1)] +pub struct MyAccount { + pub field: u8, +} diff --git a/tests/custom-discriminator/tests/custom-discriminator.ts b/tests/custom-discriminator/tests/custom-discriminator.ts index 0d4cdbf10a..f0d5fc1af8 100644 --- a/tests/custom-discriminator/tests/custom-discriminator.ts +++ b/tests/custom-discriminator/tests/custom-discriminator.ts @@ -8,7 +8,7 @@ describe("custom-discriminator", () => { const program: anchor.Program = anchor.workspace.customDiscriminator; - describe("Can use custom instruction discriminators", () => { + describe("Instructions", () => { const testCommon = async (ixName: keyof typeof program["methods"]) => { const tx = await program.methods[ixName]().transaction(); @@ -28,4 +28,26 @@ describe("custom-discriminator", () => { it("Constant", () => testCommon("constant")); it("Const Fn", () => testCommon("constFn")); }); + + describe("Accounts", () => { + it("Works", async () => { + // Verify discriminator + const acc = program.idl.accounts.find((acc) => acc.name === "myAccount")!; + assert(acc.discriminator.length < 8); + + // Verify regular `init` ix works + const field = 5; + const { pubkeys, signature } = await program.methods + .account(field) + .rpcAndKeys(); + await program.provider.connection.confirmTransaction( + signature, + "confirmed" + ); + const myAccount = await program.account.myAccount.fetch( + pubkeys.myAccount + ); + assert.strictEqual(field, myAccount.field); + }); + }); });