Skip to content

Commit

Permalink
[CHIA-711] Add WalletActionScope (#18125)
Browse files Browse the repository at this point in the history
* Add the concept of 'action scopes'

* Add `WalletActionScope`

* Add the concept of 'action scopes'

* pylint and test coverage

* add try/finally

* add try/except

* Undo giving a variable a name

* Test coverage

* Ban partial sigining in another scenario

* Make WalletActionScope an alias instead

* Add extra_spends to the action scope flow

* Add test for .add_pending_transactions
  • Loading branch information
Quexington authored Jun 24, 2024
1 parent bce4b4a commit 058b807
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 17 deletions.
80 changes: 80 additions & 0 deletions chia/_tests/wallet/test_wallet_action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple

import pytest
from chia_rs import G2Element

from chia._tests.cmds.wallet.test_consts import STD_TX
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.spend_bundle import SpendBundle
from chia.wallet.signer_protocol import SigningResponse
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.wallet_action_scope import WalletSideEffects
from chia.wallet.wallet_state_manager import WalletStateManager

MOCK_SR = SigningResponse(b"hey", bytes32([0] * 32))
MOCK_SB = SpendBundle([], G2Element())


def test_back_and_forth_serialization() -> None:
assert bytes(WalletSideEffects()) == b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
assert WalletSideEffects.from_bytes(bytes(WalletSideEffects())) == WalletSideEffects()
assert WalletSideEffects.from_bytes(bytes(WalletSideEffects([STD_TX], [MOCK_SR], [MOCK_SB]))) == WalletSideEffects(
[STD_TX], [MOCK_SR], [MOCK_SB]
)
assert WalletSideEffects.from_bytes(
bytes(WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB]))
) == WalletSideEffects([STD_TX, STD_TX], [MOCK_SR, MOCK_SR], [MOCK_SB, MOCK_SB])


@dataclass
class MockWalletStateManager:
most_recent_call: Optional[
Tuple[List[TransactionRecord], bool, bool, bool, List[SigningResponse], List[SpendBundle]]
] = None

async def add_pending_transactions(
self,
txs: List[TransactionRecord],
push: bool,
merge_spends: bool,
sign: bool,
additional_signing_responses: List[SigningResponse],
extra_spends: List[SpendBundle],
) -> List[TransactionRecord]:
self.most_recent_call = (txs, push, merge_spends, sign, additional_signing_responses, extra_spends)
return txs


MockWalletStateManager.new_action_scope = WalletStateManager.new_action_scope # type: ignore[attr-defined]


@pytest.mark.anyio
async def test_wallet_action_scope() -> None:
wsm = MockWalletStateManager()
async with wsm.new_action_scope( # type: ignore[attr-defined]
push=True,
merge_spends=False,
sign=True,
additional_signing_responses=[],
extra_spends=[],
) as action_scope:
async with action_scope.use() as interface:
interface.side_effects.transactions = [STD_TX]

with pytest.raises(RuntimeError):
action_scope.side_effects

assert action_scope.side_effects.transactions == [STD_TX]
assert wsm.most_recent_call == ([STD_TX], True, False, True, [], [])

async with wsm.new_action_scope( # type: ignore[attr-defined]
push=False, merge_spends=True, sign=True, additional_signing_responses=[]
) as action_scope:
async with action_scope.use() as interface:
interface.side_effects.transactions = []

assert action_scope.side_effects.transactions == []
assert wsm.most_recent_call == ([], False, True, True, [], [])
111 changes: 110 additions & 1 deletion chia/_tests/wallet/test_wallet_state_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from typing import AsyncIterator
from typing import AsyncIterator, List

import pytest
from chia_rs import G2Element

from chia._tests.environments.wallet import WalletTestFramework
from chia._tests.util.setup_nodes import OldSimulatorsAndWallets
from chia.protocols.wallet_protocol import CoinState
from chia.server.outbound_message import NodeType
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_spend import make_spend
from chia.types.peer_info import PeerInfo
from chia.types.spend_bundle import SpendBundle
from chia.util.ints import uint32, uint64
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.transaction_type import TransactionType
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_state_manager import WalletStateManager

Expand Down Expand Up @@ -95,3 +102,105 @@ async def test_determine_coin_type(simulator_and_wallet: OldSimulatorsAndWallets
assert (None, None) == await wallet_state_manager.determine_coin_type(
peer, CoinState(Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), uint32(0), uint32(0)), None
)


@pytest.mark.parametrize(
"wallet_environments",
[{"num_environments": 1, "blocks_needed": [1], "trusted": True, "reuse_puzhash": True}],
indirect=True,
)
@pytest.mark.limit_consensus_modes(reason="irrelevant")
@pytest.mark.anyio
async def test_commit_transactions_to_db(wallet_environments: WalletTestFramework) -> None:
env = wallet_environments.environments[0]
wsm = env.wallet_state_manager

coins = list(
await wsm.main_wallet.select_coins(
uint64(2_000_000_000_000), coin_selection_config=wallet_environments.tx_config.coin_selection_config
)
)
[tx1] = await wsm.main_wallet.generate_signed_transaction(
uint64(0),
bytes32([0] * 32),
wallet_environments.tx_config,
coins={coins[0]},
)
[tx2] = await wsm.main_wallet.generate_signed_transaction(
uint64(0),
bytes32([0] * 32),
wallet_environments.tx_config,
coins={coins[1]},
)

def flatten_spend_bundles(txs: List[TransactionRecord]) -> List[SpendBundle]:
return [tx.spend_bundle for tx in txs if tx.spend_bundle is not None]

assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=False,
sign=False,
extra_spends=[],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 2
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)

extra_coin_spend = make_spend(
Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), uint64(0)), Program.to(1), Program.to([None])
)
extra_spend = SpendBundle([extra_coin_spend], G2Element())

new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=False,
sign=False,
extra_spends=[extra_spend],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 2
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends]

new_txs = await wsm.add_pending_transactions(
[tx1, tx2],
push=False,
merge_spends=True,
sign=False,
extra_spends=[extra_spend],
)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 1
for bundle in bundles:
assert bundle.aggregated_signature == G2Element()
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 0
)
assert extra_coin_spend in [spend for bundle in bundles for spend in bundle.coin_spends]

[tx1, tx2] = await wsm.add_pending_transactions([tx1, tx2], push=True, merge_spends=True, sign=True)
bundles = flatten_spend_bundles(new_txs)
assert len(bundles) == 1
assert (
len(await wsm.tx_store.get_all_transactions_for_wallet(wsm.main_wallet.id(), type=TransactionType.OUTGOING_TX))
== 2
)

await wallet_environments.full_node.wait_transaction_records_entered_mempool([tx1, tx2])
95 changes: 95 additions & 0 deletions chia/wallet/wallet_action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast

from chia.types.spend_bundle import SpendBundle
from chia.util.action_scope import ActionScope
from chia.wallet.signer_protocol import SigningResponse
from chia.wallet.transaction_record import TransactionRecord

if TYPE_CHECKING:
# Avoid a circular import here
from chia.wallet.wallet_state_manager import WalletStateManager


@dataclass
class WalletSideEffects:
transactions: List[TransactionRecord] = field(default_factory=list)
signing_responses: List[SigningResponse] = field(default_factory=list)
extra_spends: List[SpendBundle] = field(default_factory=list)

def __bytes__(self) -> bytes:
blob = b""
blob += len(self.transactions).to_bytes(4, "big")
for tx in self.transactions:
tx_bytes = bytes(tx)
blob += len(tx_bytes).to_bytes(4, "big") + tx_bytes
blob += len(self.signing_responses).to_bytes(4, "big")
for sr in self.signing_responses:
sr_bytes = bytes(sr)
blob += len(sr_bytes).to_bytes(4, "big") + sr_bytes
blob += len(self.extra_spends).to_bytes(4, "big")
for sb in self.extra_spends:
sb_bytes = bytes(sb)
blob += len(sb_bytes).to_bytes(4, "big") + sb_bytes
return blob

@classmethod
def from_bytes(cls, blob: bytes) -> WalletSideEffects:
instance = cls()
while blob != b"":
tx_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, tx_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.transactions.append(TransactionRecord.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]
sr_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, sr_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.signing_responses.append(SigningResponse.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]
sb_len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
for _ in range(0, sb_len_prefix):
len_prefix = int.from_bytes(blob[:4], "big")
blob = blob[4:]
instance.extra_spends.append(SpendBundle.from_bytes(blob[:len_prefix]))
blob = blob[len_prefix:]

return instance


WalletActionScope = ActionScope[WalletSideEffects]


@contextlib.asynccontextmanager
async def new_wallet_action_scope(
wallet_state_manager: WalletStateManager,
push: bool = False,
merge_spends: bool = True,
sign: Optional[bool] = None,
additional_signing_responses: List[SigningResponse] = [],
extra_spends: List[SpendBundle] = [],
) -> AsyncIterator[WalletActionScope]:
async with ActionScope.new_scope(WalletSideEffects) as self:
self = cast(WalletActionScope, self)
async with self.use() as interface:
interface.side_effects.signing_responses = additional_signing_responses.copy()
interface.side_effects.extra_spends = extra_spends.copy()

yield self

self.side_effects.transactions = await wallet_state_manager.add_pending_transactions(
self.side_effects.transactions,
push=push,
merge_spends=merge_spends,
sign=sign,
additional_signing_responses=self.side_effects.signing_responses,
extra_spends=self.side_effects.extra_spends,
)
Loading

0 comments on commit 058b807

Please sign in to comment.