Skip to content

Commit

Permalink
Coin simplification (#11567)
Browse files Browse the repository at this point in the history
* factor out as_list() from Coin into free function. Remove unused name_str() from Coin. Minor optimization of hash_coin_ids() for common case.

* extend Coin unit test
  • Loading branch information
arvidn authored May 23, 2022
1 parent 3d8081c commit 00a2a15
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 29 deletions.
4 changes: 2 additions & 2 deletions chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chia.server.outbound_message import NodeType, make_msg
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.announcement import Announcement
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.coin import Coin, coin_as_list
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.spend_bundle import SpendBundle
from chia.util.bech32m import decode_puzzle_hash, encode_puzzle_hash
Expand Down Expand Up @@ -524,7 +524,7 @@ async def create_new_wallet(self, request: Dict):
assert did_wallet.did_info.temp_pubkey is not None
my_did = did_wallet.get_my_DID()
coin_name = did_wallet.did_info.temp_coin.name().hex()
coin_list = did_wallet.did_info.temp_coin.as_list()
coin_list = coin_as_list(did_wallet.did_info.temp_coin)
newpuzhash = did_wallet.did_info.temp_puzhash
pubkey = did_wallet.did_info.temp_pubkey
return {
Expand Down
14 changes: 7 additions & 7 deletions chia/types/blockchain_format/coin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def get_hash(self) -> bytes32:
def name(self) -> bytes32:
return self.get_hash()

def as_list(self) -> List[Any]:
return [self.parent_coin_info, self.puzzle_hash, self.amount]

@property
def name_str(self) -> str:
return self.name().hex()

@classmethod
def from_bytes(cls, blob):
# this function is never called. We rely on the standard streamable
Expand All @@ -53,7 +46,14 @@ def __bytes__(self) -> bytes: # pylint: disable=E0308
assert False


def coin_as_list(c: Coin) -> List[Any]:
return [c.parent_coin_info, c.puzzle_hash, c.amount]


def hash_coin_ids(coin_ids: List[bytes32]) -> bytes32:
if len(coin_ids) == 1:
return std_hash(coin_ids[0])

coin_ids.sort(reverse=True)
buffer = bytearray()

Expand Down
4 changes: 2 additions & 2 deletions chia/wallet/cat_wallet/cat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from blspy import G2Element

from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.coin import Coin, coin_as_list
from chia.types.blockchain_format.program import Program, INFINITE_COST
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.condition_opcodes import ConditionOpcode
Expand Down Expand Up @@ -105,7 +105,7 @@ def unsigned_spend_bundle_for_spendable_cats(mod_code: Program, spendable_cat_li
ids = []
for _ in spendable_cat_list:
infos_for_next.append(next_info_for_spendable_cat(_))
infos_for_me.append(Program.to(_.coin.as_list()))
infos_for_me.append(Program.to(coin_as_list(_.coin)))
ids.append(_.coin.name())

coin_spends = []
Expand Down
4 changes: 2 additions & 2 deletions chia/wallet/trade_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set

from chia.protocols.wallet_protocol import CoinState
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.coin import Coin, coin_as_list
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.spend_bundle import SpendBundle
Expand Down Expand Up @@ -435,7 +435,7 @@ async def calculate_tx_records_for_offer(self, offer: Offer, validate: bool) ->
for wid, grouped_removals in removal_dict.items():
wallet = self.wallet_state_manager.wallets[wid]
to_puzzle_hash = bytes32([1] * 32) # We use all zeros to be clear not to send here
removal_tree_hash = Program.to([rem.as_list() for rem in grouped_removals]).get_tree_hash()
removal_tree_hash = Program.to([coin_as_list(rem) for rem in grouped_removals]).get_tree_hash()
# We also need to calculate the sent amount
removed: int = sum(c.amount for c in grouped_removals)
change_coins: List[Coin] = addition_dict[wid] if wid in addition_dict else []
Expand Down
4 changes: 2 additions & 2 deletions chia/wallet/trading/offer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from blspy import G2Element

from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.coin import Coin, coin_as_list
from chia.types.blockchain_format.program import Program
from chia.types.announcement import Announcement
from chia.types.coin_spend import CoinSpend
Expand Down Expand Up @@ -60,7 +60,7 @@ def notarize_payments(
) -> Dict[Optional[bytes32], List[NotarizedPayment]]:
# This sort should be reproducible in CLVM with `>s`
sorted_coins: List[Coin] = sorted(coins, key=Coin.name)
sorted_coin_list: List[List] = [c.as_list() for c in sorted_coins]
sorted_coin_list: List[List] = [coin_as_list(c) for c in sorted_coins]
nonce: bytes32 = Program.to(sorted_coin_list).get_tree_hash()

notarized_payments: Dict[Optional[bytes32], List[NotarizedPayment]] = {}
Expand Down
60 changes: 46 additions & 14 deletions tests/core/custom_types/test_coin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from chia.util.hash import std_hash
import io
import pytest


def coin_serialize(amount: uint64, clvm_serialize: bytes, full_serialize: bytes):
Expand All @@ -24,20 +26,50 @@ def coin_serialize(amount: uint64, clvm_serialize: bytes, full_serialize: bytes)
assert c2 == c


class TestCoin:
def test_coin_serialization(self):
def test_serialization():

coin_serialize(uint64(0xFFFF), bytes([0, 0xFF, 0xFF]), bytes([0, 0, 0, 0, 0, 0, 0xFF, 0xFF]))
coin_serialize(uint64(1337000000), bytes([0x4F, 0xB1, 0x00, 0x40]), bytes([0, 0, 0, 0, 0x4F, 0xB1, 0x00, 0x40]))
coin_serialize(uint64(0xFFFF), bytes([0, 0xFF, 0xFF]), bytes([0, 0, 0, 0, 0, 0, 0xFF, 0xFF]))
coin_serialize(uint64(1337000000), bytes([0x4F, 0xB1, 0x00, 0x40]), bytes([0, 0, 0, 0, 0x4F, 0xB1, 0x00, 0x40]))

# if the amount is 0, the amount is omitted in the "short" format,
# that's hashed
coin_serialize(uint64(0), b"", bytes([0, 0, 0, 0, 0, 0, 0, 0]))
# if the amount is 0, the amount is omitted in the "short" format,
# that's hashed
coin_serialize(uint64(0), b"", bytes([0, 0, 0, 0, 0, 0, 0, 0]))

# when amount is > INT64_MAX, the "short" serialization format is 1 byte
# longer, since it needs a leading zero to make it positive
coin_serialize(
uint64(0xFFFFFFFFFFFFFFFF),
bytes([0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
bytes([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
)
# when amount is > INT64_MAX, the "short" serialization format is 1 byte
# longer, since it needs a leading zero to make it positive
coin_serialize(
uint64(0xFFFFFFFFFFFFFFFF),
bytes([0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
bytes([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
)


@pytest.mark.parametrize(
"amount, clvm",
[
(0, []),
(1, [1]),
(0xFF, [0, 0xFF]),
(0xFFFF, [0, 0xFF, 0xFF]),
(0xFFFFFF, [0, 0xFF, 0xFF, 0xFF]),
(0xFFFFFFFF, [0, 0xFF, 0xFF, 0xFF, 0xFF]),
(0xFFFFFFFFFF, [0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0xFFFFFFFFFFFF, [0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0xFFFFFFFFFFFFFF, [0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0xFFFFFFFFFFFFFFFF, [0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0x7F, [0x7F]),
(0x7FFF, [0x7F, 0xFF]),
(0x7FFFFF, [0x7F, 0xFF, 0xFF]),
(0x7FFFFFFF, [0x7F, 0xFF, 0xFF, 0xFF]),
(0x7FFFFFFFFF, [0x7F, 0xFF, 0xFF, 0xFF, 0xFF]),
(0x7FFFFFFFFFFF, [0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0x7FFFFFFFFFFFFF, [0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
(0x7FFFFFFFFFFFFFFF, [0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]),
],
)
def test_name(amount: int, clvm: List[int]) -> None:

H1 = bytes32(b"a" * 32)
H2 = bytes32(b"b" * 32)

assert Coin(H1, H2, uint64(amount)).name() == std_hash(H1 + H2 + bytes(clvm))

0 comments on commit 00a2a15

Please sign in to comment.