Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qol improvements 01 #22

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
75 changes: 65 additions & 10 deletions contracts/margin-dex/MarginDex.vy
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,25 @@ def _full_close(_trade: Trade, _min_amount_out: uint256) -> uint256:
_trade.vault_position_uid, _min_amount_out
)

# cleanup trade
self.open_trades[_trade.uid] = empty(Trade)
uids: DynArray[bytes32, 1024] = self.trades_by_account[_trade.account]
self._cleanup_trade(_trade.uid)

log TradeClosed(_trade.account, _trade.uid, _trade, amount_out_received)
return amount_out_received


@internal
def _cleanup_trade(_trade_uid: bytes32):
account: address = self.open_trades[_trade_uid].account
self.open_trades[_trade_uid] = empty(Trade)
uids: DynArray[bytes32, 1024] = self.trades_by_account[account]
for i in range(1024):
if uids[i] == _trade.uid:
if uids[i] == _trade_uid:
uids[i] = uids[len(uids) - 1]
uids.pop()
break
if i == len(uids) - 1:
raise
self.trades_by_account[_trade.account] = uids

log TradeClosed(_trade.account, _trade.uid, _trade, amount_out_received)
return amount_out_received
self.trades_by_account[account] = uids


event TradeReduced:
Expand Down Expand Up @@ -289,6 +294,17 @@ def get_all_open_trades(_account: address) -> DynArray[Trade, 1024]:

return trades

@view
@external
def get_all_open_limit_orders(_account: address) -> DynArray[LimitOrder, 1024]:
uids: DynArray[bytes32, 1024] = self.limit_order_uids[_account]
limit_orders: DynArray[LimitOrder, 1024] = empty(DynArray[LimitOrder, 1024])

for uid in uids:
limit_orders.append(self.limit_orders[uid])

return limit_orders


@external
def swap_margin(
Expand Down Expand Up @@ -368,6 +384,42 @@ def add_sl_order(_trade_uid: bytes32, _sl_order: StopLossOrder):
log SlOrderAdded(_trade_uid, sl_order)


event TpUpdated:
trade_uid: bytes32
tp: TakeProfitOrder

@external
def update_tp_order(_trade_uid: bytes32, _tp_index: uint256, _updated_order: TakeProfitOrder):
trade: Trade = self.open_trades[_trade_uid]
assert (trade.account == msg.sender) or self.is_delegate[trade.account][msg.sender], "unauthorized"
assert self.is_accepting_new_orders, "paused"

assert len(trade.tp_orders) > _tp_index, "invalid index"

trade.tp_orders[_tp_index] = _updated_order

self.open_trades[_trade_uid] = trade
log TpUpdated(_trade_uid, _updated_order)


event SlUpdated:
trade_uid: bytes32
sl: StopLossOrder

@external
def update_sl_order(_trade_uid: bytes32, _sl_index: uint256, _updated_order: StopLossOrder):
trade: Trade = self.open_trades[_trade_uid]
assert (trade.account == msg.sender) or self.is_delegate[trade.account][msg.sender], "unauthorized"
assert self.is_accepting_new_orders, "paused"

assert len(trade.sl_orders) > _sl_index, "invalid index"

trade.sl_orders[_sl_index] = _updated_order

self.open_trades[_trade_uid] = trade
log SlUpdated(_trade_uid, _updated_order)


event TpExecuted:
trade_uid: bytes32
reduce_by_amount: uint256
Expand Down Expand Up @@ -397,7 +449,7 @@ def execute_tp_order(_trade_uid: bytes32, _tp_order_index: uint8):
trade.vault_position_uid
)
amount_out_received: uint256 = 0
if tp_order.reduce_by_amount == position_amount:
if tp_order.reduce_by_amount >= position_amount:
amount_out_received = self._full_close(trade, tp_order.min_amount_out)
else:
amount_out_received = self._partial_close(
Expand Down Expand Up @@ -530,7 +582,7 @@ def post_limit_order(
assert self.is_accepting_new_orders, "not accepting new orders"
assert (_account == msg.sender) or self.is_delegate[_account][msg.sender], "unauthorized"

assert Vault(self.vault).is_enabled_market(_debt_token, _position_token)
assert Vault(self.vault).is_enabled_market(_debt_token, _position_token), "market not enabled"
assert _margin_amount > 0, "invalid margin amount"
assert _debt_amount > _margin_amount, "invalid debt amount"

Expand Down Expand Up @@ -671,7 +723,10 @@ def liquidate(_trade_uid: bytes32):
allowed leverage.
"""
trade: Trade = self.open_trades[_trade_uid]
self._cleanup_trade(_trade_uid)

Vault(self.vault).liquidate(trade.vault_position_uid)

log Liquidation(trade.account, _trade_uid, trade)


Expand Down
4 changes: 2 additions & 2 deletions contracts/margin-dex/Vault.vy
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ is_whitelisted_dex: public(HashMap[address, bool])
# address -> address -> bool
is_whitelisted_token: public(HashMap[address, bool])
# token_in -> # token_out
is_enabled_market: HashMap[address, HashMap[address, bool]]
is_enabled_market: public(HashMap[address, HashMap[address, bool]])
# token_in -> # token_out
max_leverage: public(HashMap[address, HashMap[address, uint256]])
# token -> Chainlink oracle
Expand Down Expand Up @@ -1608,4 +1608,4 @@ def set_variable_interest_parameters(
_mid_interest_rate,
_max_interest_rate,
_rate_switch_utilization,
]
]
24 changes: 24 additions & 0 deletions tests/dex/test_open_trade.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,27 @@ def test_open_trade_creates_trade_struct(dex, owner, weth, usdc, mock_vault):
assert trade[3] == [] # tp orders
assert trade[4] == [] # sl orders


def test_open_trade_is_recorded_correctly(dex, owner, weth, usdc, mock_vault):
margin_amount = 15 * 10**6
usdc_in = 150 * 10**6
min_weth_out = 123
trade = dex.open_trade(owner, weth, min_weth_out, usdc, usdc_in, margin_amount, [], [])

trades = dex.get_all_open_trades(owner)
assert len(trades) == 1
assert trade[0] == trades[0][0]

def test_cleanup_trade(dex, owner, weth, usdc, mock_vault):
margin_amount = 15 * 10**6
usdc_in = 150 * 10**6
min_weth_out = 123
trade = dex.open_trade(owner, weth, min_weth_out, usdc, usdc_in, margin_amount, [], [])

trades = dex.get_all_open_trades(owner)
assert len(trades) == 1

dex.internal._cleanup_trade(trade[0])

trades_after = dex.get_all_open_trades(owner)
assert len(trades_after) == 0