diff --git a/token/program-2022/src/extension/cpi_guard/processor.rs b/token/program-2022/src/extension/cpi_guard/processor.rs index 0e3ae0d9da3..2fdc4d15799 100644 --- a/token/program-2022/src/extension/cpi_guard/processor.rs +++ b/token/program-2022/src/extension/cpi_guard/processor.rs @@ -18,37 +18,12 @@ use { }, }; -fn process_enable_cpi_guard(program_id: &Pubkey, accounts: &[AccountInfo]) -> ProgramResult { - let account_info_iter = &mut accounts.iter(); - let token_account_info = next_account_info(account_info_iter)?; - let owner_info = next_account_info(account_info_iter)?; - let owner_info_data_len = owner_info.data_len(); - - let mut account_data = token_account_info.data.borrow_mut(); - let mut account = StateWithExtensionsMut::::unpack(&mut account_data)?; - - Processor::validate_owner( - program_id, - &account.base.owner, - owner_info, - owner_info_data_len, - account_info_iter.as_slice(), - )?; - - if in_cpi() { - return Err(TokenError::CpiGuardSettingsLocked.into()); - } - - let extension = if let Ok(extension) = account.get_extension_mut::() { - extension - } else { - account.init_extension::(true)? - }; - extension.lock_cpi = true.into(); - Ok(()) -} - -fn process_diasble_cpi_guard(program_id: &Pubkey, accounts: &[AccountInfo]) -> ProgramResult { +/// Toggle the CpiGuard extension, initializing the extension if not already present. +fn process_toggle_cpi_guard( + program_id: &Pubkey, + accounts: &[AccountInfo], + enable: bool, +) -> ProgramResult { let account_info_iter = &mut accounts.iter(); let token_account_info = next_account_info(account_info_iter)?; let owner_info = next_account_info(account_info_iter)?; @@ -74,7 +49,7 @@ fn process_diasble_cpi_guard(program_id: &Pubkey, accounts: &[AccountInfo]) -> P } else { account.init_extension::(true)? }; - extension.lock_cpi = false.into(); + extension.lock_cpi = enable.into(); Ok(()) } @@ -88,11 +63,11 @@ pub(crate) fn process_instruction( match decode_instruction_type(input)? { CpiGuardInstruction::Enable => { msg!("CpiGuardInstruction::Enable"); - process_enable_cpi_guard(program_id, accounts) + process_toggle_cpi_guard(program_id, accounts, true /* enable */) } CpiGuardInstruction::Disable => { msg!("CpiGuardInstruction::Disable"); - process_diasble_cpi_guard(program_id, accounts) + process_toggle_cpi_guard(program_id, accounts, false /* disable */) } } } diff --git a/token/program-2022/src/extension/memo_transfer/processor.rs b/token/program-2022/src/extension/memo_transfer/processor.rs index 22daaa49c69..a2d406230c1 100644 --- a/token/program-2022/src/extension/memo_transfer/processor.rs +++ b/token/program-2022/src/extension/memo_transfer/processor.rs @@ -17,9 +17,11 @@ use { }, }; -fn process_enable_required_memo_transfers( +/// Toggle the RequiredMemoTransfers extension, initializing the extension if not already present. +fn process_toggle_required_memo_transfers( program_id: &Pubkey, accounts: &[AccountInfo], + enable: bool, ) -> ProgramResult { let account_info_iter = &mut accounts.iter(); let token_account_info = next_account_info(account_info_iter)?; @@ -42,36 +44,7 @@ fn process_enable_required_memo_transfers( } else { account.init_extension::(true)? }; - extension.require_incoming_transfer_memos = true.into(); - Ok(()) -} - -fn process_diasble_required_memo_transfers( - program_id: &Pubkey, - accounts: &[AccountInfo], -) -> ProgramResult { - let account_info_iter = &mut accounts.iter(); - let token_account_info = next_account_info(account_info_iter)?; - let owner_info = next_account_info(account_info_iter)?; - let owner_info_data_len = owner_info.data_len(); - - let mut account_data = token_account_info.data.borrow_mut(); - let mut account = StateWithExtensionsMut::::unpack(&mut account_data)?; - - Processor::validate_owner( - program_id, - &account.base.owner, - owner_info, - owner_info_data_len, - account_info_iter.as_slice(), - )?; - - let extension = if let Ok(extension) = account.get_extension_mut::() { - extension - } else { - account.init_extension::(true)? - }; - extension.require_incoming_transfer_memos = false.into(); + extension.require_incoming_transfer_memos = enable.into(); Ok(()) } @@ -85,11 +58,11 @@ pub(crate) fn process_instruction( match decode_instruction_type(input)? { RequiredMemoTransfersInstruction::Enable => { msg!("RequiredMemoTransfersInstruction::Enable"); - process_enable_required_memo_transfers(program_id, accounts) + process_toggle_required_memo_transfers(program_id, accounts, true /* enable */) } RequiredMemoTransfersInstruction::Disable => { msg!("RequiredMemoTransfersInstruction::Disable"); - process_diasble_required_memo_transfers(program_id, accounts) + process_toggle_required_memo_transfers(program_id, accounts, false /* disable */) } } }