From 058b80775a7fa4e0235dd7dbed13114f87b51eb8 Mon Sep 17 00:00:00 2001 From: Matt Hauff Date: Mon, 24 Jun 2024 14:34:02 -0700 Subject: [PATCH] [CHIA-711] Add `WalletActionScope` (#18125) * Add the concept of 'action scopes' * Add `WalletActionScope` * Add the concept of 'action scopes' * pylint and test coverage * add try/finally * add try/except * Undo giving a variable a name * Test coverage * Ban partial sigining in another scenario * Make WalletActionScope an alias instead * Add extra_spends to the action scope flow * Add test for .add_pending_transactions --- .../_tests/wallet/test_wallet_action_scope.py | 80 +++++++++++++ .../wallet/test_wallet_state_manager.py | 111 +++++++++++++++++- chia/wallet/wallet_action_scope.py | 95 +++++++++++++++ chia/wallet/wallet_state_manager.py | 69 ++++++++--- 4 files changed, 338 insertions(+), 17 deletions(-) create mode 100644 chia/_tests/wallet/test_wallet_action_scope.py create mode 100644 chia/wallet/wallet_action_scope.py diff --git a/chia/_tests/wallet/test_wallet_action_scope.py b/chia/_tests/wallet/test_wallet_action_scope.py new file mode 100644 index 000000000000..54583e96bc9c --- /dev/null +++ b/chia/_tests/wallet/test_wallet_action_scope.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import pytest +from chia_rs import G2Element + +from chia._tests.cmds.wallet.test_consts import STD_TX +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.spend_bundle import SpendBundle +from chia.wallet.signer_protocol import SigningResponse +from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.wallet_action_scope import WalletSideEffects +from chia.wallet.wallet_state_manager import WalletStateManager + +MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32)) +MOCK_SB = SpendBundle([], G2Element()) + + +def test_back_and_forth_serialization() -> None: + assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects() + assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR], [MOCK_SB]))) == WalletSideEffects( + [STD_TX], [MOCK_SR], [MOCK_SB] + ) + assert WalletSideEffects.from_bytes( + bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB])) + ) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB]) + + +@dataclass +class MockWalletStateManager: + most_recent_call: Optional[ + Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse], List[SpendBundle]] + ] = None + + async def add_pending_transactions( + self, + txs: List[TransactionRecord], + push: bool, + merge_spends: bool, + sign: bool, + additional_signing_responses: List[SigningResponse], + extra_spends: List[SpendBundle], + ) -> List[TransactionRecord]: + self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses, extra_spends) + return txs + + +MockWalletStateManager.new_action_scope = WalletStateManager.new_action_scope # type: ignore[attr-defined] + + +@pytest.mark.anyio +async def test_wallet_action_scope() -> None: + wsm = MockWalletStateManager() + async with wsm.new_action_scope( # type: ignore[attr-defined] + push=True, + merge_spends=False, + sign=True, + additional_signing_responses=[], + extra_spends=[], + ) as action_scope: + async with action_scope.use() as interface: + interface.side_effects.transactions = [STD_TX] + + with pytest.raises(RuntimeError): + action_scope.side_effects + + assert action_scope.side_effects.transactions == [STD_TX] + assert wsm.most_recent_call == ([STD_TX], True, False, True, [], []) + + async with wsm.new_action_scope( # type: ignore[attr-defined] + push=False, merge_spends=True, sign=True, additional_signing_responses=[] + ) as action_scope: + async with action_scope.use() as interface: + interface.side_effects.transactions = [] + + assert action_scope.side_effects.transactions == [] + assert wsm.most_recent_call == ([], False, True, True, [], []) diff --git a/chia/_tests/wallet/test_wallet_state_manager.py b/chia/_tests/wallet/test_wallet_state_manager.py index 06ee7bf518e5..826db346e8d3 100644 --- a/chia/_tests/wallet/test_wallet_state_manager.py +++ b/chia/_tests/wallet/test_wallet_state_manager.py @@ -1,19 +1,26 @@ from __future__ import annotations from contextlib import asynccontextmanager -from typing import AsyncIterator +from typing import AsyncIterator, List import pytest +from chia_rs import G2Element +from chia._tests.environments.wallet import WalletTestFramework from chia._tests.util.setup_nodes import OldSimulatorsAndWallets from chia.protocols.wallet_protocol import CoinState from chia.server.outbound_message import NodeType from chia.types.blockchain_format.coin import Coin +from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.coin_spend import make_spend from chia.types.peer_info import PeerInfo +from chia.types.spend_bundle import SpendBundle from chia.util.ints import uint32, uint64 from chia.wallet.derivation_record import DerivationRecord from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened +from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_state_manager import WalletStateManager @@ -95,3 +102,105 @@ async def test_determine_coin_type(simulator_and_wallet: OldSimulatorsAndWallets assert (None, None) == await wallet_state_manager.determine_coin_type( peer, CoinState(Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), uint32(0), uint32(0)), None ) + + +@pytest.mark.parametrize( + "wallet_environments", + [{"num_environments": 1, "blocks_needed": [1], "trusted": True, "reuse_puzhash": True}], + indirect=True, +) +@pytest.mark.limit_consensus_modes(reason="irrelevant") +@pytest.mark.anyio +async def test_commit_transactions_to_db(wallet_environments: WalletTestFramework) -> None: + env = wallet_environments.environments[0] + wsm = env.wallet_state_manager + + coins = list( + await wsm.main_wallet.select_coins( + uint64(2_000_000_000_000), coin_selection_config=wallet_environments.tx_config.coin_selection_config + ) + ) + [tx1] = await wsm.main_wallet.generate_signed_transaction( + uint64(0), + bytes32([0] * 32), + wallet_environments.tx_config, + coins={coins[0]}, + ) + [tx2] = await wsm.main_wallet.generate_signed_transaction( + uint64(0), + bytes32([0] * 32), + wallet_environments.tx_config, + coins={coins[1]}, + ) + + def flatten_spend_bundles(txs: List[TransactionRecord]) -> List[SpendBundle]: + return [tx.spend_bundle for tx in txs if tx.spend_bundle is not None] + + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=False, + sign=False, + extra_spends=[], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 2 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + + extra_coin_spend = make_spend( + Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), Program.to(1), Program.to([None]) + ) + extra_spend = SpendBundle([extra_coin_spend], G2Element()) + + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=False, + sign=False, + extra_spends=[extra_spend], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 2 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends] + + new_txs = await wsm.add_pending_transactions( + [tx1, tx2], + push=False, + merge_spends=True, + sign=False, + extra_spends=[extra_spend], + ) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 1 + for bundle in bundles: + assert bundle.aggregated_signature == G2Element() + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 0 + ) + assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends] + + [tx1, tx2] = await wsm.add_pending_transactions([tx1, tx2], push=True, merge_spends=True, sign=True) + bundles = flatten_spend_bundles(new_txs) + assert len(bundles) == 1 + assert ( + len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX)) + == 2 + ) + + await wallet_environments.full_node.wait_transaction_records_entered_mempool([tx1, tx2]) diff --git a/chia/wallet/wallet_action_scope.py b/chia/wallet/wallet_action_scope.py new file mode 100644 index 000000000000..85f4cb759b8f --- /dev/null +++ b/chia/wallet/wallet_action_scope.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast + +from chia.types.spend_bundle import SpendBundle +from chia.util.action_scope import ActionScope +from chia.wallet.signer_protocol import SigningResponse +from chia.wallet.transaction_record import TransactionRecord + +if TYPE_CHECKING: + # Avoid a circular import here + from chia.wallet.wallet_state_manager import WalletStateManager + + +@dataclass +class WalletSideEffects: + transactions: List[TransactionRecord] = field(default_factory=list) + signing_responses: List[SigningResponse] = field(default_factory=list) + extra_spends: List[SpendBundle] = field(default_factory=list) + + def __bytes__(self) -> bytes: + blob = b"" + blob += len(self.transactions).to_bytes(4, "big") + for tx in self.transactions: + tx_bytes = bytes(tx) + blob += len(tx_bytes).to_bytes(4, "big") + tx_bytes + blob += len(self.signing_responses).to_bytes(4, "big") + for sr in self.signing_responses: + sr_bytes = bytes(sr) + blob += len(sr_bytes).to_bytes(4, "big") + sr_bytes + blob += len(self.extra_spends).to_bytes(4, "big") + for sb in self.extra_spends: + sb_bytes = bytes(sb) + blob += len(sb_bytes).to_bytes(4, "big") + sb_bytes + return blob + + @classmethod + def from_bytes(cls, blob: bytes) -> WalletSideEffects: + instance = cls() + while blob != b"": + tx_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, tx_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.transactions.append(TransactionRecord.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] + sr_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, sr_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.signing_responses.append(SigningResponse.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] + sb_len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + for _ in range(0, sb_len_prefix): + len_prefix = int.from_bytes(blob[:4], "big") + blob = blob[4:] + instance.extra_spends.append(SpendBundle.from_bytes(blob[:len_prefix])) + blob = blob[len_prefix:] + + return instance + + +WalletActionScope = ActionScope[WalletSideEffects] + + +@contextlib.asynccontextmanager +async def new_wallet_action_scope( + wallet_state_manager: WalletStateManager, + push: bool = False, + merge_spends: bool = True, + sign: Optional[bool] = None, + additional_signing_responses: List[SigningResponse] = [], + extra_spends: List[SpendBundle] = [], +) -> AsyncIterator[WalletActionScope]: + async with ActionScope.new_scope(WalletSideEffects) as self: + self = cast(WalletActionScope, self) + async with self.use() as interface: + interface.side_effects.signing_responses = additional_signing_responses.copy() + interface.side_effects.extra_spends = extra_spends.copy() + + yield self + + self.side_effects.transactions = await wallet_state_manager.add_pending_transactions( + self.side_effects.transactions, + push=push, + merge_spends=merge_spends, + sign=sign, + additional_signing_responses=self.side_effects.signing_responses, + extra_spends=self.side_effects.extra_spends, + ) diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index f0a8073f7724..0b2c514073d1 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import logging import multiprocessing.context @@ -143,6 +144,7 @@ from chia.wallet.vc_wallet.vc_store import VCStore from chia.wallet.vc_wallet.vc_wallet import VCWallet from chia.wallet.wallet import Wallet +from chia.wallet.wallet_action_scope import WalletActionScope, new_wallet_action_scope from chia.wallet.wallet_blockchain import WalletBlockchain from chia.wallet.wallet_coin_record import MetadataTypes, WalletCoinRecord from chia.wallet.wallet_coin_store import WalletCoinStore @@ -2254,9 +2256,11 @@ async def coin_added( async def add_pending_transactions( self, tx_records: List[TransactionRecord], + push: bool = True, merge_spends: bool = True, sign: Optional[bool] = None, - additional_signing_responses: List[SigningResponse] = [], + additional_signing_responses: Optional[List[SigningResponse]] = None, + extra_spends: Optional[List[SpendBundle]] = None, ) -> List[TransactionRecord]: """ Add a list of transactions to be submitted to the full node. @@ -2267,6 +2271,8 @@ async def add_pending_transactions( agg_spend: SpendBundle = SpendBundle.aggregate( [tx.spend_bundle for tx in tx_records if tx.spend_bundle is not None] ) + if extra_spends is not None: + agg_spend = SpendBundle.aggregate([agg_spend, *extra_spends]) actual_spend_involved: bool = agg_spend != SpendBundle([], G2Element()) if merge_spends and actual_spend_involved: tx_records = [ @@ -2277,27 +2283,39 @@ async def add_pending_transactions( ) for i, tx in enumerate(tx_records) ] + elif extra_spends is not None and extra_spends != []: + extra_spends.extend([] if tx_records[0].spend_bundle is None else [tx_records[0].spend_bundle]) + extra_spend_bundle = SpendBundle.aggregate(extra_spends) + tx_records = [ + dataclasses.replace( + tx, + spend_bundle=extra_spend_bundle if i == 0 else tx.spend_bundle, + name=extra_spend_bundle.name() if i == 0 else bytes32.secret(), + ) + for i, tx in enumerate(tx_records) + ] if sign: tx_records, _ = await self.sign_transactions( tx_records, - additional_signing_responses, - additional_signing_responses != [], + [] if additional_signing_responses is None else additional_signing_responses, + additional_signing_responses != [] and additional_signing_responses is not None, ) - all_coins_names = [] - async with self.db_wrapper.writer_maybe_transaction(): - for tx_record in tx_records: - # Wallet node will use this queue to retry sending this transaction until full nodes receives it - await self.tx_store.add_transaction_record(tx_record) - all_coins_names.extend([coin.name() for coin in tx_record.additions]) - all_coins_names.extend([coin.name() for coin in tx_record.removals]) + if push: + all_coins_names = [] + async with self.db_wrapper.writer_maybe_transaction(): + for tx_record in tx_records: + # Wallet node will use this queue to retry sending this transaction until full nodes receives it + await self.tx_store.add_transaction_record(tx_record) + all_coins_names.extend([coin.name() for coin in tx_record.additions]) + all_coins_names.extend([coin.name() for coin in tx_record.removals]) - await self.add_interested_coin_ids(all_coins_names) + await self.add_interested_coin_ids(all_coins_names) - if actual_spend_involved: - self.tx_pending_changed() - for wallet_id in {tx.wallet_id for tx in tx_records}: - self.state_changed("pending_transaction", wallet_id) - await self.wallet_node.update_ui() + if actual_spend_involved: + self.tx_pending_changed() + for wallet_id in {tx.wallet_id for tx in tx_records}: + self.state_changed("pending_transaction", wallet_id) + await self.wallet_node.update_ui() return tx_records @@ -2738,3 +2756,22 @@ async def submit_transactions(self, signed_txs: List[SignedTransaction]) -> List for bundle in bundles: await self.wallet_node.push_tx(bundle) return [bundle.name() for bundle in bundles] + + @contextlib.asynccontextmanager + async def new_action_scope( + self, + push: bool = False, + merge_spends: bool = True, + sign: Optional[bool] = None, + additional_signing_responses: List[SigningResponse] = [], + extra_spends: List[SpendBundle] = [], + ) -> AsyncIterator[WalletActionScope]: + async with new_wallet_action_scope( + self, + push=push, + merge_spends=merge_spends, + sign=sign, + additional_signing_responses=additional_signing_responses, + extra_spends=extra_spends, + ) as action_scope: + yield action_scope