Skip to content

Commit

Permalink
Annotate test_block_store.py (#15569)
Browse files Browse the repository at this point in the history
Annotate test_block_store.py.
  • Loading branch information
AmineKhaldi authored Jun 21, 2023
1 parent eb0282e commit 3a914ce
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
1 change: 0 additions & 1 deletion mypy-exclusions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ tests.core.custom_types.test_coin
tests.core.custom_types.test_spend_bundle
tests.core.daemon.test_daemon
tests.core.full_node.full_sync.test_full_sync
tests.core.full_node.stores.test_block_store
tests.core.full_node.stores.test_coin_store
tests.core.full_node.stores.test_full_node_store
tests.core.full_node.stores.test_hint_store
Expand Down
47 changes: 26 additions & 21 deletions tests/core/full_node/stores/test_block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import logging
import random
import sqlite3
from pathlib import Path
from typing import List

import pytest
from clvm.casts import int_to_bytes
Expand All @@ -14,19 +16,20 @@
from chia.consensus.full_block_to_block_record import header_block_to_sub_block_record
from chia.full_node.block_store import BlockStore
from chia.full_node.coin_store import CoinStore
from chia.simulator.block_tools import BlockTools
from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.vdf import VDFProof
from chia.types.full_block import FullBlock
from chia.util.ints import uint8
from chia.util.ints import uint8, uint32, uint64
from tests.blockchain.blockchain_test_utils import _validate_and_add_block
from tests.util.db_connection import DBConnection

log = logging.getLogger(__name__)


@pytest.mark.asyncio
async def test_block_store(tmp_dir, db_version, bt):
async def test_block_store(tmp_dir: Path, db_version: int, bt: BlockTools) -> None:
assert sqlite3.threadsafety >= 1
blocks = bt.get_consecutive_blocks(10)

Expand All @@ -53,9 +56,9 @@ async def test_block_store(tmp_dir, db_version, bt):
await store.set_peak(block_record.header_hash)
await store.set_peak(block_record.header_hash)

assert len(await store.get_full_blocks_at([1])) == 1
assert len(await store.get_full_blocks_at([0])) == 1
assert len(await store.get_full_blocks_at([100])) == 0
assert len(await store.get_full_blocks_at([uint32(1)])) == 1
assert len(await store.get_full_blocks_at([uint32(0)])) == 1
assert len(await store.get_full_blocks_at([uint32(100)])) == 0

# get_block_records_in_range
block_record_records = await store.get_block_records_in_range(0, 0xFFFFFFFF)
Expand All @@ -78,7 +81,7 @@ async def test_block_store(tmp_dir, db_version, bt):


@pytest.mark.asyncio
async def test_deadlock(tmp_dir, db_version, bt):
async def test_deadlock(tmp_dir: Path, db_version: int, bt: BlockTools) -> None:
"""
This test was added because the store was deadlocking in certain situations, when fetching and
adding blocks repeatedly. The issue was patched.
Expand All @@ -94,7 +97,7 @@ async def test_deadlock(tmp_dir, db_version, bt):
for block in blocks:
await _validate_and_add_block(bc, block)
block_records.append(bc.block_record(block.header_hash))
tasks = []
tasks: List[asyncio.Task[object]] = []

for i in range(10000):
rand_i = random.randint(0, 9)
Expand All @@ -110,7 +113,7 @@ async def test_deadlock(tmp_dir, db_version, bt):


@pytest.mark.asyncio
async def test_rollback(bt, tmp_dir):
async def test_rollback(bt: BlockTools, tmp_dir: Path) -> None:
blocks = bt.get_consecutive_blocks(10)

async with DBConnection(2) as db_wrapper:
Expand All @@ -134,7 +137,7 @@ async def test_rollback(bt, tmp_dir):
async with conn.execute(
"SELECT in_main_chain FROM full_blocks WHERE header_hash=?", (block.header_hash,)
) as cursor:
rows = await cursor.fetchall()
rows = list(await cursor.fetchall())
assert len(rows) == 1
assert rows[0][0]

Expand All @@ -147,15 +150,15 @@ async def test_rollback(bt, tmp_dir):
"SELECT in_main_chain FROM full_blocks WHERE header_hash=? ORDER BY height",
(block.header_hash,),
) as cursor:
rows = await cursor.fetchall()
rows = list(await cursor.fetchall())
print(count, rows)
assert len(rows) == 1
assert rows[0][0] == (count <= 5)
count += 1


@pytest.mark.asyncio
async def test_count_compactified_blocks(bt, tmp_dir, db_version):
async def test_count_compactified_blocks(bt: BlockTools, tmp_dir: Path, db_version: int) -> None:
blocks = bt.get_consecutive_blocks(10)

async with DBConnection(db_version) as db_wrapper:
Expand All @@ -174,7 +177,7 @@ async def test_count_compactified_blocks(bt, tmp_dir, db_version):


@pytest.mark.asyncio
async def test_count_uncompactified_blocks(bt, tmp_dir, db_version):
async def test_count_uncompactified_blocks(bt: BlockTools, tmp_dir: Path, db_version: int) -> None:
blocks = bt.get_consecutive_blocks(10)

async with DBConnection(db_version) as db_wrapper:
Expand All @@ -193,10 +196,10 @@ async def test_count_uncompactified_blocks(bt, tmp_dir, db_version):


@pytest.mark.asyncio
async def test_replace_proof(bt, tmp_dir, db_version):
async def test_replace_proof(bt: BlockTools, tmp_dir: Path, db_version: int) -> None:
blocks = bt.get_consecutive_blocks(10)

def rand_bytes(num) -> bytes:
def rand_bytes(num: int) -> bytes:
ret = bytearray(num)
for i in range(num):
ret[i] = random.getrandbits(8)
Expand Down Expand Up @@ -227,17 +230,19 @@ def rand_vdf_proof() -> VDFProof:

for block, proof in zip(blocks, replaced):
b = await block_store.get_full_block(block.header_hash)
assert b is not None
assert b.challenge_chain_ip_proof == proof

# make sure we get the same result when we hit the database
# itself (and not just the block cache)
block_store.rollback_cache_block(block.header_hash)
b = await block_store.get_full_block(block.header_hash)
assert b is not None
assert b.challenge_chain_ip_proof == proof


@pytest.mark.asyncio
async def test_get_generator(bt, db_version):
async def test_get_generator(bt: BlockTools, db_version: int) -> None:
blocks = bt.get_consecutive_blocks(10)

def generator(i: int) -> SerializedProgram:
Expand All @@ -250,7 +255,7 @@ def generator(i: int) -> SerializedProgram:
for i, block in enumerate(blocks):
block = dataclasses.replace(block, transactions_generator=generator(i))
block_record = header_block_to_sub_block_record(
DEFAULT_CONSTANTS, 0, block, 0, False, 0, max(0, block.height - 1), None
DEFAULT_CONSTANTS, uint64(0), block, uint64(0), False, uint8(0), uint32(max(0, block.height - 1)), None
)
await store.add_full_block(block.header_hash, block, block_record)
await store.set_in_chain([(block_record.header_hash,)])
Expand All @@ -259,16 +264,16 @@ def generator(i: int) -> SerializedProgram:

if db_version == 2:
expected_generators = list(map(lambda x: x.transactions_generator, new_blocks[1:10]))
generators = await store.get_generators_at(range(1, 10))
generators = await store.get_generators_at([uint32(x) for x in range(1, 10)])
assert generators == expected_generators

# test out-of-order heights
expected_generators = list(map(lambda x: x.transactions_generator, [new_blocks[i] for i in [4, 8, 3, 9]]))
generators = await store.get_generators_at([4, 8, 3, 9])
generators = await store.get_generators_at([uint32(4), uint32(8), uint32(3), uint32(9)])
assert generators == expected_generators

with pytest.raises(KeyError):
await store.get_generators_at([100])
await store.get_generators_at([uint32(100)])

assert await store.get_generator(blocks[2].header_hash) == new_blocks[2].transactions_generator
assert await store.get_generator(blocks[4].header_hash) == new_blocks[4].transactions_generator
Expand All @@ -277,7 +282,7 @@ def generator(i: int) -> SerializedProgram:


@pytest.mark.asyncio
async def test_get_blocks_by_hash(tmp_dir, bt, db_version):
async def test_get_blocks_by_hash(tmp_dir: Path, bt: BlockTools, db_version: int) -> None:
assert sqlite3.threadsafety >= 1
blocks = bt.get_consecutive_blocks(10)

Expand Down Expand Up @@ -315,7 +320,7 @@ async def test_get_blocks_by_hash(tmp_dir, bt, db_version):


@pytest.mark.asyncio
async def test_get_block_bytes_in_range(tmp_dir, bt, db_version):
async def test_get_block_bytes_in_range(tmp_dir: Path, bt: BlockTools, db_version: int) -> None:
assert sqlite3.threadsafety >= 1
blocks = bt.get_consecutive_blocks(10)

Expand Down

0 comments on commit 3a914ce

Please sign in to comment.