From e32a8ffb0f1185353acf5eb5322632f17496f002 Mon Sep 17 00:00:00 2001 From: Matt Hauff Date: Wed, 18 Oct 2023 15:13:31 -0700 Subject: [PATCH] Allow set_status to overwrite trade in store (#16636) --- chia/wallet/trading/trade_store.py | 17 +++--- tests/wallet/cat_wallet/test_trades.py | 76 ++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/chia/wallet/trading/trade_store.py b/chia/wallet/trading/trade_store.py index 5f1617bdb7e5..5c7fc1225013 100644 --- a/chia/wallet/trading/trade_store.py +++ b/chia/wallet/trading/trade_store.py @@ -164,17 +164,18 @@ async def create( return self - async def add_trade_record(self, record: TradeRecord, offer_name: bytes32) -> None: + async def add_trade_record(self, record: TradeRecord, offer_name: bytes32, replace: bool = False) -> None: """ Store TradeRecord into DB """ async with self.db_wrapper.writer_maybe_transaction() as conn: - existing_trades_with_same_offer = await conn.execute_fetchall( - "SELECT trade_id FROM trade_records WHERE offer_name=? AND trade_id<>? LIMIT 1", - (offer_name, record.trade_id.hex()), - ) - if existing_trades_with_same_offer: - raise ValueError("Trade for this offer already exists.") + if not replace: + existing_trades_with_same_offer = await conn.execute_fetchall( + "SELECT trade_id FROM trade_records WHERE offer_name=? AND trade_id<>? LIMIT 1", + (offer_name, record.trade_id.hex()), + ) + if existing_trades_with_same_offer: + raise ValueError("Trade for this offer already exists.") cursor = await conn.execute( "INSERT OR REPLACE INTO trade_records " "(trade_record, trade_id, status, confirmed_at_index, created_at_time, sent, offer_name, is_my_offer) " @@ -242,7 +243,7 @@ async def set_status( sent_to=current.sent_to, valid_times=current.valid_times, ) - await self.add_trade_record(tx, offer_name) + await self.add_trade_record(tx, offer_name, replace=True) async def increment_sent( self, id: bytes32, name: str, send_status: MempoolInclusionStatus, err: Optional[Err] diff --git a/tests/wallet/cat_wallet/test_trades.py b/tests/wallet/cat_wallet/test_trades.py index fb7621738197..fb84dcd5e4ef 100644 --- a/tests/wallet/cat_wallet/test_trades.py +++ b/tests/wallet/cat_wallet/test_trades.py @@ -1094,3 +1094,79 @@ async def get_trade_and_status(trade_manager, trade) -> TradeStatus: ) await full_node.process_transaction_records(records=txs1) await time_out_assert(15, get_trade_and_status, TradeStatus.CONFIRMED, trade_manager_taker, tr1) + + @pytest.mark.asyncio + async def test_aggregated_trade_state(self, wallets_prefarm): + ( + [wallet_node_maker, maker_funds], + [wallet_node_taker, taker_funds], + full_node, + ) = wallets_prefarm + wallet_maker = wallet_node_maker.wallet_state_manager.main_wallet + xch_to_cat_amount = uint64(100) + + async with wallet_node_maker.wallet_state_manager.lock: + cat_wallet_maker: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_maker.wallet_state_manager, + wallet_maker, + {"identifier": "genesis_by_id"}, + xch_to_cat_amount, + DEFAULT_TX_CONFIG, + ) + + tx_records: List[TransactionRecord] = await wallet_node_maker.wallet_state_manager.tx_store.get_not_sent() + + await full_node.process_transaction_records(records=tx_records) + + await time_out_assert(15, cat_wallet_maker.get_confirmed_balance, xch_to_cat_amount) + await time_out_assert(15, cat_wallet_maker.get_unconfirmed_balance, xch_to_cat_amount) + maker_funds -= xch_to_cat_amount + await time_out_assert(15, wallet_maker.get_confirmed_balance, maker_funds) + + chia_for_cat = { + wallet_maker.id(): 2, + cat_wallet_maker.id(): -2, + } + cat_for_chia = { + wallet_maker.id(): -1, + cat_wallet_maker.id(): 1, + } + + trade_manager_maker = wallet_node_maker.wallet_state_manager.trade_manager + trade_manager_taker = wallet_node_taker.wallet_state_manager.trade_manager + + async def get_trade_and_status(trade_manager, trade) -> TradeStatus: + trade_rec = await trade_manager.get_trade_by_id(trade.trade_id) + if trade_rec: + return TradeStatus(trade_rec.status) + raise ValueError("Couldn't find the trade record") # pragma: no cover + + success, trade_make_1, error = await trade_manager_maker.create_offer_for_ids(chia_for_cat, DEFAULT_TX_CONFIG) + await time_out_assert(10, get_trade_and_status, TradeStatus.PENDING_ACCEPT, trade_manager_maker, trade_make_1) + assert error is None + assert success is True + assert trade_make_1 is not None + success, trade_make_2, error = await trade_manager_maker.create_offer_for_ids(cat_for_chia, DEFAULT_TX_CONFIG) + await time_out_assert(10, get_trade_and_status, TradeStatus.PENDING_ACCEPT, trade_manager_maker, trade_make_2) + assert error is None + assert success is True + assert trade_make_2 is not None + + agg_offer = Offer.aggregate([Offer.from_bytes(trade_make_1.offer), Offer.from_bytes(trade_make_2.offer)]) + + peer = wallet_node_taker.get_full_node_peer() + trade_take, tx_records = await trade_manager_taker.respond_to_offer( + agg_offer, + peer, + DEFAULT_TX_CONFIG, + ) + assert trade_take is not None + assert tx_records is not None + + await full_node.process_transaction_records(records=tx_records) + await full_node.wait_for_wallets_synced(wallet_nodes=[wallet_node_maker, wallet_node_taker], timeout=60) + + await time_out_assert(15, wallet_maker.get_confirmed_balance, maker_funds + 1) + await time_out_assert(15, wallet_maker.get_unconfirmed_balance, maker_funds + 1) + await time_out_assert(15, cat_wallet_maker.get_confirmed_balance, xch_to_cat_amount - 1) + await time_out_assert(15, cat_wallet_maker.get_unconfirmed_balance, xch_to_cat_amount - 1)