diff --git a/Cargo.lock b/Cargo.lock index e6ce3c0fc149f8..ff3f0ad576cf3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,13 +709,12 @@ dependencies = [ [[package]] name = "curve25519-dalek" version = "2.1.0" -source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5" dependencies = [ - "borsh", "byteorder", "digest 0.8.1", "rand_core", - "serde", "subtle 2.2.3", "zeroize", ] @@ -723,12 +722,13 @@ dependencies = [ [[package]] name = "curve25519-dalek" version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5" +source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff" dependencies = [ + "borsh", "byteorder", "digest 0.8.1", "rand_core", + "serde", "subtle 2.2.3", "zeroize", ] diff --git a/token-swap/js/cli/token-swap-test.js b/token-swap/js/cli/token-swap-test.js index 558328a42587cc..b00bdbea199c26 100644 --- a/token-swap/js/cli/token-swap-test.js +++ b/token-swap/js/cli/token-swap-test.js @@ -64,7 +64,7 @@ let currentFeeAmount = 0; // Because there is no withdraw fee in the production version, these numbers // need to get slightly tweaked in the two cases. const SWAP_AMOUNT_IN = 100000; -const SWAP_AMOUNT_OUT = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 90662 : 90675; +const SWAP_AMOUNT_OUT = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 90661 : 90674; const SWAP_FEE = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 22272 : 22276; const HOST_SWAP_FEE = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? Math.floor((SWAP_FEE * HOST_FEE_NUMERATOR) / HOST_FEE_DENOMINATOR) diff --git a/token-swap/program/src/curve/calculator.rs b/token-swap/program/src/curve/calculator.rs index 349692a2811293..6083133ac76d03 100644 --- a/token-swap/program/src/curve/calculator.rs +++ b/token-swap/program/src/curve/calculator.rs @@ -8,7 +8,7 @@ use std::fmt::Debug; /// input amounts, and Balancer uses 100 * 10 ^ 18. pub const INITIAL_SWAP_POOL_AMOUNT: u128 = 1_000_000_000; -/// Helper function for calcuating swap fee +/// Helper function for calculating swap fee pub fn calculate_fee( token_amount: u128, fee_numerator: u128, @@ -43,14 +43,24 @@ pub struct SwapResult { pub new_source_amount: u128, /// New amount of destination token pub new_destination_amount: u128, + /// Amount of source token swapped (includes fees) + pub source_amount_swapped: u128, /// Amount of destination token swapped - pub amount_swapped: u128, + pub destination_amount_swapped: u128, /// Amount of source tokens going to pool holders pub trade_fee: u128, /// Amount of source tokens going to owner pub owner_fee: u128, } +/// Encodes all results of swapping from a source token to a destination token +pub struct SwapWithoutFeesResult { + /// Amount of source token swapped + pub source_amount_swapped: u128, + /// Amount of destination token swapped + pub destination_amount_swapped: u128, +} + /// Trait for packing of trait objects, required because structs that implement /// `Pack` cannot be used as trait objects (as `dyn Pack`). pub trait DynPack { @@ -62,12 +72,48 @@ pub trait DynPack { pub trait CurveCalculator: Debug + DynPack { /// Calculate how much destination token will be provided given an amount /// of source token. + fn swap_without_fees( + &self, + source_amount: u128, + swap_source_amount: u128, + swap_destination_amount: u128, + ) -> Option; + + /// Subtract fees and calculate how much destination token will be provided + /// given an amount of source token. fn swap( &self, source_amount: u128, swap_source_amount: u128, swap_destination_amount: u128, - ) -> Option; + ) -> Option { + // debit the fee to calculate the amount swapped + let trade_fee = self.trading_fee(source_amount)?; + let owner_fee = self.owner_trading_fee(source_amount)?; + + let total_fees = trade_fee.checked_add(owner_fee)?; + let source_amount_less_fees = source_amount.checked_sub(total_fees)?; + + let SwapWithoutFeesResult { + source_amount_swapped, + destination_amount_swapped, + } = self.swap_without_fees( + source_amount_less_fees, + swap_source_amount, + swap_destination_amount, + )?; + + let source_amount_swapped = source_amount_swapped.checked_add(total_fees)?; + Some(SwapResult { + new_source_amount: swap_source_amount.checked_add(source_amount_swapped)?, + new_destination_amount: swap_destination_amount + .checked_sub(destination_amount_swapped)?, + source_amount_swapped, + destination_amount_swapped, + trade_fee, + owner_fee, + }) + } /// Calculate the withdraw fee in pool tokens /// Default implementation assumes no fee diff --git a/token-swap/program/src/curve/constant_product.rs b/token-swap/program/src/curve/constant_product.rs index e23e586b1fdc71..be520f35ea69eb 100644 --- a/token-swap/program/src/curve/constant_product.rs +++ b/token-swap/program/src/curve/constant_product.rs @@ -6,7 +6,7 @@ use solana_program::{ }; use crate::curve::calculator::{ - calculate_fee, map_zero_to_none, CurveCalculator, DynPack, SwapResult, + calculate_fee, map_zero_to_none, CurveCalculator, DynPack, SwapWithoutFeesResult, }; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use std::convert::TryFrom; @@ -34,33 +34,38 @@ pub struct ConstantProductCurve { impl CurveCalculator for ConstantProductCurve { /// Constant product swap ensures x * y = constant - fn swap( + fn swap_without_fees( &self, source_amount: u128, swap_source_amount: u128, swap_destination_amount: u128, - ) -> Option { - // debit the fee to calculate the amount swapped - let trade_fee = self.trading_fee(source_amount)?; - let owner_fee = self.owner_trading_fee(source_amount)?; - + ) -> Option { let invariant = swap_source_amount.checked_mul(swap_destination_amount)?; - let new_source_amount_less_fee = swap_source_amount - .checked_add(source_amount)? - .checked_sub(trade_fee)? - .checked_sub(owner_fee)?; - let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?; - let amount_swapped = - map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?; - // actually add the whole amount coming in - let new_source_amount = swap_source_amount.checked_add(source_amount)?; - Some(SwapResult { - new_source_amount, - new_destination_amount, - amount_swapped, - trade_fee, - owner_fee, + let mut new_swap_source_amount = swap_source_amount.checked_add(source_amount)?; + let mut new_swap_destination_amount = invariant.checked_div(new_swap_source_amount)?; + + // Ceiling the destination amount if there's any remainder, which will + // almost always be the case. + let remainder = invariant.checked_rem(new_swap_source_amount)?; + if remainder > 0 { + new_swap_destination_amount = new_swap_destination_amount.checked_add(1)?; + // now calculate the minimum amount of source token needed to get + // the destination amount to avoid taking too much from users + new_swap_source_amount = invariant.checked_div(new_swap_destination_amount)?; + let remainder = invariant.checked_rem(new_swap_destination_amount)?; + if remainder > 0 { + new_swap_source_amount = new_swap_source_amount.checked_add(1)?; + } + } + + let source_amount_swapped = new_swap_source_amount.checked_sub(swap_source_amount)?; + let destination_amount_swapped = + map_zero_to_none(swap_destination_amount.checked_sub(new_swap_destination_amount)?)?; + + Some(SwapWithoutFeesResult { + source_amount_swapped, + destination_amount_swapped, }) } @@ -255,8 +260,8 @@ mod tests { .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); assert_eq!(result.new_source_amount, 1100); - assert_eq!(result.amount_swapped, 4505); - assert_eq!(result.new_destination_amount, 45495); + assert_eq!(result.destination_amount_swapped, 4504); + assert_eq!(result.new_destination_amount, 45496); assert_eq!(result.trade_fee, 1); assert_eq!(result.owner_fee, 0); } @@ -289,8 +294,8 @@ mod tests { .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); assert_eq!(result.new_source_amount, 1100); - assert_eq!(result.amount_swapped, 4505); - assert_eq!(result.new_destination_amount, 45495); + assert_eq!(result.destination_amount_swapped, 4504); + assert_eq!(result.new_destination_amount, 45496); assert_eq!(result.trade_fee, 0); assert_eq!(result.owner_fee, 1); } @@ -305,8 +310,8 @@ mod tests { .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); assert_eq!(result.new_source_amount, 1100); - assert_eq!(result.amount_swapped, 4546); - assert_eq!(result.new_destination_amount, 45454); + assert_eq!(result.destination_amount_swapped, 4545); + assert_eq!(result.new_destination_amount, 45455); } #[test] @@ -347,4 +352,66 @@ mod tests { let unpacked = ConstantProductCurve::unpack(&packed).unwrap(); assert_eq!(curve, unpacked); } + + fn test_truncation( + curve: &ConstantProductCurve, + source_amount: u128, + swap_source_amount: u128, + swap_destination_amount: u128, + expected_source_amount_swapped: u128, + expected_destination_amount_swapped: u128, + ) { + let invariant = swap_source_amount * swap_destination_amount; + let result = curve + .swap(source_amount, swap_source_amount, swap_destination_amount) + .unwrap(); + assert_eq!(result.source_amount_swapped, expected_source_amount_swapped); + assert_eq!( + result.destination_amount_swapped, + expected_destination_amount_swapped + ); + let new_invariant = (swap_source_amount + result.source_amount_swapped) + * (swap_destination_amount - result.destination_amount_swapped); + assert!(new_invariant >= invariant); + } + + #[test] + fn constant_product_swap_rounding() { + let curve = ConstantProductCurve::default(); + + // much too small + assert!(curve + .swap_without_fees(10, 70_000_000_000, 4_000_000) + .is_none()); // spot: 10 * 4m / 70b = 0 + + let tests: &[(u128, u128, u128, u128, u128)] = &[ + (10, 4_000_000, 70_000_000_000, 10, 174_999), // spot: 10 * 70b / ~4m = 174,999.99 + (20, 30_000 - 20, 10_000, 18, 6), // spot: 20 * 1 / 3.000 = 6.6667 (source can be 18 to get 6 dest.) + (19, 30_000 - 20, 10_000, 18, 6), // spot: 19 * 1 / 2.999 = 6.3334 (source can be 18 to get 6 dest.) + (18, 30_000 - 20, 10_000, 18, 6), // spot: 18 * 1 / 2.999 = 6.0001 + (10, 20_000, 30_000, 10, 14), // spot: 10 * 3 / 2.0010 = 14.99 + (10, 20_000 - 9, 30_000, 10, 14), // spot: 10 * 3 / 2.0001 = 14.999 + (10, 20_000 - 10, 30_000, 10, 15), // spot: 10 * 3 / 2.0000 = 15 + (100, 60_000, 30_000, 99, 49), // spot: 100 * 3 / 6.001 = 49.99 (source can be 99 to get 49 dest.) + (99, 60_000, 30_000, 99, 49), // spot: 99 * 3 / 6.001 = 49.49 + (98, 60_000, 30_000, 97, 48), // spot: 98 * 3 / 6.001 = 48.99 (source can be 97 to get 48 dest.) + ]; + for ( + source_amount, + swap_source_amount, + swap_destination_amount, + expected_source_amount, + expected_destination_amount, + ) in tests.iter() + { + test_truncation( + &curve, + *source_amount, + *swap_source_amount, + *swap_destination_amount, + *expected_source_amount, + *expected_destination_amount, + ); + } + } } diff --git a/token-swap/program/src/curve/flat.rs b/token-swap/program/src/curve/flat.rs index 76cd31387dda62..3746650697f609 100644 --- a/token-swap/program/src/curve/flat.rs +++ b/token-swap/program/src/curve/flat.rs @@ -5,7 +5,7 @@ use solana_program::{ program_pack::{IsInitialized, Pack, Sealed}, }; -use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapResult}; +use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapWithoutFeesResult}; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use std::convert::TryFrom; @@ -32,29 +32,15 @@ pub struct FlatCurve { impl CurveCalculator for FlatCurve { /// Flat curve swap always returns 1:1 (minus fee) - fn swap( + fn swap_without_fees( &self, source_amount: u128, - swap_source_amount: u128, - swap_destination_amount: u128, - ) -> Option { - // debit the fee to calculate the amount swapped - let trade_fee = self.trading_fee(source_amount)?; - let owner_fee = self.owner_trading_fee(source_amount)?; - - let amount_swapped = source_amount - .checked_sub(trade_fee)? - .checked_sub(owner_fee)?; - let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?; - - // actually add the whole amount coming in - let new_source_amount = swap_source_amount.checked_add(source_amount)?; - Some(SwapResult { - new_source_amount, - new_destination_amount, - amount_swapped, - trade_fee, - owner_fee, + _swap_source_amount: u128, + _swap_destination_amount: u128, + ) -> Option { + Some(SwapWithoutFeesResult { + source_amount_swapped: source_amount, + destination_amount_swapped: source_amount, }) } @@ -191,7 +177,7 @@ mod tests { .unwrap(); let amount_swapped = 97; assert_eq!(result.new_source_amount, 1100); - assert_eq!(result.amount_swapped, amount_swapped); + assert_eq!(result.destination_amount_swapped, amount_swapped); assert_eq!(result.trade_fee, 1); assert_eq!(result.owner_fee, 2); assert_eq!( diff --git a/token-swap/program/src/curve/stable.rs b/token-swap/program/src/curve/stable.rs index e613e2232c6c6b..c7cd7b13988152 100644 --- a/token-swap/program/src/curve/stable.rs +++ b/token-swap/program/src/curve/stable.rs @@ -6,7 +6,7 @@ use solana_program::{ program_pack::{IsInitialized, Pack, Sealed}, }; -use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapResult}; +use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapWithoutFeesResult}; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use std::convert::TryFrom; @@ -128,38 +128,26 @@ fn compute_new_destination_amount( impl CurveCalculator for StableCurve { /// Stable curve - fn swap( + fn swap_without_fees( &self, source_amount: u128, swap_source_amount: u128, swap_destination_amount: u128, - ) -> Option { - let trade_fee = self.trading_fee(source_amount)?; - let owner_fee = self.owner_trading_fee(source_amount)?; - - let new_source_amount_less_fee = swap_source_amount - .checked_add(source_amount)? - .checked_sub(trade_fee)? - .checked_sub(owner_fee)?; - + ) -> Option { let leverage = self.amp.checked_mul(N_COINS as u64)?; + let new_source_amount = swap_source_amount.checked_add(source_amount)?; let new_destination_amount = compute_new_destination_amount( leverage, - new_source_amount_less_fee, + new_source_amount, compute_d(leverage, swap_source_amount, swap_destination_amount)?, )?; - //let amount_swapped = - // map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount.as_u128())?)?; let amount_swapped = swap_destination_amount.checked_sub(new_destination_amount)?; - let new_source_amount = swap_source_amount.checked_add(source_amount)?; - Some(SwapResult { - new_source_amount, - new_destination_amount, - amount_swapped, - trade_fee, - owner_fee, + + Some(SwapWithoutFeesResult { + source_amount_swapped: source_amount, + destination_amount_swapped: amount_swapped, }) } @@ -364,7 +352,7 @@ mod tests { .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); assert_eq!(result.new_source_amount, 1_100); - assert_eq!(result.amount_swapped, 2_063); + assert_eq!(result.destination_amount_swapped, 2_063); assert_eq!(result.new_destination_amount, 47_937); assert_eq!(result.trade_fee, 1); assert_eq!(result.owner_fee, 0); @@ -399,7 +387,7 @@ mod tests { .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); assert_eq!(result.new_source_amount, 1100); - assert_eq!(result.amount_swapped, 2024); + assert_eq!(result.destination_amount_swapped, 2024); assert_eq!(result.new_destination_amount, 47976); assert_eq!(result.trade_fee, 1); assert_eq!(result.owner_fee, 2); @@ -446,7 +434,7 @@ mod tests { } println!( - "trying: source_amount={}, swap_source_amount={}, swap_destination_amount={}", + "trying: source_amount={}, swap_source_amount={}, swap_destination_amount={}", source_amount, swap_source_amount, swap_destination_amount ); @@ -467,14 +455,15 @@ mod tests { println!( "result={}, sim_result={}", - result.amount_swapped, sim_result + result.destination_amount_swapped, sim_result ); - let diff = (sim_result as i128 - result.amount_swapped as i128).abs(); + let diff = + (sim_result as i128 - result.destination_amount_swapped as i128).abs(); assert!( diff <= 1, "result={}, sim_result={}, amp={}, source_amount={}, swap_source_amount={}, swap_destination_amount={}", - result.amount_swapped, + result.destination_amount_swapped, sim_result, amp, source_amount, diff --git a/token-swap/program/src/error.rs b/token-swap/program/src/error.rs index 9d7a5e2cabd18d..a34e92d239dd1c 100644 --- a/token-swap/program/src/error.rs +++ b/token-swap/program/src/error.rs @@ -79,6 +79,9 @@ pub enum SwapError { /// The provided fee does not match the program owner's constraints #[error("The provided fee does not match the program owner's constraints")] InvalidFee, + /// The provided token program does not match the token program expected by the swap + #[error("The provided token program does not match the token program expected by the swap")] + IncorrectTokenProgramId, } impl From for ProgramError { fn from(e: SwapError) -> Self { diff --git a/token-swap/program/src/processor.rs b/token-swap/program/src/processor.rs index 193bcc64ab8732..ee3c3537355a85 100644 --- a/token-swap/program/src/processor.rs +++ b/token-swap/program/src/processor.rs @@ -29,13 +29,29 @@ const TOKENS_IN_POOL: u64 = 2; pub struct Processor {} impl Processor { /// Unpacks a spl_token `Account`. - pub fn unpack_token_account(data: &[u8]) -> Result { - spl_token::state::Account::unpack(data).map_err(|_| SwapError::ExpectedAccount) + pub fn unpack_token_account( + account_info: &AccountInfo, + token_program_id: &Pubkey, + ) -> Result { + if account_info.owner != token_program_id { + Err(SwapError::IncorrectTokenProgramId) + } else { + spl_token::state::Account::unpack(&account_info.data.borrow()) + .map_err(|_| SwapError::ExpectedAccount) + } } /// Unpacks a spl_token `Mint`. - pub fn unpack_mint(data: &[u8]) -> Result { - spl_token::state::Mint::unpack(data).map_err(|_| SwapError::ExpectedMint) + pub fn unpack_mint( + account_info: &AccountInfo, + token_program_id: &Pubkey, + ) -> Result { + if account_info.owner != token_program_id { + Err(SwapError::IncorrectTokenProgramId) + } else { + spl_token::state::Mint::unpack(&account_info.data.borrow()) + .map_err(|_| SwapError::ExpectedMint) + } } /// Calculates the authority id by generating a program address. @@ -149,6 +165,7 @@ impl Processor { let destination_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; + let token_program_id = *token_program_info.key; let token_swap = SwapInfo::unpack_unchecked(&swap_info.data.borrow())?; if token_swap.is_initialized { return Err(SwapError::AlreadyInUse.into()); @@ -157,11 +174,11 @@ impl Processor { if *authority_info.key != Self::authority_id(program_id, swap_info.key, nonce)? { return Err(SwapError::InvalidProgramAddress.into()); } - let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?; - let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; - let fee_account = Self::unpack_token_account(&fee_account_info.data.borrow())?; - let destination = Self::unpack_token_account(&destination_info.data.borrow())?; - let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; + let token_a = Self::unpack_token_account(token_a_info, &token_program_id)?; + let token_b = Self::unpack_token_account(token_b_info, &token_program_id)?; + let fee_account = Self::unpack_token_account(fee_account_info, &token_program_id)?; + let destination = Self::unpack_token_account(destination_info, &token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_program_id)?; if *authority_info.key != token_a.owner { return Err(SwapError::InvalidOwner.into()); } @@ -236,7 +253,7 @@ impl Processor { let obj = SwapInfo { is_initialized: true, nonce, - token_program_id: *token_program_info.key, + token_program_id, token_a: *token_a_info.key, token_b: *token_b_info.key, pool_mint: *pool_mint_info.key, @@ -267,6 +284,9 @@ impl Processor { let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; + if swap_info.owner != program_id { + return Err(ProgramError::IncorrectProgramId); + } let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; if *authority_info.key != Self::authority_id(program_id, swap_info.key, token_swap.nonce)? { @@ -297,10 +317,15 @@ impl Processor { if *pool_fee_account_info.key != token_swap.pool_fee_account { return Err(SwapError::IncorrectFeeAccount.into()); } + if *token_program_info.key != token_swap.token_program_id { + return Err(SwapError::IncorrectTokenProgramId.into()); + } - let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; - let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?; - let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; + let source_account = + Self::unpack_token_account(swap_source_info, &token_swap.token_program_id)?; + let dest_account = + Self::unpack_token_account(swap_destination_info, &token_swap.token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; let result = token_swap .swap_curve @@ -311,7 +336,7 @@ impl Processor { to_u128(dest_account.amount)?, ) .ok_or(SwapError::ZeroTradingTokens)?; - if result.amount_swapped < to_u128(minimum_amount_out)? { + if result.destination_amount_swapped < to_u128(minimum_amount_out)? { return Err(SwapError::ExceededSlippage.into()); } Self::token_transfer( @@ -321,7 +346,7 @@ impl Processor { swap_source_info.clone(), authority_info.clone(), token_swap.nonce, - amount_in, + to_u64(result.source_amount_swapped)?, )?; Self::token_transfer( swap_info.key, @@ -330,11 +355,12 @@ impl Processor { destination_info.clone(), authority_info.clone(), token_swap.nonce, - to_u64(result.amount_swapped)?, + to_u64(result.destination_amount_swapped)?, )?; // mint pool tokens equivalent to the owner fee - let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; + let source_account = + Self::unpack_token_account(swap_source_info, &token_swap.token_program_id)?; let mut pool_token_amount = token_swap .swap_curve .calculator @@ -348,8 +374,10 @@ impl Processor { if pool_token_amount > 0 { // Allow error to fall through if let Ok(host_fee_account_info) = next_account_info(account_info_iter) { - let host_fee_account = - Self::unpack_token_account(&host_fee_account_info.data.borrow())?; + let host_fee_account = Self::unpack_token_account( + host_fee_account_info, + &token_swap.token_program_id, + )?; if *pool_mint_info.key != host_fee_account.mint { return Err(SwapError::IncorrectPoolMint.into()); } @@ -405,6 +433,9 @@ impl Processor { let dest_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; + if swap_info.owner != program_id { + return Err(ProgramError::IncorrectProgramId); + } let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; if *authority_info.key != Self::authority_id(program_id, swap_info.key, token_swap.nonce)? { return Err(SwapError::InvalidProgramAddress.into()); @@ -424,10 +455,13 @@ impl Processor { if token_b_info.key == source_b_info.key { return Err(SwapError::InvalidInput.into()); } + if *token_program_info.key != token_swap.token_program_id { + return Err(SwapError::IncorrectTokenProgramId.into()); + } - let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?; - let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; - let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; + let token_a = Self::unpack_token_account(token_a_info, &token_swap.token_program_id)?; + let token_b = Self::unpack_token_account(token_b_info, &token_swap.token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; let pool_token_amount = to_u128(pool_token_amount)?; let pool_mint_supply = to_u128(pool_mint.supply)?; @@ -505,6 +539,9 @@ impl Processor { let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; + if swap_info.owner != program_id { + return Err(ProgramError::IncorrectProgramId); + } let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; if *authority_info.key != Self::authority_id(program_id, swap_info.key, token_swap.nonce)? { return Err(SwapError::InvalidProgramAddress.into()); @@ -527,10 +564,13 @@ impl Processor { if token_b_info.key == dest_token_b_info.key { return Err(SwapError::InvalidInput.into()); } + if *token_program_info.key != token_swap.token_program_id { + return Err(SwapError::IncorrectTokenProgramId.into()); + } - let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?; - let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; - let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; + let token_a = Self::unpack_token_account(token_a_info, &token_swap.token_program_id)?; + let token_b = Self::unpack_token_account(token_b_info, &token_swap.token_program_id)?; + let pool_mint = Self::unpack_mint(pool_mint_info, &token_swap.token_program_id)?; let calculator = token_swap.swap_curve.calculator; @@ -719,6 +759,9 @@ impl PrintProgramError for SwapError { SwapError::InvalidFee => { msg!("Error: The provided fee does not match the program owner's constraints") } + SwapError::IncorrectTokenProgramId => { + msg!("Error: The provided token program does not match the token program expected by the swap") + } } } } @@ -1437,7 +1480,7 @@ mod tests { // uninitialized token a account { let old_account = accounts.token_a_account; - accounts.token_a_account = Account::default(); + accounts.token_a_account = Account::new(0, 0, &TOKEN_PROGRAM_ID); assert_eq!( Err(SwapError::ExpectedAccount.into()), accounts.initialize_swap() @@ -1448,7 +1491,7 @@ mod tests { // uninitialized token b account { let old_account = accounts.token_b_account; - accounts.token_b_account = Account::default(); + accounts.token_b_account = Account::new(0, 0, &TOKEN_PROGRAM_ID); assert_eq!( Err(SwapError::ExpectedAccount.into()), accounts.initialize_swap() @@ -1459,7 +1502,7 @@ mod tests { // uninitialized pool mint { let old_account = accounts.pool_mint_account; - accounts.pool_mint_account = Account::default(); + accounts.pool_mint_account = Account::new(0, 0, &TOKEN_PROGRAM_ID); assert_eq!( Err(SwapError::ExpectedMint.into()), accounts.initialize_swap() @@ -1569,6 +1612,46 @@ mod tests { accounts.pool_mint_account = old_mint; } + // token A account owned by wrong program + { + let (_token_a_key, mut token_a_account) = mint_token( + &TOKEN_PROGRAM_ID, + &accounts.token_a_mint_key, + &mut accounts.token_a_mint_account, + &user_key, + &accounts.authority_key, + token_a_amount, + ); + token_a_account.owner = SWAP_PROGRAM_ID; + let old_account = accounts.token_a_account; + accounts.token_a_account = token_a_account; + assert_eq!( + Err(SwapError::IncorrectTokenProgramId.into()), + accounts.initialize_swap() + ); + accounts.token_a_account = old_account; + } + + // token B account owned by wrong program + { + let (_token_b_key, mut token_b_account) = mint_token( + &TOKEN_PROGRAM_ID, + &accounts.token_b_mint_key, + &mut accounts.token_b_mint_account, + &user_key, + &accounts.authority_key, + token_b_amount, + ); + token_b_account.owner = SWAP_PROGRAM_ID; + let old_account = accounts.token_b_account; + accounts.token_b_account = token_b_account; + assert_eq!( + Err(SwapError::IncorrectTokenProgramId.into()), + accounts.initialize_swap() + ); + accounts.token_b_account = old_account; + } + // empty token A account { let (_token_a_key, token_a_account) = mint_token( @@ -1819,7 +1902,7 @@ mod tests { { let wrong_program_id = Pubkey::new_unique(); assert_eq!( - Err(ProgramError::InvalidAccountData), + Err(SwapError::IncorrectTokenProgramId.into()), do_process_instruction( initialize( &SWAP_PROGRAM_ID, @@ -2103,13 +2186,13 @@ mod tests { assert_eq!(swap_info.token_a_mint, accounts.token_a_mint_key); assert_eq!(swap_info.token_b_mint, accounts.token_b_mint_key); assert_eq!(swap_info.pool_fee_account, accounts.pool_fee_key); - let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); assert_eq!(token_a.amount, token_a_amount); - let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); assert_eq!(token_b.amount, token_b_amount); let pool_account = - Processor::unpack_token_account(&accounts.pool_token_account.data).unwrap(); - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.pool_token_account.data).unwrap(); + let pool_mint = spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); assert_eq!(pool_mint.supply, pool_account.amount); } @@ -2179,6 +2262,38 @@ mod tests { accounts.initialize_swap().unwrap(); + // wrong owner for swap account + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + pool_key, + mut pool_account, + ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); + let old_swap_account = accounts.swap_account; + let mut wrong_swap_account = old_swap_account.clone(); + wrong_swap_account.owner = TOKEN_PROGRAM_ID; + accounts.swap_account = wrong_swap_account; + assert_eq!( + Err(ProgramError::IncorrectProgramId), + accounts.deposit( + &depositor_key, + &token_a_key, + &mut token_a_account, + &token_b_key, + &mut token_b_account, + &pool_key, + &mut pool_account, + pool_amount.try_into().unwrap(), + deposit_a, + deposit_b, + ) + ); + accounts.swap_account = old_swap_account; + } + // wrong nonce for authority_key { let ( @@ -2399,7 +2514,7 @@ mod tests { ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); let wrong_key = Pubkey::new_unique(); assert_eq!( - Err(ProgramError::InvalidAccountData), + Err(SwapError::IncorrectTokenProgramId.into()), do_process_instruction( deposit( &SWAP_PROGRAM_ID, @@ -2663,19 +2778,20 @@ mod tests { .unwrap(); let swap_token_a = - Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); assert_eq!(swap_token_a.amount, deposit_a + token_a_amount); let swap_token_b = - Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); assert_eq!(swap_token_b.amount, deposit_b + token_b_amount); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap(); assert_eq!(token_a.amount, 0); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap(); assert_eq!(token_b.amount, 0); - let pool_account = Processor::unpack_token_account(&pool_account.data).unwrap(); + let pool_account = spl_token::state::Account::unpack(&pool_account.data).unwrap(); let swap_pool_account = - Processor::unpack_token_account(&accounts.pool_token_account.data).unwrap(); - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.pool_token_account.data).unwrap(); + let pool_mint = + spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); assert_eq!( pool_mint.supply, pool_account.amount + swap_pool_account.amount @@ -2752,6 +2868,38 @@ mod tests { accounts.initialize_swap().unwrap(); + // wrong owner for swap account + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + pool_key, + mut pool_account, + ) = accounts.setup_token_accounts(&user_key, &withdrawer_key, initial_a, initial_b, 0); + let old_swap_account = accounts.swap_account; + let mut wrong_swap_account = old_swap_account.clone(); + wrong_swap_account.owner = TOKEN_PROGRAM_ID; + accounts.swap_account = wrong_swap_account; + assert_eq!( + Err(ProgramError::IncorrectProgramId), + accounts.withdraw( + &withdrawer_key, + &pool_key, + &mut pool_account, + &token_a_key, + &mut token_a_account, + &token_b_key, + &mut token_b_account, + withdraw_amount.try_into().unwrap(), + minimum_token_a_amount, + minimum_token_b_amount, + ) + ); + accounts.swap_account = old_swap_account; + } + // wrong nonce for authority_key { let ( @@ -3024,7 +3172,7 @@ mod tests { ); let wrong_key = Pubkey::new_unique(); assert_eq!( - Err(ProgramError::InvalidAccountData), + Err(SwapError::IncorrectTokenProgramId.into()), do_process_instruction( withdraw( &SWAP_PROGRAM_ID, @@ -3342,10 +3490,11 @@ mod tests { .unwrap(); let swap_token_a = - Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); let swap_token_b = - Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); + let pool_mint = + spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); let withdraw_fee = accounts .swap_curve .calculator @@ -3377,17 +3526,17 @@ mod tests { swap_token_b.amount, token_b_amount - to_u64(withdrawn_b).unwrap() ); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap(); assert_eq!(token_a.amount, initial_a + to_u64(withdrawn_a).unwrap()); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap(); assert_eq!(token_b.amount, initial_b + to_u64(withdrawn_b).unwrap()); - let pool_account = Processor::unpack_token_account(&pool_account.data).unwrap(); + let pool_account = spl_token::state::Account::unpack(&pool_account.data).unwrap(); assert_eq!( pool_account.amount, to_u64(initial_pool - withdraw_amount).unwrap() ); let fee_account = - Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.pool_fee_account.data).unwrap(); assert_eq!( fee_account.amount, TryInto::::try_into(withdraw_fee).unwrap() @@ -3407,7 +3556,7 @@ mod tests { let pool_fee_key = accounts.pool_fee_key; let mut pool_fee_account = accounts.pool_fee_account.clone(); - let fee_account = Processor::unpack_token_account(&pool_fee_account.data).unwrap(); + let fee_account = spl_token::state::Account::unpack(&pool_fee_account.data).unwrap(); let pool_fee_amount = fee_account.amount; accounts @@ -3426,10 +3575,11 @@ mod tests { .unwrap(); let swap_token_a = - Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); let swap_token_b = - Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); + let pool_mint = + spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); let withdrawn_a = accounts .swap_curve .calculator @@ -3439,7 +3589,7 @@ mod tests { swap_token_a.amount.try_into().unwrap(), ) .unwrap(); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap(); assert_eq!( token_a.amount, TryInto::::try_into(withdrawn_a).unwrap() @@ -3453,7 +3603,7 @@ mod tests { swap_token_b.amount.try_into().unwrap(), ) .unwrap(); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap(); assert_eq!( token_b.amount, TryInto::::try_into(withdrawn_b).unwrap() @@ -3496,7 +3646,7 @@ mod tests { // swap one way let a_to_b_amount = initial_a / 10; let minimum_token_b_amount = 0; - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + let pool_mint = spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); let initial_supply = pool_mint.supply; accounts .swap( @@ -3521,25 +3671,27 @@ mod tests { ) .unwrap(); - let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + let swap_token_a = + spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); let token_a_amount = swap_token_a.amount; assert_eq!( token_a_amount, TryInto::::try_into(results.new_source_amount).unwrap() ); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap(); assert_eq!(token_a.amount, initial_a - a_to_b_amount); - let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + let swap_token_b = + spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); let token_b_amount = swap_token_b.amount; assert_eq!( token_b_amount, TryInto::::try_into(results.new_destination_amount).unwrap() ); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap(); assert_eq!( token_b.amount, - initial_b + to_u64(results.amount_swapped).unwrap() + initial_b + to_u64(results.destination_amount_swapped).unwrap() ); let first_fee = swap_curve @@ -3551,16 +3703,17 @@ mod tests { TOKENS_IN_POOL.try_into().unwrap(), ) .unwrap(); - let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + let fee_account = + spl_token::state::Account::unpack(&accounts.pool_fee_account.data).unwrap(); assert_eq!( fee_account.amount, TryInto::::try_into(first_fee).unwrap() ); - let first_swap_amount = results.amount_swapped; + let first_swap_amount = results.destination_amount_swapped; // swap the other way - let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + let pool_mint = spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap(); let initial_supply = pool_mint.supply; let b_to_a_amount = initial_b / 10; @@ -3588,28 +3741,31 @@ mod tests { ) .unwrap(); - let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + let swap_token_a = + spl_token::state::Account::unpack(&accounts.token_a_account.data).unwrap(); let token_a_amount = swap_token_a.amount; assert_eq!( token_a_amount, TryInto::::try_into(results.new_destination_amount).unwrap() ); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap(); assert_eq!( token_a.amount, - initial_a - a_to_b_amount + to_u64(results.amount_swapped).unwrap() + initial_a - a_to_b_amount + to_u64(results.destination_amount_swapped).unwrap() ); - let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + let swap_token_b = + spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap(); let token_b_amount = swap_token_b.amount; assert_eq!( token_b_amount, TryInto::::try_into(results.new_source_amount).unwrap() ); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap(); assert_eq!( token_b.amount, - initial_b + to_u64(first_swap_amount).unwrap() - b_to_a_amount + initial_b + to_u64(first_swap_amount).unwrap() + - to_u64(results.source_amount_swapped).unwrap() ); let second_fee = swap_curve @@ -3621,7 +3777,8 @@ mod tests { TOKENS_IN_POOL.try_into().unwrap(), ) .unwrap(); - let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + let fee_account = + spl_token::state::Account::unpack(&accounts.pool_fee_account.data).unwrap(); assert_eq!(fee_account.amount, to_u64(first_fee + second_fee).unwrap()); } @@ -3811,9 +3968,9 @@ mod tests { .unwrap(); // check that fees were taken in the host fee account - let host_fee_account = Processor::unpack_token_account(&pool_account.data).unwrap(); + let host_fee_account = spl_token::state::Account::unpack(&pool_account.data).unwrap(); let owner_fee_account = - Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + spl_token::state::Account::unpack(&accounts.pool_fee_account.data).unwrap(); let total_fee = host_fee_account.amount * host_fee_denominator / host_fee_numerator; assert_eq!( total_fee, @@ -3888,6 +4045,37 @@ mod tests { accounts.initialize_swap().unwrap(); + // wrong swap account program id + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + _pool_key, + _pool_account, + ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); + let old_swap_account = accounts.swap_account; + let mut wrong_swap_account = old_swap_account.clone(); + wrong_swap_account.owner = TOKEN_PROGRAM_ID; + accounts.swap_account = wrong_swap_account; + assert_eq!( + Err(ProgramError::IncorrectProgramId), + accounts.swap( + &swapper_key, + &token_a_key, + &mut token_a_account, + &swap_token_a_key, + &swap_token_b_key, + &token_b_key, + &mut token_b_account, + initial_a, + minimum_token_b_amount, + ) + ); + accounts.swap_account = old_swap_account; + } + // wrong nonce { let ( @@ -3933,7 +4121,7 @@ mod tests { ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); let wrong_program_id = Pubkey::new_unique(); assert_eq!( - Err(ProgramError::InvalidAccountData), + Err(SwapError::IncorrectTokenProgramId.into()), do_process_instruction( swap( &SWAP_PROGRAM_ID,