Skip to content

Commit

Permalink
Annotate test_coin_store.py (#15571)
Browse files Browse the repository at this point in the history
Annotate test_coin_store.py.
  • Loading branch information
AmineKhaldi authored Jun 21, 2023
1 parent 8fd44c1 commit 4f2dec8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 31 deletions.
1 change: 0 additions & 1 deletion mypy-exclusions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ tests.core.custom_types.test_coin
tests.core.custom_types.test_spend_bundle
tests.core.daemon.test_daemon
tests.core.full_node.full_sync.test_full_sync
tests.core.full_node.stores.test_coin_store
tests.core.full_node.stores.test_full_node_store
tests.core.full_node.stores.test_hint_store
tests.core.full_node.stores.test_sync_store
Expand Down
74 changes: 44 additions & 30 deletions tests/core/full_node/stores/test_coin_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import List, Optional, Set, Tuple

import pytest
Expand All @@ -11,7 +12,7 @@
from chia.full_node.block_store import BlockStore
from chia.full_node.coin_store import CoinStore
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions
from chia.simulator.block_tools import test_constants
from chia.simulator.block_tools import BlockTools, test_constants
from chia.simulator.wallet_tools import WalletTool
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
Expand Down Expand Up @@ -51,7 +52,7 @@ def get_future_reward_coins(block: FullBlock) -> Tuple[Coin, Coin]:

class TestCoinStoreWithBlocks:
@pytest.mark.asyncio
async def test_basic_coin_store(self, db_version, softfork_height, bt):
async def test_basic_coin_store(self, db_version: int, softfork_height: uint32, bt: BlockTools) -> None:
wallet_a = WALLET_A
reward_ph = wallet_a.get_new_puzzlehash()

Expand Down Expand Up @@ -135,11 +136,13 @@ async def test_basic_coin_store(self, db_version, softfork_height, bt):
for coin_name in tx_removals:
# Check that the removed coins are set to spent
record = await coin_store.get_coin_record(coin_name)
assert record is not None
assert record.spent
all_records.add(record)
for coin in tx_additions:
# Check that the added coins are added
record = await coin_store.get_coin_record(coin.name())
assert record is not None
assert not record.spent
assert coin == record.coin
all_records.add(record)
Expand All @@ -156,7 +159,7 @@ async def test_basic_coin_store(self, db_version, softfork_height, bt):
should_be_included = set()

@pytest.mark.asyncio
async def test_set_spent(self, db_version, bt):
async def test_set_spent(self, db_version: int, bt: BlockTools) -> None:
blocks = bt.get_consecutive_blocks(9, [])

async with DBConnection(db_version) as db_wrapper:
Expand All @@ -181,23 +184,25 @@ async def test_set_spent(self, db_version, bt):
coins = block.get_included_reward_coins()
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]

await coin_store._set_spent([r.name for r in records], block.height)
await coin_store._set_spent([r.name for r in records if r is not None], block.height)

if len(records) > 0:
for r in records:
assert r is not None
assert (await coin_store.get_coin_record(r.name)) is not None

# Check that we can't spend a coin twice in DB
with pytest.raises(ValueError, match="Invalid operation to set spent"):
await coin_store._set_spent([r.name for r in records], block.height)
await coin_store._set_spent([r.name for r in records if r is not None], block.height)

records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
for record in records:
assert record is not None
assert record.spent
assert record.spent_block_index == block.height

@pytest.mark.asyncio
async def test_num_unspent(self, bt, db_version):
async def test_num_unspent(self, bt: BlockTools, db_version: int) -> None:
blocks = bt.get_consecutive_blocks(37, [])

expect_unspent = 0
Expand Down Expand Up @@ -229,7 +234,7 @@ async def test_num_unspent(self, bt, db_version):
assert test_excercised

@pytest.mark.asyncio
async def test_rollback(self, db_version, bt):
async def test_rollback(self, db_version: int, bt: BlockTools) -> None:
blocks = bt.get_consecutive_blocks(20)

async with DBConnection(db_version) as db_wrapper:
Expand All @@ -251,16 +256,16 @@ async def test_rollback(self, db_version, bt):
additions,
removals,
)
coins = list(block.get_included_reward_coins())
records: List[CoinRecord] = [await coin_store.get_coin_record(coin.name()) for coin in coins]
coins = block.get_included_reward_coins()
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]

spend_selected_coin = selected_coin is not None
if block.height != 0 and selected_coin is None:
# Select the first CoinRecord which will be spent at the next transaction block.
selected_coin = records[0]
await coin_store._set_spent([r.name for r in records[1:]], block.height)
await coin_store._set_spent([r.name for r in records[1:] if r is not None], block.height)
else:
await coin_store._set_spent([r.name for r in records], block.height)
await coin_store._set_spent([r.name for r in records if r is not None], block.height)

if spend_selected_coin:
assert selected_coin is not None
Expand All @@ -286,14 +291,15 @@ async def test_rollback(self, db_version, bt):
reorg_index = selected_coin.confirmed_block_index

# Get all CoinRecords.
all_records: List[CoinRecord] = [await coin_store.get_coin_record(coin.name()) for coin in all_coins]
all_records = [await coin_store.get_coin_record(coin.name()) for coin in all_coins]

# The reorg will revert the creation and spend of many coins. It will also revert the spend (but not the
# creation) of the selected coin.
changed_records = await coin_store.rollback_to_block(reorg_index)
changed_coin_records = [cr.coin for cr in changed_records]
assert selected_coin in changed_records
for coin_record in all_records:
assert coin_record is not None
if coin_record.confirmed_block_index > reorg_index:
assert coin_record.coin in changed_coin_records
if coin_record.spent_block_index > reorg_index:
Expand All @@ -313,7 +319,7 @@ async def test_rollback(self, db_version, bt):
assert record is None

@pytest.mark.asyncio
async def test_basic_reorg(self, tmp_dir, db_version, bt):
async def test_basic_reorg(self, tmp_dir: Path, db_version: int, bt: BlockTools) -> None:
async with DBConnection(db_version) as db_wrapper:
initial_block_count = 30
reorg_length = 15
Expand Down Expand Up @@ -366,7 +372,7 @@ async def test_basic_reorg(self, tmp_dir, db_version, bt):
b.shut_down()

@pytest.mark.asyncio
async def test_get_puzzle_hash(self, tmp_dir, db_version, bt):
async def test_get_puzzle_hash(self, tmp_dir: Path, db_version: int, bt: BlockTools) -> None:
async with DBConnection(db_version) as db_wrapper:
num_blocks = 20
farmer_ph = bytes32(32 * b"0")
Expand Down Expand Up @@ -395,7 +401,7 @@ async def test_get_puzzle_hash(self, tmp_dir, db_version, bt):
b.shut_down()

@pytest.mark.asyncio
async def test_get_coin_states(self, tmp_dir, db_version):
async def test_get_coin_states(self, db_version: int) -> None:
async with DBConnection(db_version) as db_wrapper:
crs = [
CoinRecord(
Expand All @@ -420,32 +426,40 @@ async def test_get_coin_states(self, tmp_dir, db_version):
coin_store = await CoinStore.create(db_wrapper)
await coin_store._add_coin_records(crs)

assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0)) == 300
assert len(await coin_store.get_coin_states_by_puzzle_hashes(False, {std_hash(b"2")}, 0)) == 0
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 300)) == 151
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 603)) == 0
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"1")}, 0)) == 0
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(0))) == 300
assert len(await coin_store.get_coin_states_by_puzzle_hashes(False, {std_hash(b"2")}, uint32(0))) == 0
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(300))) == 151
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(603))) == 0
assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"1")}, uint32(0))) == 0

# test max_items limit
for limit in [0, 1, 42, 300]:
assert (
len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0, max_items=limit))
len(
await coin_store.get_coin_states_by_puzzle_hashes(
True, {std_hash(b"2")}, uint32(0), max_items=limit
)
)
== limit
)

# if the limit is very high, we should get all of them
assert (
len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0, max_items=10000))
len(
await coin_store.get_coin_states_by_puzzle_hashes(
True, {std_hash(b"2")}, uint32(0), max_items=10000
)
)
== 300
)

coins = {cr.coin.name() for cr in crs}
bad_coins = [std_hash(cr.coin.name()) for cr in crs]
assert len(await coin_store.get_coin_states_by_ids(True, coins, 0)) == 600
assert len(await coin_store.get_coin_states_by_ids(False, coins, 0)) == 0
assert len(await coin_store.get_coin_states_by_ids(True, coins, 300)) == 302
assert len(await coin_store.get_coin_states_by_ids(True, coins, 603)) == 0
assert len(await coin_store.get_coin_states_by_ids(True, bad_coins, 0)) == 0
bad_coins = {std_hash(cr.coin.name()) for cr in crs}
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0))) == 600
assert len(await coin_store.get_coin_states_by_ids(False, coins, uint32(0))) == 0
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(300))) == 302
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(603))) == 0
assert len(await coin_store.get_coin_states_by_ids(True, bad_coins, uint32(0))) == 0
# Test max_height
assert len(await coin_store.get_coin_states_by_ids(True, coins, max_height=uint32(603))) == 600
assert len(await coin_store.get_coin_states_by_ids(True, coins, max_height=uint32(602))) == 600
Expand All @@ -467,7 +481,7 @@ async def test_get_coin_states(self, tmp_dir, db_version):

# test max_items limit
for limit in [0, 1, 42, 300]:
assert len(await coin_store.get_coin_states_by_ids(True, coins, 0, max_items=limit)) == limit
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=limit)) == limit

# if the limit is very high, we should get all of them
assert len(await coin_store.get_coin_states_by_ids(True, coins, 0, max_items=10000)) == 600
assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=10000)) == 600

0 comments on commit 4f2dec8

Please sign in to comment.