From 1aa3af34fa80b139cb396771dc1c047a3e7bd764 Mon Sep 17 00:00:00 2001 From: arvidn Date: Sun, 5 Nov 2023 16:23:47 +0100 Subject: [PATCH] when we add a block, there's a subtle co-routine suspension point in between updating the height-to-hash map and the peak height. This leaves a window for a call to get_peak() to fail, because the peak height doesn't exist in height-to-hash map. This patch moves the suspension point to make these updates atomic. This allows for removing a hack in a test, that previously had to catch this exception --- chia/consensus/blockchain.py | 53 ++++++++++++++++---------- tests/core/full_node/test_full_node.py | 29 ++++---------- 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/chia/consensus/blockchain.py b/chia/consensus/blockchain.py index 2e10136e3816..79d6e29bbb34 100644 --- a/chia/consensus/blockchain.py +++ b/chia/consensus/blockchain.py @@ -271,9 +271,14 @@ async def add_block( block, None, ) - # Always add the block to the database - async with self.block_store.db_wrapper.writer(): - try: + + # in case we fail and need to restore the blockchain state, remember the + # peak height + previous_peak_height = self._peak_height + + try: + # Always add the block to the database + async with self.block_store.db_wrapper.writer(): header_hash: bytes32 = block.header_hash # Perform the DB operations to update the state, and rollback if something goes wrong await self.block_store.add_full_block(header_hash, block, block_record) @@ -284,26 +289,32 @@ async def add_block( # Then update the memory cache. It is important that this is not cancelled and does not throw # This is done after all async/DB operations, so there is a decreased chance of failure. self.add_block_record(block_record) - if state_change_summary is not None: - self.__height_map.rollback(state_change_summary.fork_height) - for fetched_block_record in records: - self.__height_map.update_height( - fetched_block_record.height, - fetched_block_record.header_hash, - fetched_block_record.sub_epoch_summary_included, - ) - except BaseException as e: - self.block_store.rollback_cache_block(header_hash) - log.error( - f"Error while adding block {block.header_hash} height {block.height}," - f" rolling back: {traceback.format_exc()} {e}" + + # there's a suspension point here, as we leave the async context + # manager + + # make sure to update _peak_height after the transaction is committed, + # otherwise other tasks may go look for this block before it's available + if state_change_summary is not None: + self.__height_map.rollback(state_change_summary.fork_height) + for fetched_block_record in records: + self.__height_map.update_height( + fetched_block_record.height, + fetched_block_record.header_hash, + fetched_block_record.sub_epoch_summary_included, ) - raise - # make sure to update _peak_height after the transaction is committed, - # otherwise other tasks may go look for this block before it's available - if state_change_summary is not None: - self._peak_height = block_record.height + if state_change_summary is not None: + self._peak_height = block_record.height + + except BaseException as e: + self.block_store.rollback_cache_block(header_hash) + self._peak_height = previous_peak_height + log.error( + f"Error while adding block {block.header_hash} height {block.height}," + f" rolling back: {traceback.format_exc()} {e}" + ) + raise # This is done outside the try-except in case it fails, since we do not want to revert anything if it does await self.__height_map.maybe_flush() diff --git a/tests/core/full_node/test_full_node.py b/tests/core/full_node/test_full_node.py index 1577c42bce5a..434e4e08ca38 100644 --- a/tests/core/full_node/test_full_node.py +++ b/tests/core/full_node/test_full_node.py @@ -6,7 +6,6 @@ import logging import random import time -import traceback from typing import Coroutine, Dict, List, Optional, Tuple import pytest @@ -2190,16 +2189,9 @@ async def test_long_reorg_nodes( await full_node_2.full_node.add_block(reorg_blocks[-1]) def check_nodes_in_sync(): - try: - p1 = full_node_2.full_node.blockchain.get_peak() - p2 = full_node_1.full_node.blockchain.get_peak() - return p1 == p2 - except Exception as e: - # TODO: understand why we get an exception here sometimes. Fix it or - # add comment explaining why we need to catch here - traceback.print_exc() - print(f"e: {e}") - return False + p1 = full_node_2.full_node.blockchain.get_peak() + p2 = full_node_1.full_node.blockchain.get_peak() + return p1 == p2 await time_out_assert(120, check_nodes_in_sync) peak = full_node_2.full_node.blockchain.get_peak() @@ -2225,17 +2217,10 @@ def check_nodes_in_sync(): # await connect_and_get_peer(full_node_3.full_node.server, full_node_2.full_node.server, self_hostname) def check_nodes_in_sync2(): - try: - p1 = full_node_1.full_node.blockchain.get_peak() - # p2 = full_node_2.full_node.blockchain.get_peak() - p3 = full_node_3.full_node.blockchain.get_peak() - return p1.header_hash == p3.header_hash - except Exception as e: - # TODO: understand why we get an exception here sometimes. Fix it or - # add comment explaining why we need to catch here - traceback.print_exc() - print(f"e: {e}") - return False + p1 = full_node_1.full_node.blockchain.get_peak() + # p2 = full_node_2.full_node.blockchain.get_peak() + p3 = full_node_3.full_node.blockchain.get_peak() + return p1.header_hash == p3.header_hash await time_out_assert(950, check_nodes_in_sync2)