Skip to content

Commit

Permalink
token-swap: Cleanup check (solana-labs#903)
Browse files Browse the repository at this point in the history
* Add extra check

* Add extra check on init

* Run cargo fmt

* Clippy

* Add token program id check on unpack

* Run cargo fmt

* Add checks for program id ownership on swap account

* Add truncation check during swapping

* Run cargo fmt

* Update truncation to ceiling the value

* Run cargo fmt

* Fix JS test

* Refund back not needed source tokens

* Run cargo fmt

* Add swap_without_fees method to trait

* Fix merge problem

Co-authored-by: Justin Starry <[email protected]>
  • Loading branch information
joncinque and jstarry authored Dec 3, 2020
1 parent e49201b commit 04a3c83
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 161 deletions.
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion token-swap/js/cli/token-swap-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ let currentFeeAmount = 0;
// Because there is no withdraw fee in the production version, these numbers
// need to get slightly tweaked in the two cases.
const SWAP_AMOUNT_IN = 100000;
const SWAP_AMOUNT_OUT = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 90662 : 90675;
const SWAP_AMOUNT_OUT = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 90661 : 90674;
const SWAP_FEE = SWAP_PROGRAM_OWNER_FEE_ADDRESS ? 22272 : 22276;
const HOST_SWAP_FEE = SWAP_PROGRAM_OWNER_FEE_ADDRESS
? Math.floor((SWAP_FEE * HOST_FEE_NUMERATOR) / HOST_FEE_DENOMINATOR)
Expand Down
52 changes: 49 additions & 3 deletions token-swap/program/src/curve/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::fmt::Debug;
/// input amounts, and Balancer uses 100 * 10 ^ 18.
pub const INITIAL_SWAP_POOL_AMOUNT: u128 = 1_000_000_000;

/// Helper function for calcuating swap fee
/// Helper function for calculating swap fee
pub fn calculate_fee(
token_amount: u128,
fee_numerator: u128,
Expand Down Expand Up @@ -43,14 +43,24 @@ pub struct SwapResult {
pub new_source_amount: u128,
/// New amount of destination token
pub new_destination_amount: u128,
/// Amount of source token swapped (includes fees)
pub source_amount_swapped: u128,
/// Amount of destination token swapped
pub amount_swapped: u128,
pub destination_amount_swapped: u128,
/// Amount of source tokens going to pool holders
pub trade_fee: u128,
/// Amount of source tokens going to owner
pub owner_fee: u128,
}

/// Encodes all results of swapping from a source token to a destination token
pub struct SwapWithoutFeesResult {
/// Amount of source token swapped
pub source_amount_swapped: u128,
/// Amount of destination token swapped
pub destination_amount_swapped: u128,
}

/// Trait for packing of trait objects, required because structs that implement
/// `Pack` cannot be used as trait objects (as `dyn Pack`).
pub trait DynPack {
Expand All @@ -62,12 +72,48 @@ pub trait DynPack {
pub trait CurveCalculator: Debug + DynPack {
/// Calculate how much destination token will be provided given an amount
/// of source token.
fn swap_without_fees(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapWithoutFeesResult>;

/// Subtract fees and calculate how much destination token will be provided
/// given an amount of source token.
fn swap(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult>;
) -> Option<SwapResult> {
// debit the fee to calculate the amount swapped
let trade_fee = self.trading_fee(source_amount)?;
let owner_fee = self.owner_trading_fee(source_amount)?;

let total_fees = trade_fee.checked_add(owner_fee)?;
let source_amount_less_fees = source_amount.checked_sub(total_fees)?;

let SwapWithoutFeesResult {
source_amount_swapped,
destination_amount_swapped,
} = self.swap_without_fees(
source_amount_less_fees,
swap_source_amount,
swap_destination_amount,
)?;

let source_amount_swapped = source_amount_swapped.checked_add(total_fees)?;
Some(SwapResult {
new_source_amount: swap_source_amount.checked_add(source_amount_swapped)?,
new_destination_amount: swap_destination_amount
.checked_sub(destination_amount_swapped)?,
source_amount_swapped,
destination_amount_swapped,
trade_fee,
owner_fee,
})
}

/// Calculate the withdraw fee in pool tokens
/// Default implementation assumes no fee
Expand Down
123 changes: 95 additions & 28 deletions token-swap/program/src/curve/constant_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use solana_program::{
};

use crate::curve::calculator::{
calculate_fee, map_zero_to_none, CurveCalculator, DynPack, SwapResult,
calculate_fee, map_zero_to_none, CurveCalculator, DynPack, SwapWithoutFeesResult,
};
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use std::convert::TryFrom;
Expand Down Expand Up @@ -34,33 +34,38 @@ pub struct ConstantProductCurve {

impl CurveCalculator for ConstantProductCurve {
/// Constant product swap ensures x * y = constant
fn swap(
fn swap_without_fees(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult> {
// debit the fee to calculate the amount swapped
let trade_fee = self.trading_fee(source_amount)?;
let owner_fee = self.owner_trading_fee(source_amount)?;

) -> Option<SwapWithoutFeesResult> {
let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
let new_source_amount_less_fee = swap_source_amount
.checked_add(source_amount)?
.checked_sub(trade_fee)?
.checked_sub(owner_fee)?;
let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?;
let amount_swapped =
map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?;

// actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
Some(SwapResult {
new_source_amount,
new_destination_amount,
amount_swapped,
trade_fee,
owner_fee,
let mut new_swap_source_amount = swap_source_amount.checked_add(source_amount)?;
let mut new_swap_destination_amount = invariant.checked_div(new_swap_source_amount)?;

// Ceiling the destination amount if there's any remainder, which will
// almost always be the case.
let remainder = invariant.checked_rem(new_swap_source_amount)?;
if remainder > 0 {
new_swap_destination_amount = new_swap_destination_amount.checked_add(1)?;
// now calculate the minimum amount of source token needed to get
// the destination amount to avoid taking too much from users
new_swap_source_amount = invariant.checked_div(new_swap_destination_amount)?;
let remainder = invariant.checked_rem(new_swap_destination_amount)?;
if remainder > 0 {
new_swap_source_amount = new_swap_source_amount.checked_add(1)?;
}
}

let source_amount_swapped = new_swap_source_amount.checked_sub(swap_source_amount)?;
let destination_amount_swapped =
map_zero_to_none(swap_destination_amount.checked_sub(new_swap_destination_amount)?)?;

Some(SwapWithoutFeesResult {
source_amount_swapped,
destination_amount_swapped,
})
}

Expand Down Expand Up @@ -255,8 +260,8 @@ mod tests {
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, 4505);
assert_eq!(result.new_destination_amount, 45495);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_destination_amount, 45496);
assert_eq!(result.trade_fee, 1);
assert_eq!(result.owner_fee, 0);
}
Expand Down Expand Up @@ -289,8 +294,8 @@ mod tests {
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, 4505);
assert_eq!(result.new_destination_amount, 45495);
assert_eq!(result.destination_amount_swapped, 4504);
assert_eq!(result.new_destination_amount, 45496);
assert_eq!(result.trade_fee, 0);
assert_eq!(result.owner_fee, 1);
}
Expand All @@ -305,8 +310,8 @@ mod tests {
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, 4546);
assert_eq!(result.new_destination_amount, 45454);
assert_eq!(result.destination_amount_swapped, 4545);
assert_eq!(result.new_destination_amount, 45455);
}

#[test]
Expand Down Expand Up @@ -347,4 +352,66 @@ mod tests {
let unpacked = ConstantProductCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
}

fn test_truncation(
curve: &ConstantProductCurve,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
expected_source_amount_swapped: u128,
expected_destination_amount_swapped: u128,
) {
let invariant = swap_source_amount * swap_destination_amount;
let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.source_amount_swapped, expected_source_amount_swapped);
assert_eq!(
result.destination_amount_swapped,
expected_destination_amount_swapped
);
let new_invariant = (swap_source_amount + result.source_amount_swapped)
* (swap_destination_amount - result.destination_amount_swapped);
assert!(new_invariant >= invariant);
}

#[test]
fn constant_product_swap_rounding() {
let curve = ConstantProductCurve::default();

// much too small
assert!(curve
.swap_without_fees(10, 70_000_000_000, 4_000_000)
.is_none()); // spot: 10 * 4m / 70b = 0

let tests: &[(u128, u128, u128, u128, u128)] = &[
(10, 4_000_000, 70_000_000_000, 10, 174_999), // spot: 10 * 70b / ~4m = 174,999.99
(20, 30_000 - 20, 10_000, 18, 6), // spot: 20 * 1 / 3.000 = 6.6667 (source can be 18 to get 6 dest.)
(19, 30_000 - 20, 10_000, 18, 6), // spot: 19 * 1 / 2.999 = 6.3334 (source can be 18 to get 6 dest.)
(18, 30_000 - 20, 10_000, 18, 6), // spot: 18 * 1 / 2.999 = 6.0001
(10, 20_000, 30_000, 10, 14), // spot: 10 * 3 / 2.0010 = 14.99
(10, 20_000 - 9, 30_000, 10, 14), // spot: 10 * 3 / 2.0001 = 14.999
(10, 20_000 - 10, 30_000, 10, 15), // spot: 10 * 3 / 2.0000 = 15
(100, 60_000, 30_000, 99, 49), // spot: 100 * 3 / 6.001 = 49.99 (source can be 99 to get 49 dest.)
(99, 60_000, 30_000, 99, 49), // spot: 99 * 3 / 6.001 = 49.49
(98, 60_000, 30_000, 97, 48), // spot: 98 * 3 / 6.001 = 48.99 (source can be 97 to get 48 dest.)
];
for (
source_amount,
swap_source_amount,
swap_destination_amount,
expected_source_amount,
expected_destination_amount,
) in tests.iter()
{
test_truncation(
&curve,
*source_amount,
*swap_source_amount,
*swap_destination_amount,
*expected_source_amount,
*expected_destination_amount,
);
}
}
}
32 changes: 9 additions & 23 deletions token-swap/program/src/curve/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use solana_program::{
program_pack::{IsInitialized, Pack, Sealed},
};

use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapResult};
use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapWithoutFeesResult};
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use std::convert::TryFrom;

Expand All @@ -32,29 +32,15 @@ pub struct FlatCurve {

impl CurveCalculator for FlatCurve {
/// Flat curve swap always returns 1:1 (minus fee)
fn swap(
fn swap_without_fees(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult> {
// debit the fee to calculate the amount swapped
let trade_fee = self.trading_fee(source_amount)?;
let owner_fee = self.owner_trading_fee(source_amount)?;

let amount_swapped = source_amount
.checked_sub(trade_fee)?
.checked_sub(owner_fee)?;
let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?;

// actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
Some(SwapResult {
new_source_amount,
new_destination_amount,
amount_swapped,
trade_fee,
owner_fee,
_swap_source_amount: u128,
_swap_destination_amount: u128,
) -> Option<SwapWithoutFeesResult> {
Some(SwapWithoutFeesResult {
source_amount_swapped: source_amount,
destination_amount_swapped: source_amount,
})
}

Expand Down Expand Up @@ -191,7 +177,7 @@ mod tests {
.unwrap();
let amount_swapped = 97;
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, amount_swapped);
assert_eq!(result.destination_amount_swapped, amount_swapped);
assert_eq!(result.trade_fee, 1);
assert_eq!(result.owner_fee, 2);
assert_eq!(
Expand Down
Loading

0 comments on commit 04a3c83

Please sign in to comment.