From db42f7abbea2451898850b62583bf05c1bb4a3a1 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Wed, 21 Oct 2020 20:46:50 +0200 Subject: [PATCH] token-swap: Add curve and pool token trait and integrate in processor (#624) * Add SwapCurve trait and integrate in processor * Add PoolTokenConverter trait to correspond * Add curve type parameter to initialization and JS * Refactor for flat curve test, fmt * Update token-swap/program/src/curve.rs Co-authored-by: Tyera Eulberg * Refactor swap curve to allow for any implementation * Rename SwapCurveWrapper -> SwapCurve * Run cargo fmt * Update CurveType to enum in JS Co-authored-by: Tyera Eulberg --- token-swap/js/cli/token-swap-test.js | 8 +- token-swap/js/client/token-swap.js | 33 +- token-swap/js/module.d.ts | 6 +- token-swap/js/module.flow.js | 7 +- token-swap/program/src/curve.rs | 529 ++++++++++++++++++++------ token-swap/program/src/instruction.rs | 68 ++-- token-swap/program/src/processor.rs | 372 ++++++++++-------- token-swap/program/src/state.rs | 66 ++-- 8 files changed, 736 insertions(+), 353 deletions(-) diff --git a/token-swap/js/cli/token-swap-test.js b/token-swap/js/cli/token-swap-test.js index f4031a9591ebd7..6806eb9a37aa4a 100644 --- a/token-swap/js/cli/token-swap-test.js +++ b/token-swap/js/cli/token-swap-test.js @@ -10,7 +10,7 @@ import { } from '@solana/web3.js'; import {Token} from '../../../token/js/client/token'; -import {TokenSwap} from '../client/token-swap'; +import {TokenSwap, CurveType} from '../client/token-swap'; import {Store} from '../client/util/store'; import {newAccountWithLamports} from '../client/util/new-account-with-lamports'; import {url} from '../url'; @@ -34,6 +34,8 @@ let mintB: Token; let tokenAccountA: PublicKey; let tokenAccountB: PublicKey; +// curve type used to calculate swaps and deposits +const CURVE_TYPE = CurveType.ConstantProduct; // Initial amount in each swap token const BASE_AMOUNT = 1000; // Amount passed to swap instruction @@ -194,13 +196,14 @@ export async function createTokenSwap(): Promise { swapPayer, tokenSwapAccount, authority, - nonce, tokenAccountA, tokenAccountB, tokenPool.publicKey, tokenAccountPool, tokenSwapProgramId, tokenProgramId, + nonce, + CURVE_TYPE, 1, 4, ); @@ -217,6 +220,7 @@ export async function createTokenSwap(): Promise { assert(fetchedTokenSwap.tokenAccountA.equals(tokenAccountA)); assert(fetchedTokenSwap.tokenAccountB.equals(tokenAccountB)); assert(fetchedTokenSwap.poolToken.equals(tokenPool.publicKey)); + assert(CURVE_TYPE == fetchedTokenSwap.curveType); assert(1 == fetchedTokenSwap.feeNumerator.toNumber()); assert(4 == fetchedTokenSwap.feeDenominator.toNumber()); } diff --git a/token-swap/js/client/token-swap.js b/token-swap/js/client/token-swap.js index 12506045d541c4..a5f6665315e8b7 100644 --- a/token-swap/js/client/token-swap.js +++ b/token-swap/js/client/token-swap.js @@ -64,11 +64,18 @@ export const TokenSwapLayout: typeof BufferLayout.Structure = BufferLayout.struc Layout.publicKey('tokenAccountA'), Layout.publicKey('tokenAccountB'), Layout.publicKey('tokenPool'), + BufferLayout.u8('curveType'), Layout.uint64('feeNumerator'), Layout.uint64('feeDenominator'), + BufferLayout.blob(48, 'padding'), ], ); +export const CurveType = Object.freeze({ + ConstantProduct: 0, // Constant product curve, Uniswap-style + Flat: 1, // Flat curve, always 1:1 trades +}); + /** * A program to exchange tokens against a pool of liquidity */ @@ -123,6 +130,11 @@ export class TokenSwap { */ feeDenominator: Numberu64; + /** + * CurveType, current options are: + */ + curveType: number; + /** * Fee payer */ @@ -150,6 +162,7 @@ export class TokenSwap { authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + curveType: number, feeNumerator: Numberu64, feeDenominator: Numberu64, payer: Account, @@ -163,6 +176,7 @@ export class TokenSwap { authority, tokenAccountA, tokenAccountB, + curveType, feeNumerator, feeDenominator, payer, @@ -185,13 +199,14 @@ export class TokenSwap { static createInitSwapInstruction( tokenSwapAccount: Account, authority: PublicKey, - nonce: number, tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, swapProgramId: PublicKey, + nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, ): TransactionInstruction { @@ -206,18 +221,21 @@ export class TokenSwap { ]; const commandDataLayout = BufferLayout.struct([ BufferLayout.u8('instruction'), + BufferLayout.u8('nonce'), + BufferLayout.u8('curveType'), BufferLayout.nu64('feeNumerator'), BufferLayout.nu64('feeDenominator'), - BufferLayout.u8('nonce'), + BufferLayout.blob(48, 'padding'), ]); let data = Buffer.alloc(1024); { const encodeLength = commandDataLayout.encode( { instruction: 0, // InitializeSwap instruction + nonce, + curveType, feeNumerator, feeDenominator, - nonce, }, data, ); @@ -254,6 +272,7 @@ export class TokenSwap { const feeNumerator = Numberu64.fromBuffer(tokenSwapData.feeNumerator); const feeDenominator = Numberu64.fromBuffer(tokenSwapData.feeDenominator); + const curveType = tokenSwapData.curveType; return new TokenSwap( connection, @@ -264,6 +283,7 @@ export class TokenSwap { authority, tokenAccountA, tokenAccountB, + curveType, feeNumerator, feeDenominator, payer, @@ -293,13 +313,14 @@ export class TokenSwap { payer: Account, tokenSwapAccount: Account, authority: PublicKey, - nonce: number, tokenAccountA: PublicKey, tokenAccountB: PublicKey, poolToken: PublicKey, tokenAccountPool: PublicKey, swapProgramId: PublicKey, tokenProgramId: PublicKey, + nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, ): Promise { @@ -313,6 +334,7 @@ export class TokenSwap { authority, tokenAccountA, tokenAccountB, + curveType, new Numberu64(feeNumerator), new Numberu64(feeDenominator), payer, @@ -336,13 +358,14 @@ export class TokenSwap { const instruction = TokenSwap.createInitSwapInstruction( tokenSwapAccount, authority, - nonce, tokenAccountA, tokenAccountB, poolToken, tokenAccountPool, tokenProgramId, swapProgramId, + nonce, + curveType, feeNumerator, feeDenominator, ); diff --git a/token-swap/js/module.d.ts b/token-swap/js/module.d.ts index 0cb1a989ece5f7..862c5da58edb06 100644 --- a/token-swap/js/module.d.ts +++ b/token-swap/js/module.d.ts @@ -17,6 +17,7 @@ declare module '@solana/spl-token-swap' { } export const TokenSwapLayout: Layout; + export const CurveType: Object; export class TokenSwap { constructor( @@ -28,6 +29,7 @@ declare module '@solana/spl-token-swap' { authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + curveType: number, feeNumerator: Numberu64, feeDenominator: Numberu64, payer: Account, @@ -40,13 +42,14 @@ declare module '@solana/spl-token-swap' { static createInitSwapInstruction( tokenSwapAccount: Account, authority: PublicKey, - nonce: number, tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, swapProgramId: PublicKey, + nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, ): TransactionInstruction; @@ -69,6 +72,7 @@ declare module '@solana/spl-token-swap' { tokenAccountPool: PublicKey, tokenProgramId: PublicKey, nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, swapProgramId: PublicKey, diff --git a/token-swap/js/module.flow.js b/token-swap/js/module.flow.js index 5212163df295ce..a977521f6d1d3b 100644 --- a/token-swap/js/module.flow.js +++ b/token-swap/js/module.flow.js @@ -14,6 +14,8 @@ declare module '@solana/spl-token-swap' { declare export var TokenSwapLayout: Layout; + declare export var CurveType: Object; + declare export class TokenSwap { constructor( connection: Connection, @@ -24,6 +26,7 @@ declare module '@solana/spl-token-swap' { authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + curveType: number, feeNumerator: Numberu64, feeDenominator: Numberu64, payer: Account, @@ -37,12 +40,13 @@ declare module '@solana/spl-token-swap' { programId: PublicKey, tokenSwapAccount: Account, authority: PublicKey, - nonce: number, tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, + nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, ): TransactionInstruction; @@ -65,6 +69,7 @@ declare module '@solana/spl-token-swap' { tokenAccountPool: PublicKey, tokenProgramId: PublicKey, nonce: number, + curveType: number, feeNumerator: number, feeDenominator: number, programId: PublicKey, diff --git a/token-swap/program/src/curve.rs b/token-swap/program/src/curve.rs index 76e9006dc68c0b..b2bf4739c04843 100644 --- a/token-swap/program/src/curve.rs +++ b/token-swap/program/src/curve.rs @@ -1,11 +1,153 @@ //! Swap calculations and curve implementations +use solana_sdk::{ + program_error::ProgramError, + program_pack::{IsInitialized, Pack, Sealed}, +}; + +use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; +use std::convert::{TryFrom, TryInto}; +use std::fmt::Debug; + /// Initial amount of pool tokens for swap contract, hard-coded to something /// "sensible" given a maximum of u64. /// Note that on Ethereum, Uniswap uses the geometric mean of all provided /// input amounts, and Balancer uses 100 * 10 ^ 18. pub const INITIAL_SWAP_POOL_AMOUNT: u64 = 1_000_000_000; +/// Curve types supported by the token-swap program. +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum CurveType { + /// Uniswap-style constant product curve, invariant = token_a_amount * token_b_amount + ConstantProduct, + /// Flat line, always providing 1:1 from one token to another + Flat, +} + +/// Concrete struct to wrap around the trait object which performs calculation. +#[repr(C)] +#[derive(Debug)] +pub struct SwapCurve { + /// The type of curve contained in the calculator, helpful for outside + /// queries + pub curve_type: CurveType, + /// The actual calculator, represented as a trait object to allow for many + /// different types of curves + pub calculator: Box, +} + +/// Default implementation for SwapCurve cannot be derived because of +/// the contained Box. +impl Default for SwapCurve { + fn default() -> Self { + let curve_type: CurveType = Default::default(); + let calculator: ConstantProductCurve = Default::default(); + Self { + curve_type, + calculator: Box::new(calculator), + } + } +} + +/// Simple implementation for PartialEq which assumes that the output of +/// `Pack` is enough to guarantee equality +impl PartialEq for SwapCurve { + fn eq(&self, other: &Self) -> bool { + let mut packed_self = [0u8; Self::LEN]; + Self::pack_into_slice(self, &mut packed_self); + let mut packed_other = [0u8; Self::LEN]; + Self::pack_into_slice(other, &mut packed_other); + packed_self[..] == packed_other[..] + } +} + +impl Sealed for SwapCurve {} +impl Pack for SwapCurve { + /// Size of encoding of all curve parameters, which include fees and any other + /// constants used to calculate swaps, deposits, and withdrawals. + /// This includes 1 byte for the type, and 64 for the calculator to use as + /// it needs. Some calculators may be smaller than 64 bytes. + const LEN: usize = 65; + + /// Unpacks a byte buffer into a SwapCurve + fn unpack_from_slice(input: &[u8]) -> Result { + let input = array_ref![input, 0, 65]; + #[allow(clippy::ptr_offset_with_cast)] + let (curve_type, calculator) = array_refs![input, 1, 64]; + let curve_type = curve_type[0].try_into()?; + Ok(Self { + curve_type, + calculator: match curve_type { + CurveType::ConstantProduct => { + Box::new(ConstantProductCurve::unpack_from_slice(calculator)?) + } + CurveType::Flat => Box::new(FlatCurve::unpack_from_slice(calculator)?), + }, + }) + } + + /// Pack SwapCurve into a byte buffer + fn pack_into_slice(&self, output: &mut [u8]) { + let output = array_mut_ref![output, 0, 65]; + let (curve_type, calculator) = mut_array_refs![output, 1, 64]; + curve_type[0] = self.curve_type as u8; + self.calculator.pack_into_slice(&mut calculator[..]); + } +} + +/// Sensible default of CurveType to ConstantProduct, the most popular and +/// well-known curve type. +impl Default for CurveType { + fn default() -> Self { + CurveType::ConstantProduct + } +} + +impl TryFrom for CurveType { + type Error = ProgramError; + + fn try_from(curve_type: u8) -> Result { + match curve_type { + 0 => Ok(CurveType::ConstantProduct), + 1 => Ok(CurveType::Flat), + _ => Err(ProgramError::InvalidAccountData), + } + } +} + +/// Trait for packing of trait objects, required because structs that implement +/// `Pack` cannot be used as trait objects (as `dyn Pack`). +pub trait DynPack { + /// Only required function is to pack given a trait object + fn pack_into_slice(&self, dst: &mut [u8]); +} + +/// Trait representing operations required on a swap curve +pub trait CurveCalculator: Debug + DynPack { + /// Calculate how much destination token will be provided given an amount + /// of source token. + fn swap( + &self, + source_amount: u64, + swap_source_amount: u64, + swap_destination_amount: u64, + ) -> Option; + + /// Get the supply of a new pool (can be a default amount or calculated + /// based on parameters) + fn new_pool_supply(&self) -> u64; + + /// Get the amount of liquidity tokens for pool tokens given the total amount + /// of liquidity tokens in the pool + fn liquidity_tokens( + &self, + pool_tokens: u64, + pool_token_supply: u64, + total_liquidity_tokens: u64, + ) -> Option; +} + /// Encodes all results of swapping from a source token to a destination token pub struct SwapResult { /// New amount of source token @@ -16,30 +158,42 @@ pub struct SwapResult { pub amount_swapped: u64, } -impl SwapResult { - /// SwapResult for swap from one currency into another, given pool information - /// and fee - pub fn swap_to( +/// Helper function for mapping to SwapError::CalculationFailure +fn map_zero_to_none(x: u64) -> Option { + if x == 0 { + None + } else { + Some(x) + } +} + +/// Simple constant 1:1 swap curve, example of different swap curve implementations +#[derive(Clone, Debug, Default, PartialEq)] +pub struct FlatCurve { + /// Fee numerator + pub fee_numerator: u64, + /// Fee denominator + pub fee_denominator: u64, +} + +impl CurveCalculator for FlatCurve { + /// Flat curve swap always returns 1:1 (minus fee) + fn swap( + &self, source_amount: u64, swap_source_amount: u64, swap_destination_amount: u64, - fee_numerator: u64, - fee_denominator: u64, ) -> Option { - let invariant = swap_source_amount.checked_mul(swap_destination_amount)?; - // debit the fee to calculate the amount swapped let mut fee = source_amount - .checked_mul(fee_numerator)? - .checked_div(fee_denominator)?; + .checked_mul(self.fee_numerator)? + .checked_div(self.fee_denominator)?; if fee == 0 { fee = 1; // minimum fee of one token } - let new_source_amount_less_fee = swap_source_amount - .checked_add(source_amount)? - .checked_sub(fee)?; - let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?; - let amount_swapped = swap_destination_amount.checked_sub(new_destination_amount)?; + + let amount_swapped = source_amount.checked_sub(fee)?; + let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?; // actually add the whole amount coming in let new_source_amount = swap_source_amount.checked_add(source_amount)?; @@ -49,105 +203,154 @@ impl SwapResult { amount_swapped, }) } + + /// Balancer-style fixed initial supply + fn new_pool_supply(&self) -> u64 { + INITIAL_SWAP_POOL_AMOUNT + } + + /// Simple ratio calculation for how many liquidity tokens correspond to + /// a certain number of pool tokens + fn liquidity_tokens( + &self, + pool_tokens: u64, + pool_token_supply: u64, + total_liquidity_tokens: u64, + ) -> Option { + pool_tokens + .checked_mul(total_liquidity_tokens)? + .checked_div(pool_token_supply) + .and_then(map_zero_to_none) + } } -fn map_zero_to_none(x: u64) -> Option { - if x == 0 { - None - } else { - Some(x) +/// IsInitialized is required to use `Pack::pack` and `Pack::unpack` +impl IsInitialized for FlatCurve { + fn is_initialized(&self) -> bool { + true + } +} +impl Sealed for FlatCurve {} +impl Pack for FlatCurve { + const LEN: usize = 16; + /// Unpacks a byte buffer into a SwapCurve + fn unpack_from_slice(input: &[u8]) -> Result { + let input = array_ref![input, 0, 16]; + #[allow(clippy::ptr_offset_with_cast)] + let (fee_numerator, fee_denominator) = array_refs![input, 8, 8]; + Ok(Self { + fee_numerator: u64::from_le_bytes(*fee_numerator), + fee_denominator: u64::from_le_bytes(*fee_denominator), + }) + } + + fn pack_into_slice(&self, output: &mut [u8]) { + (self as &dyn DynPack).pack_into_slice(output); + } +} + +impl DynPack for FlatCurve { + fn pack_into_slice(&self, output: &mut [u8]) { + let output = array_mut_ref![output, 0, 16]; + let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8]; + *fee_numerator = self.fee_numerator.to_le_bytes(); + *fee_denominator = self.fee_denominator.to_le_bytes(); } } /// The Uniswap invariant calculator. -pub struct ConstantProduct { - /// Token A - pub token_a: u64, - /// Token B - pub token_b: u64, +#[derive(Clone, Debug, Default, PartialEq)] +pub struct ConstantProductCurve { /// Fee numerator pub fee_numerator: u64, /// Fee denominator pub fee_denominator: u64, } -impl ConstantProduct { - /// Swap token a to b - pub fn swap_a_to_b(&mut self, token_a: u64) -> Option { - let result = SwapResult::swap_to( - token_a, - self.token_a, - self.token_b, - self.fee_numerator, - self.fee_denominator, - )?; - self.token_a = result.new_source_amount; - self.token_b = result.new_destination_amount; - map_zero_to_none(result.amount_swapped) - } - - /// Swap token b to a - pub fn swap_b_to_a(&mut self, token_b: u64) -> Option { - let result = SwapResult::swap_to( - token_b, - self.token_b, - self.token_a, - self.fee_numerator, - self.fee_denominator, - )?; - self.token_b = result.new_source_amount; - self.token_a = result.new_destination_amount; - map_zero_to_none(result.amount_swapped) - } -} - -/// Conversions for pool tokens, how much to deposit / withdraw, along with -/// proper initialization -pub struct PoolTokenConverter { - /// Total supply - pub supply: u64, - /// Token A amount - pub token_a: u64, - /// Token B amount - pub token_b: u64, -} - -impl PoolTokenConverter { - /// Create a converter based on existing market information - pub fn new_existing(supply: u64, token_a: u64, token_b: u64) -> Self { - Self { - supply, - token_a, - token_b, +impl CurveCalculator for ConstantProductCurve { + /// Constant product swap ensures x * y = constant + fn swap( + &self, + source_amount: u64, + swap_source_amount: u64, + swap_destination_amount: u64, + ) -> Option { + let invariant = swap_source_amount.checked_mul(swap_destination_amount)?; + + // debit the fee to calculate the amount swapped + let mut fee = source_amount + .checked_mul(self.fee_numerator)? + .checked_div(self.fee_denominator)?; + if fee == 0 { + fee = 1; // minimum fee of one token } + let new_source_amount_less_fee = swap_source_amount + .checked_add(source_amount)? + .checked_sub(fee)?; + let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?; + let amount_swapped = + map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?; + + // actually add the whole amount coming in + let new_source_amount = swap_source_amount.checked_add(source_amount)?; + Some(SwapResult { + new_source_amount, + new_destination_amount, + amount_swapped, + }) } - /// Create a converter for a new pool token, no supply present yet. - /// According to Uniswap, the geometric mean protects the pool creator - /// in case the initial ratio is off the market. - pub fn new_pool(token_a: u64, token_b: u64) -> Self { - let supply = INITIAL_SWAP_POOL_AMOUNT; - Self { - supply, - token_a, - token_b, - } + /// Balancer-style supply starts at a constant. This could be modified to + /// follow the geometric mean, as done in Uniswap v2. + fn new_pool_supply(&self) -> u64 { + INITIAL_SWAP_POOL_AMOUNT } - /// A tokens for pool tokens, returns None if output is less than 0 - pub fn token_a_rate(&self, pool_tokens: u64) -> Option { + /// Simple ratio calculation to get the amount of liquidity tokens given + /// pool information + fn liquidity_tokens( + &self, + pool_tokens: u64, + pool_token_supply: u64, + total_liquidity_tokens: u64, + ) -> Option { pool_tokens - .checked_mul(self.token_a)? - .checked_div(self.supply) + .checked_mul(total_liquidity_tokens)? + .checked_div(pool_token_supply) .and_then(map_zero_to_none) } +} - /// B tokens for pool tokens, returns None is output is less than 0 - pub fn token_b_rate(&self, pool_tokens: u64) -> Option { - pool_tokens - .checked_mul(self.token_b)? - .checked_div(self.supply) - .and_then(map_zero_to_none) +/// IsInitialized is required to use `Pack::pack` and `Pack::unpack` +impl IsInitialized for ConstantProductCurve { + fn is_initialized(&self) -> bool { + true + } +} +impl Sealed for ConstantProductCurve {} +impl Pack for ConstantProductCurve { + const LEN: usize = 16; + fn unpack_from_slice(input: &[u8]) -> Result { + let input = array_ref![input, 0, 16]; + #[allow(clippy::ptr_offset_with_cast)] + let (fee_numerator, fee_denominator) = array_refs![input, 8, 8]; + Ok(Self { + fee_numerator: u64::from_le_bytes(*fee_numerator), + fee_denominator: u64::from_le_bytes(*fee_denominator), + }) + } + + fn pack_into_slice(&self, output: &mut [u8]) { + (self as &dyn DynPack).pack_into_slice(output); + } +} + +impl DynPack for ConstantProductCurve { + fn pack_into_slice(&self, output: &mut [u8]) { + let output = array_mut_ref![output, 0, 16]; + let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8]; + *fee_numerator = self.fee_numerator.to_le_bytes(); + *fee_denominator = self.fee_denominator.to_le_bytes(); } } @@ -157,48 +360,152 @@ mod tests { #[test] fn initial_pool_amount() { - let token_converter = PoolTokenConverter::new_pool(1, 5); - assert_eq!(token_converter.supply, INITIAL_SWAP_POOL_AMOUNT); + let fee_numerator = 0; + let fee_denominator = 1; + let calculator = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; + assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT); } - fn check_pool_token_a_rate( + fn check_liquidity_pool_token_rate( token_a: u64, - token_b: u64, deposit: u64, supply: u64, expected: Option, ) { - let calculator = PoolTokenConverter::new_existing(supply, token_a, token_b); - assert_eq!(calculator.token_a_rate(deposit), expected); + let fee_numerator = 0; + let fee_denominator = 1; + let calculator = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; + assert_eq!( + calculator.liquidity_tokens(deposit, supply, token_a), + expected + ); } #[test] fn issued_tokens() { - check_pool_token_a_rate(2, 50, 5, 10, Some(1)); - check_pool_token_a_rate(10, 10, 5, 10, Some(5)); - check_pool_token_a_rate(5, 100, 5, 10, Some(2)); - check_pool_token_a_rate(5, u64::MAX, 5, 10, Some(2)); - check_pool_token_a_rate(u64::MAX, u64::MAX, 5, 10, None); + check_liquidity_pool_token_rate(2, 5, 10, Some(1)); + check_liquidity_pool_token_rate(10, 5, 10, Some(5)); + check_liquidity_pool_token_rate(5, 5, 10, Some(2)); + check_liquidity_pool_token_rate(5, 5, 10, Some(2)); + check_liquidity_pool_token_rate(u64::MAX, 5, 10, None); } #[test] - fn swap_calculation() { + fn constant_product_swap_calculation() { // calculation on https://github.com/solana-labs/solana-program-library/issues/341 let swap_source_amount: u64 = 1000; let swap_destination_amount: u64 = 50000; let fee_numerator: u64 = 1; let fee_denominator: u64 = 100; let source_amount: u64 = 100; - let result = SwapResult::swap_to( - source_amount, - swap_source_amount, - swap_destination_amount, + let curve = ConstantProductCurve { fee_numerator, fee_denominator, - ) - .unwrap(); + }; + let result = curve + .swap(source_amount, swap_source_amount, swap_destination_amount) + .unwrap(); assert_eq!(result.new_source_amount, 1100); assert_eq!(result.amount_swapped, 4505); assert_eq!(result.new_destination_amount, 45495); } + + #[test] + fn flat_swap_calculation() { + let swap_source_amount: u64 = 1000; + let swap_destination_amount: u64 = 50000; + let fee_numerator: u64 = 1; + let fee_denominator: u64 = 100; + let source_amount: u64 = 100; + let curve = FlatCurve { + fee_numerator, + fee_denominator, + }; + let result = curve + .swap(source_amount, swap_source_amount, swap_destination_amount) + .unwrap(); + let amount_swapped = 99; + assert_eq!(result.new_source_amount, 1100); + assert_eq!(result.amount_swapped, amount_swapped); + assert_eq!( + result.new_destination_amount, + swap_destination_amount - amount_swapped + ); + } + + #[test] + fn pack_flat_curve() { + let fee_numerator = 1; + let fee_denominator = 4; + let curve = FlatCurve { + fee_numerator, + fee_denominator, + }; + + let mut packed = [0u8; FlatCurve::LEN]; + Pack::pack_into_slice(&curve, &mut packed[..]); + let unpacked = FlatCurve::unpack(&packed).unwrap(); + assert_eq!(curve, unpacked); + + let mut packed = vec![]; + packed.extend_from_slice(&fee_numerator.to_le_bytes()); + packed.extend_from_slice(&fee_denominator.to_le_bytes()); + let unpacked = FlatCurve::unpack(&packed).unwrap(); + assert_eq!(curve, unpacked); + } + + #[test] + fn pack_constant_product_curve() { + let fee_numerator = 1; + let fee_denominator = 4; + let curve = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; + + let mut packed = [0u8; ConstantProductCurve::LEN]; + Pack::pack_into_slice(&curve, &mut packed[..]); + let unpacked = ConstantProductCurve::unpack(&packed).unwrap(); + assert_eq!(curve, unpacked); + + let mut packed = vec![]; + packed.extend_from_slice(&fee_numerator.to_le_bytes()); + packed.extend_from_slice(&fee_denominator.to_le_bytes()); + let unpacked = ConstantProductCurve::unpack(&packed).unwrap(); + assert_eq!(curve, unpacked); + } + + #[test] + fn pack_swap_curve() { + let fee_numerator = 1; + let fee_denominator = 4; + let curve = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; + let curve_type = CurveType::ConstantProduct; + let swap_curve = SwapCurve { + curve_type, + calculator: Box::new(curve), + }; + + let mut packed = [0u8; SwapCurve::LEN]; + Pack::pack_into_slice(&swap_curve, &mut packed[..]); + let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap(); + assert_eq!(swap_curve, unpacked); + + let mut packed = vec![]; + packed.push(curve_type as u8); + packed.extend_from_slice(&fee_numerator.to_le_bytes()); + packed.extend_from_slice(&fee_denominator.to_le_bytes()); + packed.extend_from_slice(&[0u8; 48]); // padding + let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap(); + assert_eq!(swap_curve, unpacked); + } } diff --git a/token-swap/program/src/instruction.rs b/token-swap/program/src/instruction.rs index e60c7dbefc2d3e..cff4f273e3956c 100644 --- a/token-swap/program/src/instruction.rs +++ b/token-swap/program/src/instruction.rs @@ -2,10 +2,12 @@ #![allow(clippy::too_many_arguments)] +use crate::curve::{ConstantProductCurve, CurveType, FlatCurve, SwapCurve}; use crate::error::SwapError; use solana_sdk::{ instruction::{AccountMeta, Instruction}, program_error::ProgramError, + program_pack::Pack, pubkey::Pubkey, }; use std::convert::TryInto; @@ -13,7 +15,7 @@ use std::mem::size_of; /// Instructions supported by the SwapInfo program. #[repr(C)] -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] pub enum SwapInstruction { /// Initializes a new SwapInfo. /// @@ -25,12 +27,11 @@ pub enum SwapInstruction { /// 5. `[writable]` Pool Token Account to deposit the minted tokens. Must be empty, owned by user. /// 6. '[]` Token program id Initialize { - /// swap pool fee numerator - fee_numerator: u64, - /// swap pool fee denominator - fee_denominator: u64, /// nonce used to create valid program address nonce: u8, + /// swap curve info for pool, including CurveType, fees, and anything + /// else that may be required + swap_curve: SwapCurve, }, /// Swap the tokens in the pool. @@ -99,14 +100,9 @@ impl SwapInstruction { let (&tag, rest) = input.split_first().ok_or(SwapError::InvalidInstruction)?; Ok(match tag { 0 => { - let (fee_numerator, rest) = Self::unpack_u64(rest)?; - let (fee_denominator, rest) = Self::unpack_u64(rest)?; - let (&nonce, _rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?; - Self::Initialize { - fee_numerator, - fee_denominator, - nonce, - } + let (&nonce, rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?; + let swap_curve = SwapCurve::unpack_unchecked(rest)?; + Self::Initialize { nonce, swap_curve } } 1 => { let (amount_in, rest) = Self::unpack_u64(rest)?; @@ -157,16 +153,13 @@ impl SwapInstruction { /// Packs a [SwapInstruction](enum.SwapInstruction.html) into a byte buffer. pub fn pack(&self) -> Vec { let mut buf = Vec::with_capacity(size_of::()); - match *self { - Self::Initialize { - fee_numerator, - fee_denominator, - nonce, - } => { + match &*self { + Self::Initialize { nonce, swap_curve } => { buf.push(0); - buf.extend_from_slice(&fee_numerator.to_le_bytes()); - buf.extend_from_slice(&fee_denominator.to_le_bytes()); - buf.push(nonce); + buf.push(*nonce); + let mut swap_curve_slice = [0u8; SwapCurve::LEN]; + Pack::pack_into_slice(swap_curve, &mut swap_curve_slice[..]); + buf.extend_from_slice(&swap_curve_slice); } Self::Swap { amount_in, @@ -212,14 +205,24 @@ pub fn initialize( pool_pubkey: &Pubkey, destination_pubkey: &Pubkey, nonce: u8, + curve_type: CurveType, fee_numerator: u64, fee_denominator: u64, ) -> Result { - let init_data = SwapInstruction::Initialize { - fee_numerator, - fee_denominator, - nonce, + let swap_curve = SwapCurve { + curve_type, + calculator: match curve_type { + CurveType::ConstantProduct => Box::new(ConstantProductCurve { + fee_numerator, + fee_denominator, + }), + CurveType::Flat => Box::new(FlatCurve { + fee_numerator, + fee_denominator, + }), + }, }; + let init_data = SwapInstruction::Initialize { nonce, swap_curve }; let data = init_data.pack(); let accounts = vec![ @@ -379,17 +382,24 @@ mod tests { let fee_numerator: u64 = 1; let fee_denominator: u64 = 4; let nonce: u8 = 255; - let check = SwapInstruction::Initialize { + let curve_type = CurveType::Flat; + let calculator = Box::new(FlatCurve { fee_numerator, fee_denominator, - nonce, + }); + let swap_curve = SwapCurve { + curve_type, + calculator, }; + let check = SwapInstruction::Initialize { nonce, swap_curve }; let packed = check.pack(); let mut expect = vec![]; expect.push(0 as u8); + expect.push(nonce); + expect.push(curve_type as u8); expect.extend_from_slice(&fee_numerator.to_le_bytes()); expect.extend_from_slice(&fee_denominator.to_le_bytes()); - expect.push(nonce); + expect.extend_from_slice(&[0u8; 48]); // padding assert_eq!(packed, expect); let unpacked = SwapInstruction::unpack(&expect).unwrap(); assert_eq!(unpacked, check); diff --git a/token-swap/program/src/processor.rs b/token-swap/program/src/processor.rs index 5f55c2a0052e22..6df63167e3fb5f 100644 --- a/token-swap/program/src/processor.rs +++ b/token-swap/program/src/processor.rs @@ -2,12 +2,7 @@ #![cfg(feature = "program")] -use crate::{ - curve::{ConstantProduct, PoolTokenConverter}, - error::SwapError, - instruction::SwapInstruction, - state::SwapInfo, -}; +use crate::{curve::SwapCurve, error::SwapError, instruction::SwapInstruction, state::SwapInfo}; use num_traits::FromPrimitive; #[cfg(not(target_arch = "bpf"))] use solana_sdk::instruction::Instruction; @@ -142,8 +137,7 @@ impl Processor { pub fn process_initialize( program_id: &Pubkey, nonce: u8, - fee_numerator: u64, - fee_denominator: u64, + swap_curve: SwapCurve, accounts: &[AccountInfo], ) -> ProgramResult { let account_info_iter = &mut accounts.iter(); @@ -198,8 +192,7 @@ impl Processor { return Err(SwapError::InvalidSupply.into()); } - let converter = PoolTokenConverter::new_pool(token_a.amount, token_b.amount); - let initial_amount = converter.supply; + let initial_amount = swap_curve.calculator.new_pool_supply(); Self::token_mint_to( swap_info.key, @@ -218,8 +211,7 @@ impl Processor { token_a: *token_a_info.key, token_b: *token_b_info.key, pool_mint: *pool_mint_info.key, - fee_numerator, - fee_denominator, + swap_curve, }; SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?; Ok(()) @@ -262,28 +254,12 @@ impl Processor { let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?; - let amount_out = if *swap_source_info.key == token_swap.token_a { - let mut invariant = ConstantProduct { - token_a: source_account.amount, - token_b: dest_account.amount, - fee_numerator: token_swap.fee_numerator, - fee_denominator: token_swap.fee_denominator, - }; - invariant - .swap_a_to_b(amount_in) - .ok_or(SwapError::CalculationFailure)? - } else { - let mut invariant = ConstantProduct { - token_a: dest_account.amount, - token_b: source_account.amount, - fee_numerator: token_swap.fee_numerator, - fee_denominator: token_swap.fee_denominator, - }; - invariant - .swap_b_to_a(amount_in) - .ok_or(SwapError::CalculationFailure)? - }; - if amount_out < minimum_amount_out { + let result = token_swap + .swap_curve + .calculator + .swap(amount_in, source_account.amount, dest_account.amount) + .ok_or(SwapError::CalculationFailure)?; + if result.amount_swapped < minimum_amount_out { return Err(SwapError::ExceededSlippage.into()); } Self::token_transfer( @@ -302,7 +278,7 @@ impl Processor { destination_info.clone(), authority_info.clone(), token_swap.nonce, - amount_out, + result.amount_swapped, )?; Ok(()) } @@ -344,17 +320,16 @@ impl Processor { let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; - let converter = - PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount); + let calculator = token_swap.swap_curve.calculator; - let a_amount = converter - .token_a_rate(pool_token_amount) + let a_amount = calculator + .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount) .ok_or(SwapError::CalculationFailure)?; if a_amount > maximum_token_a_amount { return Err(SwapError::ExceededSlippage.into()); } - let b_amount = converter - .token_b_rate(pool_token_amount) + let b_amount = calculator + .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount) .ok_or(SwapError::CalculationFailure)?; if b_amount > maximum_token_b_amount { return Err(SwapError::ExceededSlippage.into()); @@ -428,17 +403,16 @@ impl Processor { let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; - let converter = - PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount); + let calculator = token_swap.swap_curve.calculator; - let a_amount = converter - .token_a_rate(pool_token_amount) + let a_amount = calculator + .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount) .ok_or(SwapError::CalculationFailure)?; if a_amount < minimum_token_a_amount { return Err(SwapError::ExceededSlippage.into()); } - let b_amount = converter - .token_b_rate(pool_token_amount) + let b_amount = calculator + .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount) .ok_or(SwapError::CalculationFailure)?; if b_amount < minimum_token_b_amount { return Err(SwapError::ExceededSlippage.into()); @@ -478,19 +452,9 @@ impl Processor { pub fn process(program_id: &Pubkey, accounts: &[AccountInfo], input: &[u8]) -> ProgramResult { let instruction = SwapInstruction::unpack(input)?; match instruction { - SwapInstruction::Initialize { - fee_numerator, - fee_denominator, - nonce, - } => { + SwapInstruction::Initialize { nonce, swap_curve } => { info!("Instruction: Init"); - Self::process_initialize( - program_id, - nonce, - fee_numerator, - fee_denominator, - accounts, - ) + Self::process_initialize(program_id, nonce, swap_curve, accounts) } SwapInstruction::Swap { amount_in, @@ -618,7 +582,9 @@ solana_sdk::program_stubs!(); mod tests { use super::*; use crate::{ - curve::{SwapResult, INITIAL_SWAP_POOL_AMOUNT}, + curve::{ + ConstantProductCurve, CurveCalculator, CurveType, FlatCurve, INITIAL_SWAP_POOL_AMOUNT, + }, instruction::{deposit, initialize, swap, withdraw}, }; use solana_sdk::{ @@ -634,6 +600,7 @@ mod tests { struct SwapAccountInfo { nonce: u8, + curve_type: CurveType, authority_key: Pubkey, fee_numerator: u64, fee_denominator: u64, @@ -656,6 +623,7 @@ mod tests { impl SwapAccountInfo { pub fn new( user_key: &Pubkey, + curve_type: CurveType, fee_numerator: u64, fee_denominator: u64, token_a_amount: u64, @@ -699,6 +667,7 @@ mod tests { SwapAccountInfo { nonce, + curve_type, authority_key, fee_numerator, fee_denominator, @@ -731,6 +700,7 @@ mod tests { &self.pool_mint_key, &self.pool_token_key, self.nonce, + self.curve_type, self.fee_numerator, self.fee_denominator, ) @@ -1179,9 +1149,11 @@ mod tests { let token_a_amount = 1000; let token_b_amount = 2000; let pool_token_amount = 10; + let curve_type = CurveType::ConstantProduct; let mut accounts = SwapAccountInfo::new( &user_key, + curve_type, fee_numerator, fee_denominator, token_a_amount, @@ -1474,6 +1446,7 @@ mod tests { &accounts.pool_mint_key, &accounts.pool_token_key, accounts.nonce, + accounts.curve_type, accounts.fee_numerator, accounts.fee_denominator, ) @@ -1513,6 +1486,19 @@ mod tests { // create valid swap accounts.initialize_swap().unwrap(); + // create valid flat swap + { + let mut accounts = SwapAccountInfo::new( + &user_key, + CurveType::Flat, + fee_numerator, + fee_denominator, + token_a_amount, + token_b_amount, + ); + accounts.initialize_swap().unwrap(); + } + // create again { assert_eq!( @@ -1523,11 +1509,10 @@ mod tests { let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap(); assert_eq!(swap_info.is_initialized, true); assert_eq!(swap_info.nonce, accounts.nonce); + assert_eq!(swap_info.swap_curve.curve_type, accounts.curve_type); assert_eq!(swap_info.token_a, accounts.token_a_key); assert_eq!(swap_info.token_b, accounts.token_b_key); assert_eq!(swap_info.pool_mint, accounts.pool_mint_key); - assert_eq!(swap_info.fee_denominator, fee_denominator); - assert_eq!(swap_info.fee_numerator, fee_numerator); let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); assert_eq!(token_a.amount, token_a_amount); let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); @@ -1546,8 +1531,11 @@ mod tests { let fee_denominator = 2; let token_a_amount = 1000; let token_b_amount = 9000; + let curve_type = CurveType::ConstantProduct; + let mut accounts = SwapAccountInfo::new( &user_key, + curve_type, fee_numerator, fee_denominator, token_a_amount, @@ -2063,18 +2051,24 @@ mod tests { let fee_denominator = 2; let token_a_amount = 1000; let token_b_amount = 2000; + let curve_type = CurveType::ConstantProduct; + let mut accounts = SwapAccountInfo::new( &user_key, + curve_type, fee_numerator, fee_denominator, token_a_amount, token_b_amount, ); let withdrawer_key = pubkey_rand(); - let pool_converter = PoolTokenConverter::new_pool(token_a_amount, token_b_amount); + let calculator = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; let initial_a = token_a_amount / 10; let initial_b = token_b_amount / 10; - let initial_pool = pool_converter.supply / 10; + let initial_pool = calculator.new_pool_supply() / 10; let withdraw_amount = initial_pool / 4; let minimum_a_amount = initial_a / 40; let minimum_b_amount = initial_b / 40; @@ -2583,15 +2577,18 @@ mod tests { let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); - let pool_converter = PoolTokenConverter::new_existing( - pool_mint.supply, - swap_token_a.amount, - swap_token_b.amount, - ); + let calculator = ConstantProductCurve { + fee_numerator, + fee_denominator, + }; - let withdrawn_a = pool_converter.token_a_rate(withdraw_amount).unwrap(); + let withdrawn_a = calculator + .liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_a.amount) + .unwrap(); assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a); - let withdrawn_b = pool_converter.token_b_rate(withdraw_amount).unwrap(); + let withdrawn_b = calculator + .liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_b.amount) + .unwrap(); assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b); let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); assert_eq!(token_a.amount, initial_a + withdrawn_a); @@ -2602,16 +2599,150 @@ mod tests { } } + fn check_valid_swap_curve(curve_type: CurveType, calculator: Box) { + let user_key = pubkey_rand(); + let swapper_key = pubkey_rand(); + let fee_numerator = 1; + let fee_denominator = 10; + let token_a_amount = 1000; + let token_b_amount = 5000; + + let swap_curve = SwapCurve { + curve_type, + calculator, + }; + + let mut accounts = SwapAccountInfo::new( + &user_key, + curve_type, + fee_numerator, + fee_denominator, + token_a_amount, + token_b_amount, + ); + let initial_a = token_a_amount / 5; + let initial_b = token_b_amount / 5; + accounts.initialize_swap().unwrap(); + + let swap_token_a_key = accounts.token_a_key; + let swap_token_b_key = accounts.token_b_key; + + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + _pool_key, + _pool_account, + ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); + // swap one way + let a_to_b_amount = initial_a / 10; + let minimum_b_amount = 0; + accounts + .swap( + &swapper_key, + &token_a_key, + &mut token_a_account, + &swap_token_a_key, + &swap_token_b_key, + &token_b_key, + &mut token_b_account, + a_to_b_amount, + minimum_b_amount, + ) + .unwrap(); + + let results = swap_curve + .calculator + .swap(a_to_b_amount, token_a_amount, token_b_amount) + .unwrap(); + + let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + let token_a_amount = swap_token_a.amount; + assert_eq!(token_a_amount, results.new_source_amount); + let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + assert_eq!(token_a.amount, initial_a - a_to_b_amount); + + let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + let token_b_amount = swap_token_b.amount; + assert_eq!(token_b_amount, results.new_destination_amount); + let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + assert_eq!(token_b.amount, initial_b + results.amount_swapped); + + let first_swap_amount = results.amount_swapped; + + // swap the other way + let b_to_a_amount = initial_b / 10; + let minimum_a_amount = 0; + accounts + .swap( + &swapper_key, + &token_b_key, + &mut token_b_account, + &swap_token_b_key, + &swap_token_a_key, + &token_a_key, + &mut token_a_account, + b_to_a_amount, + minimum_a_amount, + ) + .unwrap(); + + let results = swap_curve + .calculator + .swap(b_to_a_amount, token_b_amount, token_a_amount) + .unwrap(); + + let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + assert_eq!(swap_token_a.amount, results.new_destination_amount); + let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + assert_eq!( + token_a.amount, + initial_a - a_to_b_amount + results.amount_swapped + ); + + let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + assert_eq!(swap_token_b.amount, results.new_source_amount); + let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + assert_eq!( + token_b.amount, + initial_b + first_swap_amount - b_to_a_amount + ); + } + + #[test] + fn test_valid_swap_curves() { + let fee_numerator = 1; + let fee_denominator = 10; + check_valid_swap_curve( + CurveType::ConstantProduct, + Box::new(ConstantProductCurve { + fee_numerator, + fee_denominator, + }), + ); + check_valid_swap_curve( + CurveType::Flat, + Box::new(FlatCurve { + fee_numerator, + fee_denominator, + }), + ); + } + #[test] - fn test_swap() { + fn test_invalid_swap() { let user_key = pubkey_rand(); let swapper_key = pubkey_rand(); let fee_numerator = 1; let fee_denominator = 10; let token_a_amount = 1000; let token_b_amount = 5000; + let curve_type = CurveType::ConstantProduct; + let mut accounts = SwapAccountInfo::new( &user_key, + curve_type, fee_numerator, fee_denominator, token_a_amount, @@ -2932,102 +3063,5 @@ mod tests { ) ); } - - // correct swap - { - let ( - token_a_key, - mut token_a_account, - token_b_key, - mut token_b_account, - _pool_key, - _pool_account, - ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); - // swap one way - let a_to_b_amount = initial_a / 10; - let minimum_b_amount = initial_b / 20; - accounts - .swap( - &swapper_key, - &token_a_key, - &mut token_a_account, - &swap_token_a_key, - &swap_token_b_key, - &token_b_key, - &mut token_b_account, - a_to_b_amount, - minimum_b_amount, - ) - .unwrap(); - - let results = SwapResult::swap_to( - a_to_b_amount, - token_a_amount, - token_b_amount, - fee_numerator, - fee_denominator, - ) - .unwrap(); - - let swap_token_a = - Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); - let token_a_amount = swap_token_a.amount; - assert_eq!(token_a_amount, results.new_source_amount); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); - assert_eq!(token_a.amount, initial_a - a_to_b_amount); - - let swap_token_b = - Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); - let token_b_amount = swap_token_b.amount; - assert_eq!(token_b_amount, results.new_destination_amount); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); - assert_eq!(token_b.amount, initial_b + results.amount_swapped); - - let first_swap_amount = results.amount_swapped; - - // swap the other way - let b_to_a_amount = initial_b / 10; - let minimum_a_amount = initial_a / 20; - accounts - .swap( - &swapper_key, - &token_b_key, - &mut token_b_account, - &swap_token_b_key, - &swap_token_a_key, - &token_a_key, - &mut token_a_account, - b_to_a_amount, - minimum_a_amount, - ) - .unwrap(); - - let results = SwapResult::swap_to( - b_to_a_amount, - token_b_amount, - token_a_amount, - fee_numerator, - fee_denominator, - ) - .unwrap(); - - let swap_token_a = - Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); - assert_eq!(swap_token_a.amount, results.new_destination_amount); - let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); - assert_eq!( - token_a.amount, - initial_a - a_to_b_amount + results.amount_swapped - ); - - let swap_token_b = - Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); - assert_eq!(swap_token_b.amount, results.new_source_amount); - let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); - assert_eq!( - token_b.amount, - initial_b + first_swap_amount - b_to_a_amount - ); - } } } diff --git a/token-swap/program/src/state.rs b/token-swap/program/src/state.rs index c21a083e4fe860..e722a298fd4b6d 100644 --- a/token-swap/program/src/state.rs +++ b/token-swap/program/src/state.rs @@ -1,5 +1,6 @@ //! State transition types +use crate::curve::SwapCurve; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use solana_sdk::{ program_error::ProgramError, @@ -9,7 +10,7 @@ use solana_sdk::{ /// Program states. #[repr(C)] -#[derive(Clone, Copy, Debug, Default, PartialEq)] +#[derive(Debug, Default, PartialEq)] pub struct SwapInfo { /// Initialized state. pub is_initialized: bool, @@ -32,10 +33,9 @@ pub struct SwapInfo { /// Pool tokens can be withdrawn back to the original A or B token. pub pool_mint: Pubkey, - /// Numerator of fee applied to the input token amount prior to output calculation. - pub fee_numerator: u64, - /// Denominator of fee applied to the input token amount prior to output calculation. - pub fee_denominator: u64, + /// Swap curve parameters, to be unpacked and used by the SwapCurve, which + /// calculates swaps, deposits, and withdrawals + pub swap_curve: SwapCurve, } impl Sealed for SwapInfo {} @@ -46,22 +46,14 @@ impl IsInitialized for SwapInfo { } impl Pack for SwapInfo { - const LEN: usize = 146; + const LEN: usize = 195; /// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html). fn unpack_from_slice(input: &[u8]) -> Result { - let input = array_ref![input, 0, 146]; + let input = array_ref![input, 0, 195]; #[allow(clippy::ptr_offset_with_cast)] - let ( - is_initialized, - nonce, - token_program_id, - token_a, - token_b, - pool_mint, - fee_numerator, - fee_denominator, - ) = array_refs![input, 1, 1, 32, 32, 32, 32, 8, 8]; + let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) = + array_refs![input, 1, 1, 32, 32, 32, 32, 65]; Ok(Self { is_initialized: match is_initialized { [0] => false, @@ -73,41 +65,36 @@ impl Pack for SwapInfo { token_a: Pubkey::new_from_array(*token_a), token_b: Pubkey::new_from_array(*token_b), pool_mint: Pubkey::new_from_array(*pool_mint), - fee_numerator: u64::from_le_bytes(*fee_numerator), - fee_denominator: u64::from_le_bytes(*fee_denominator), + swap_curve: SwapCurve::unpack_from_slice(swap_curve)?, }) } fn pack_into_slice(&self, output: &mut [u8]) { - let output = array_mut_ref![output, 0, 146]; - let ( - is_initialized, - nonce, - token_program_id, - token_a, - token_b, - pool_mint, - fee_numerator, - fee_denominator, - ) = mut_array_refs![output, 1, 1, 32, 32, 32, 32, 8, 8]; + let output = array_mut_ref![output, 0, 195]; + let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) = + mut_array_refs![output, 1, 1, 32, 32, 32, 32, 65]; is_initialized[0] = self.is_initialized as u8; nonce[0] = self.nonce; token_program_id.copy_from_slice(self.token_program_id.as_ref()); token_a.copy_from_slice(self.token_a.as_ref()); token_b.copy_from_slice(self.token_b.as_ref()); pool_mint.copy_from_slice(self.pool_mint.as_ref()); - *fee_numerator = self.fee_numerator.to_le_bytes(); - *fee_denominator = self.fee_denominator.to_le_bytes(); + self.swap_curve.pack_into_slice(&mut swap_curve[..]); } } #[cfg(test)] mod tests { use super::*; + use crate::curve::FlatCurve; + + use std::convert::TryInto; #[test] fn test_swap_info_packing() { let nonce = 255; + let curve_type_raw: u8 = 1; + let curve_type = curve_type_raw.try_into().unwrap(); let token_program_id_raw = [1u8; 32]; let token_a_raw = [1u8; 32]; let token_b_raw = [2u8; 32]; @@ -118,6 +105,14 @@ mod tests { let pool_mint = Pubkey::new_from_array(pool_mint_raw); let fee_numerator = 1; let fee_denominator = 4; + let calculator = Box::new(FlatCurve { + fee_numerator, + fee_denominator, + }); + let swap_curve = SwapCurve { + curve_type, + calculator, + }; let is_initialized = true; let swap_info = SwapInfo { is_initialized, @@ -126,12 +121,11 @@ mod tests { token_a, token_b, pool_mint, - fee_numerator, - fee_denominator, + swap_curve, }; let mut packed = [0u8; SwapInfo::LEN]; - SwapInfo::pack(swap_info, &mut packed).unwrap(); + SwapInfo::pack_into_slice(&swap_info, &mut packed); let unpacked = SwapInfo::unpack(&packed).unwrap(); assert_eq!(swap_info, unpacked); @@ -142,10 +136,12 @@ mod tests { packed.extend_from_slice(&token_a_raw); packed.extend_from_slice(&token_b_raw); packed.extend_from_slice(&pool_mint_raw); + packed.push(curve_type_raw); packed.push(fee_numerator as u8); packed.extend_from_slice(&[0u8; 7]); // padding packed.push(fee_denominator as u8); packed.extend_from_slice(&[0u8; 7]); // padding + packed.extend_from_slice(&[0u8; 48]); // padding let unpacked = SwapInfo::unpack(&packed).unwrap(); assert_eq!(swap_info, unpacked);