Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lang: add the InitSpace macro #2346

Merged
merged 19 commits into from
Jan 26, 2023
14 changes: 12 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lang/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ anchor-attribute-error = { path = "./attribute/error", version = "0.26.0" }
anchor-attribute-program = { path = "./attribute/program", version = "0.26.0" }
anchor-attribute-event = { path = "./attribute/event", version = "0.26.0" }
anchor-derive-accounts = { path = "./derive/accounts", version = "0.26.0" }
anchor-derive-space = { path = "./derive/space", version = "0.26.0" }
arrayref = "0.3.6"
base64 = "0.13.0"
borsh = "0.9"
Expand Down
21 changes: 15 additions & 6 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod id;
/// - [`Clone`](https://doc.rust-lang.org/std/clone/trait.Clone.html)
/// - [`Discriminator`](./trait.Discriminator.html)
/// - [`Owner`](./trait.Owner.html)
/// - [`InitSpace`](./trait.InitSpace.html)
///
/// When implementing account serialization traits the first 8 bytes are
/// reserved for a unique account discriminator, self described by the first 8
Expand Down Expand Up @@ -66,29 +67,36 @@ pub fn account(
) -> proc_macro::TokenStream {
let mut namespace = "".to_string();
let mut is_zero_copy = false;
let mut skip_init_space = false;
Aursen marked this conversation as resolved.
Show resolved Hide resolved
let args_str = args.to_string();
let args: Vec<&str> = args_str.split(',').collect();
if args.len() > 2 {
panic!("Only two args are allowed to the account attribute.")
}
for arg in args {
let ns = arg
let ns: String = arg
.to_string()
.replace('\"', "")
.chars()
.filter(|c| !c.is_whitespace())
.collect();
if ns == "zero_copy" {
is_zero_copy = true;
} else {
namespace = ns;
}

match ns.as_str() {
"zero_copy" => is_zero_copy = true,
"skip_space" => skip_init_space = true,
_ => namespace = ns,
};
}

let account_strct = parse_macro_input!(input as syn::ItemStruct);
let account_name = &account_strct.ident;
let account_name_str = account_name.to_string();
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
let init_space = if skip_init_space {
quote!()
} else {
quote!(#[derive(InitSpace)])
};

let discriminator: proc_macro2::TokenStream = {
// Namespace the discriminator to prevent collisions.
Expand Down Expand Up @@ -170,6 +178,7 @@ pub fn account(
}
} else {
quote! {
#init_space
#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
#account_strct

Expand Down
17 changes: 17 additions & 0 deletions lang/derive/space/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "anchor-derive-space"
version = "0.26.0"
authors = ["Serum Foundation <[email protected]>"]
repository = "https://github.com/coral-xyz/anchor"
license = "Apache-2.0"
description = "Anchor Derive macro to automatically calculate the size of a structure or an enum"
rust-version = "1.59"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = "1.0"
172 changes: 172 additions & 0 deletions lang/derive/space/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parse_macro_input,
punctuated::{IntoIter, Punctuated},
Attribute, DeriveInput, Fields, GenericArgument, LitInt, PathArguments, Token, Type, TypeArray,
};

/// Implements a [`Space`](./trait.Space.html) trait on the given
/// struct or enum.
///
/// # Example
/// ```ignore
/// #[account]
/// pub struct ExampleAccount {
/// pub data: u64,
/// }
///
/// #[derive(Accounts)]
/// pub struct Initialize<'info> {
/// #[account(mut)]
/// pub payer: Signer<'info>,
/// pub system_program: Program<'info, System>,
/// #[account(init, payer = payer, space = 8 + ExampleAccount::INIT_SPACE)]
/// pub data: Account<'info, ExampleAccount>,
/// }
/// ```
#[proc_macro_derive(InitSpace, attributes(max_len))]
pub fn derive_anchor_deserialize(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let name = input.ident;

let expanded: TokenStream2 = match input.data {
syn::Data::Struct(strct) => match strct.fields {
Fields::Named(named) => {
let recurse = named.named.into_iter().map(|f| {
let mut max_len_args = get_max_len_args(&f.attrs);
len_from_type(f.ty, &mut max_len_args)
});

quote! {
#[automatically_derived]
impl #impl_generics anchor_lang::Space for #name #ty_generics #where_clause {
const INIT_SPACE: usize = 0 #(+ #recurse)*;
}
}
}
_ => panic!("Please use named fields in account structure"),
},
syn::Data::Enum(enm) => {
let variants = enm.variants.into_iter().map(|v| {
let len = v.fields.into_iter().map(|f| {
let mut max_len_args = get_max_len_args(&f.attrs);
len_from_type(f.ty, &mut max_len_args)
});

quote! {
0 #(+ #len)*
}
});

let max = gen_max(variants);

quote! {
#[automatically_derived]
impl anchor_lang::Space for #name {
const INIT_SPACE: usize = 1 + #max;
}
}
}
_ => unimplemented!(),
};

TokenStream::from(expanded)
}

fn gen_max<T: Iterator<Item = TokenStream2>>(mut iter: T) -> TokenStream2 {
if let Some(item) = iter.next() {
let next_item = gen_max(iter);
quote!(anchor_lang::__private::max(#item, #next_item))
} else {
quote!(0)
}
}

fn len_from_type(ty: Type, attrs: &mut Option<IntoIter<LitInt>>) -> TokenStream2 {
match ty {
Type::Array(TypeArray { elem, len, .. }) => {
let array_len = len.to_token_stream();
let type_len = len_from_type(*elem, attrs);
quote!((#array_len * #type_len))
}
Type::Path(path) => {
let path_segment = path.path.segments.last().unwrap();
let ident = &path_segment.ident;
let type_name = ident.to_string();
let first_ty = get_first_ty_arg(&path_segment.arguments);

match type_name.as_str() {
"i8" | "u8" | "bool" => quote!(1),
"i16" | "u16" => quote!(2),
"i32" | "u32" | "f32" => quote!(4),
"i64" | "u64" | "f64" => quote!(8),
"i128" | "u128" => quote!(16),
"String" => {
let max_len = get_next_arg(ident, attrs);
quote!((4 + #max_len))
}
"Pubkey" => quote!(32),
"Option" => {
if let Some(ty) = first_ty {
let type_len = len_from_type(ty, attrs);

quote!((1 + #type_len))
} else {
quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
}
}
"Vec" => {
if let Some(ty) = first_ty {
let max_len = get_next_arg(ident, attrs);
let type_len = len_from_type(ty, attrs);

quote!((4 + #type_len * #max_len))
} else {
quote_spanned!(ident.span() => compile_error!("Invalid argument in Vec"))
}
}
_ => {
Aursen marked this conversation as resolved.
Show resolved Hide resolved
let ty = &path_segment.ident;
quote!(<#ty as anchor_lang::Space>::INIT_SPACE)
}
}
}
_ => panic!("Type {:?} is not supported", ty),
}
}

fn get_first_ty_arg(args: &PathArguments) -> Option<Type> {
match args {
PathArguments::AngleBracketed(bracket) => bracket.args.iter().find_map(|el| match el {
GenericArgument::Type(ty) => Some(ty.to_owned()),
_ => None,
}),
_ => None,
}
}

fn get_max_len_args(attributes: &[Attribute]) -> Option<IntoIter<LitInt>> {
attributes
.iter()
.find(|a| a.path.is_ident("max_len"))
.and_then(|a| {
a.parse_args_with(Punctuated::<LitInt, Token![,]>::parse_terminated)
.ok()
})
.map(|p| p.into_iter())
}

fn get_next_arg(ident: &Ident, args: &mut Option<IntoIter<LitInt>>) -> TokenStream2 {
Henry-E marked this conversation as resolved.
Show resolved Hide resolved
if let Some(arg_list) = args {
if let Some(arg) = arg_list.next() {
quote!(#arg)
} else {
quote_spanned!(ident.span() => compile_error!("The number of lengths are invalid."))
}
} else {
quote_spanned!(ident.span() => compile_error!("Expected max_len attribute."))
}
}
2 changes: 1 addition & 1 deletion lang/src/idl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub struct IdlSetBuffer<'info> {
//
// Note: we use the same account for the "write buffer", similar to the
// bpf upgradeable loader's mechanism.
#[account("internal")]
#[account("internal", skip_space)]
#[derive(Debug)]
pub struct IdlAccount {
// Address that can modify the IDL.
Expand Down
17 changes: 15 additions & 2 deletions lang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub use anchor_attribute_error::*;
pub use anchor_attribute_event::{emit, event};
pub use anchor_attribute_program::program;
pub use anchor_derive_accounts::Accounts;
pub use anchor_derive_space::InitSpace;
/// Borsh is the default serialization format for instructions and accounts.
pub use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize};
pub use solana_program;
Expand Down Expand Up @@ -209,6 +210,11 @@ pub trait Discriminator {
}
}

/// Defines the space of an account for initialization.
pub trait Space {
const INIT_SPACE: usize;
}

/// Bump seed for program derived addresses.
pub trait Bump {
fn seed(&self) -> u8;
Expand Down Expand Up @@ -247,8 +253,8 @@ pub mod prelude {
require, require_eq, require_gt, require_gte, require_keys_eq, require_keys_neq,
require_neq, solana_program::bpf_loader_upgradeable::UpgradeableLoaderState, source,
system_program::System, zero_copy, AccountDeserialize, AccountSerialize, Accounts,
AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Id, Key, Owner,
ProgramData, Result, ToAccountInfo, ToAccountInfos, ToAccountMetas,
AccountsClose, AccountsExit, AnchorDeserialize, AnchorSerialize, Id, InitSpace, Key, Owner,
ProgramData, Result, Space, ToAccountInfo, ToAccountInfos, ToAccountMetas,
};
pub use anchor_attribute_error::*;
pub use borsh;
Expand Down Expand Up @@ -288,6 +294,13 @@ pub mod __private {

use solana_program::pubkey::Pubkey;

// Used to calculate the maximum between two expressions.
// It is necessary for the calculation of the enum space.
#[doc(hidden)]
pub const fn max(a: usize, b: usize) -> usize {
[a, b][(a < b) as usize]
}

// Very experimental trait.
#[doc(hidden)]
pub trait ZeroCopyAccessor<Ty> {
Expand Down
6 changes: 3 additions & 3 deletions lang/tests/generics_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");
#[derive(Accounts)]
pub struct GenericsTest<'info, T, U, const N: usize>
where
T: AccountSerialize + AccountDeserialize + Owner + Clone,
U: BorshSerialize + BorshDeserialize + Default + Clone,
T: AccountSerialize + AccountDeserialize + Space + Owner + Clone,
U: BorshSerialize + BorshDeserialize + Space + Default + Clone,
{
pub non_generic: AccountInfo<'info>,
pub generic: Account<'info, T>,
Expand All @@ -31,7 +31,7 @@ pub struct FooAccount<const N: usize> {
#[derive(Default)]
pub struct Associated<T>
where
T: BorshDeserialize + BorshSerialize + Default,
T: BorshDeserialize + BorshSerialize + Space + Default,
{
pub data: T,
}
Expand Down
Loading