From 4f2dec820ade13a6d3e03056a11245e8b01a2813 Mon Sep 17 00:00:00 2001 From: Amine Khaldi Date: Wed, 21 Jun 2023 18:53:05 +0100 Subject: [PATCH] Annotate test_coin_store.py (#15571) Annotate test_coin_store.py. --- mypy-exclusions.txt | 1 - .../core/full_node/stores/test_coin_store.py | 74 +++++++++++-------- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/mypy-exclusions.txt b/mypy-exclusions.txt index 100f451b1e75..2f9a67a60a1e 100644 --- a/mypy-exclusions.txt +++ b/mypy-exclusions.txt @@ -99,7 +99,6 @@ tests.core.custom_types.test_coin tests.core.custom_types.test_spend_bundle tests.core.daemon.test_daemon tests.core.full_node.full_sync.test_full_sync -tests.core.full_node.stores.test_coin_store tests.core.full_node.stores.test_full_node_store tests.core.full_node.stores.test_hint_store tests.core.full_node.stores.test_sync_store diff --git a/tests/core/full_node/stores/test_coin_store.py b/tests/core/full_node/stores/test_coin_store.py index e4f654ffad9a..31e62d6dcb4a 100644 --- a/tests/core/full_node/stores/test_coin_store.py +++ b/tests/core/full_node/stores/test_coin_store.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from pathlib import Path from typing import List, Optional, Set, Tuple import pytest @@ -11,7 +12,7 @@ from chia.full_node.block_store import BlockStore from chia.full_node.coin_store import CoinStore from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions -from chia.simulator.block_tools import test_constants +from chia.simulator.block_tools import BlockTools, test_constants from chia.simulator.wallet_tools import WalletTool from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.sized_bytes import bytes32 @@ -51,7 +52,7 @@ def get_future_reward_coins(block: FullBlock) -> Tuple[Coin, Coin]: class TestCoinStoreWithBlocks: @pytest.mark.asyncio - async def test_basic_coin_store(self, db_version, softfork_height, bt): + async def test_basic_coin_store(self, db_version: int, softfork_height: uint32, bt: BlockTools) -> None: wallet_a = WALLET_A reward_ph = wallet_a.get_new_puzzlehash() @@ -135,11 +136,13 @@ async def test_basic_coin_store(self, db_version, softfork_height, bt): for coin_name in tx_removals: # Check that the removed coins are set to spent record = await coin_store.get_coin_record(coin_name) + assert record is not None assert record.spent all_records.add(record) for coin in tx_additions: # Check that the added coins are added record = await coin_store.get_coin_record(coin.name()) + assert record is not None assert not record.spent assert coin == record.coin all_records.add(record) @@ -156,7 +159,7 @@ async def test_basic_coin_store(self, db_version, softfork_height, bt): should_be_included = set() @pytest.mark.asyncio - async def test_set_spent(self, db_version, bt): + async def test_set_spent(self, db_version: int, bt: BlockTools) -> None: blocks = bt.get_consecutive_blocks(9, []) async with DBConnection(db_version) as db_wrapper: @@ -181,23 +184,25 @@ async def test_set_spent(self, db_version, bt): coins = block.get_included_reward_coins() records = [await coin_store.get_coin_record(coin.name()) for coin in coins] - await coin_store._set_spent([r.name for r in records], block.height) + await coin_store._set_spent([r.name for r in records if r is not None], block.height) if len(records) > 0: for r in records: + assert r is not None assert (await coin_store.get_coin_record(r.name)) is not None # Check that we can't spend a coin twice in DB with pytest.raises(ValueError, match="Invalid operation to set spent"): - await coin_store._set_spent([r.name for r in records], block.height) + await coin_store._set_spent([r.name for r in records if r is not None], block.height) records = [await coin_store.get_coin_record(coin.name()) for coin in coins] for record in records: + assert record is not None assert record.spent assert record.spent_block_index == block.height @pytest.mark.asyncio - async def test_num_unspent(self, bt, db_version): + async def test_num_unspent(self, bt: BlockTools, db_version: int) -> None: blocks = bt.get_consecutive_blocks(37, []) expect_unspent = 0 @@ -229,7 +234,7 @@ async def test_num_unspent(self, bt, db_version): assert test_excercised @pytest.mark.asyncio - async def test_rollback(self, db_version, bt): + async def test_rollback(self, db_version: int, bt: BlockTools) -> None: blocks = bt.get_consecutive_blocks(20) async with DBConnection(db_version) as db_wrapper: @@ -251,16 +256,16 @@ async def test_rollback(self, db_version, bt): additions, removals, ) - coins = list(block.get_included_reward_coins()) - records: List[CoinRecord] = [await coin_store.get_coin_record(coin.name()) for coin in coins] + coins = block.get_included_reward_coins() + records = [await coin_store.get_coin_record(coin.name()) for coin in coins] spend_selected_coin = selected_coin is not None if block.height != 0 and selected_coin is None: # Select the first CoinRecord which will be spent at the next transaction block. selected_coin = records[0] - await coin_store._set_spent([r.name for r in records[1:]], block.height) + await coin_store._set_spent([r.name for r in records[1:] if r is not None], block.height) else: - await coin_store._set_spent([r.name for r in records], block.height) + await coin_store._set_spent([r.name for r in records if r is not None], block.height) if spend_selected_coin: assert selected_coin is not None @@ -286,7 +291,7 @@ async def test_rollback(self, db_version, bt): reorg_index = selected_coin.confirmed_block_index # Get all CoinRecords. - all_records: List[CoinRecord] = [await coin_store.get_coin_record(coin.name()) for coin in all_coins] + all_records = [await coin_store.get_coin_record(coin.name()) for coin in all_coins] # The reorg will revert the creation and spend of many coins. It will also revert the spend (but not the # creation) of the selected coin. @@ -294,6 +299,7 @@ async def test_rollback(self, db_version, bt): changed_coin_records = [cr.coin for cr in changed_records] assert selected_coin in changed_records for coin_record in all_records: + assert coin_record is not None if coin_record.confirmed_block_index > reorg_index: assert coin_record.coin in changed_coin_records if coin_record.spent_block_index > reorg_index: @@ -313,7 +319,7 @@ async def test_rollback(self, db_version, bt): assert record is None @pytest.mark.asyncio - async def test_basic_reorg(self, tmp_dir, db_version, bt): + async def test_basic_reorg(self, tmp_dir: Path, db_version: int, bt: BlockTools) -> None: async with DBConnection(db_version) as db_wrapper: initial_block_count = 30 reorg_length = 15 @@ -366,7 +372,7 @@ async def test_basic_reorg(self, tmp_dir, db_version, bt): b.shut_down() @pytest.mark.asyncio - async def test_get_puzzle_hash(self, tmp_dir, db_version, bt): + async def test_get_puzzle_hash(self, tmp_dir: Path, db_version: int, bt: BlockTools) -> None: async with DBConnection(db_version) as db_wrapper: num_blocks = 20 farmer_ph = bytes32(32 * b"0") @@ -395,7 +401,7 @@ async def test_get_puzzle_hash(self, tmp_dir, db_version, bt): b.shut_down() @pytest.mark.asyncio - async def test_get_coin_states(self, tmp_dir, db_version): + async def test_get_coin_states(self, db_version: int) -> None: async with DBConnection(db_version) as db_wrapper: crs = [ CoinRecord( @@ -420,32 +426,40 @@ async def test_get_coin_states(self, tmp_dir, db_version): coin_store = await CoinStore.create(db_wrapper) await coin_store._add_coin_records(crs) - assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0)) == 300 - assert len(await coin_store.get_coin_states_by_puzzle_hashes(False, {std_hash(b"2")}, 0)) == 0 - assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 300)) == 151 - assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 603)) == 0 - assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"1")}, 0)) == 0 + assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(0))) == 300 + assert len(await coin_store.get_coin_states_by_puzzle_hashes(False, {std_hash(b"2")}, uint32(0))) == 0 + assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(300))) == 151 + assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, uint32(603))) == 0 + assert len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"1")}, uint32(0))) == 0 # test max_items limit for limit in [0, 1, 42, 300]: assert ( - len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0, max_items=limit)) + len( + await coin_store.get_coin_states_by_puzzle_hashes( + True, {std_hash(b"2")}, uint32(0), max_items=limit + ) + ) == limit ) # if the limit is very high, we should get all of them assert ( - len(await coin_store.get_coin_states_by_puzzle_hashes(True, {std_hash(b"2")}, 0, max_items=10000)) + len( + await coin_store.get_coin_states_by_puzzle_hashes( + True, {std_hash(b"2")}, uint32(0), max_items=10000 + ) + ) == 300 ) coins = {cr.coin.name() for cr in crs} - bad_coins = [std_hash(cr.coin.name()) for cr in crs] - assert len(await coin_store.get_coin_states_by_ids(True, coins, 0)) == 600 - assert len(await coin_store.get_coin_states_by_ids(False, coins, 0)) == 0 - assert len(await coin_store.get_coin_states_by_ids(True, coins, 300)) == 302 - assert len(await coin_store.get_coin_states_by_ids(True, coins, 603)) == 0 - assert len(await coin_store.get_coin_states_by_ids(True, bad_coins, 0)) == 0 + bad_coins = {std_hash(cr.coin.name()) for cr in crs} + assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0))) == 600 + assert len(await coin_store.get_coin_states_by_ids(False, coins, uint32(0))) == 0 + assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(300))) == 302 + assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(603))) == 0 + assert len(await coin_store.get_coin_states_by_ids(True, bad_coins, uint32(0))) == 0 # Test max_height assert len(await coin_store.get_coin_states_by_ids(True, coins, max_height=uint32(603))) == 600 assert len(await coin_store.get_coin_states_by_ids(True, coins, max_height=uint32(602))) == 600 @@ -467,7 +481,7 @@ async def test_get_coin_states(self, tmp_dir, db_version): # test max_items limit for limit in [0, 1, 42, 300]: - assert len(await coin_store.get_coin_states_by_ids(True, coins, 0, max_items=limit)) == limit + assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=limit)) == limit # if the limit is very high, we should get all of them - assert len(await coin_store.get_coin_states_by_ids(True, coins, 0, max_items=10000)) == 600 + assert len(await coin_store.get_coin_states_by_ids(True, coins, uint32(0), max_items=10000)) == 600