diff --git a/src/tribler/core/components/tunnel/community/tunnel_community.py b/src/tribler/core/components/tunnel/community/tunnel_community.py index 9563718664e..efe8fd2ccdb 100644 --- a/src/tribler/core/components/tunnel/community/tunnel_community.py +++ b/src/tribler/core/components/tunnel/community/tunnel_community.py @@ -6,10 +6,9 @@ from binascii import unhexlify from collections import Counter from distutils.version import LooseVersion -from typing import List +from typing import List, Optional import async_timeout - from ipv8.messaging.anonymization.caches import CreateRequestCache from ipv8.messaging.anonymization.community import unpack_cell from ipv8.messaging.anonymization.hidden_services import HiddenTunnelCommunity @@ -48,6 +47,7 @@ RelayBalanceResponsePayload, ) from tribler.core.utilities.bencodecheck import is_bencoded +from tribler.core.utilities.path_util import Path from tribler.core.utilities.simpledefs import ( DLSTATUS_DOWNLOADING, DLSTATUS_METADATA, @@ -72,7 +72,7 @@ class TriblerTunnelCommunity(HiddenTunnelCommunity): def __init__(self, *args, **kwargs): self.bandwidth_community = kwargs.pop('bandwidth_community', None) - self.exitnode_cache = kwargs.pop('exitnode_cache', None) + self.exitnode_cache: Optional[Path] = kwargs.pop('exitnode_cache', None) self.config = kwargs.pop('config', None) self.notifier = kwargs.pop('notifier', None) self.dlmgr = kwargs.pop('dlmgr', None) @@ -140,9 +140,12 @@ def cache_exitnodes_to_disk(self): exit_nodes = Network() for peer in self.get_candidates(PEER_FLAG_EXIT_BT): exit_nodes.add_verified_peer(peer) - self.logger.debug('Writing exit nodes to cache: %s', self.exitnode_cache) - with open(self.exitnode_cache, 'wb') as cache: - cache.write(exit_nodes.snapshot()) + snapshot = exit_nodes.snapshot() + self.logger.info(f'Writing exit nodes to cache file: {self.exitnode_cache}') + try: + self.exitnode_cache.write_bytes(snapshot) + except OSError as e: + self.logger.warning(f'{e.__class__.__name__}: {e}') def restore_exitnodes_from_disk(self): """ @@ -601,7 +604,7 @@ async def on_http_request(self, source_address, payload, circuit_id): response += line if not line.strip(): # Read HTTP response body (1MB max) - response += await reader.read(1024**2) + response += await reader.read(1024 ** 2) break except OSError: self.logger.warning('Tunnel HTTP request failed') @@ -624,7 +627,7 @@ async def on_http_request(self, source_address, payload, circuit_id): for i in range(num_cells): self.send_cell(source_address, HTTPResponsePayload(circuit_id, payload.identifier, i, num_cells, - response[i*MAX_HTTP_PACKET_SIZE:(i+1)*MAX_HTTP_PACKET_SIZE])) + response[i * MAX_HTTP_PACKET_SIZE:(i + 1) * MAX_HTTP_PACKET_SIZE])) @unpack_cell(HTTPResponsePayload) def on_http_response(self, source_address, payload, circuit_id): diff --git a/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py b/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py index c6382329fca..bec8ef04530 100644 --- a/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py +++ b/src/tribler/core/components/tunnel/tests/test_triblertunnel_community.py @@ -2,8 +2,9 @@ from asyncio import Future, TimeoutError as AsyncTimeoutError, sleep, wait_for from collections import defaultdict from random import random -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch +import pytest from ipv8.messaging.anonymization.payload import EstablishIntroPayload from ipv8.messaging.anonymization.tunnel import ( CIRCUIT_STATE_READY, @@ -12,6 +13,7 @@ PEER_FLAG_EXIT_BT, ) from ipv8.peer import Peer +from ipv8.peerdiscovery.network import Network from ipv8.test.base import TestBase from ipv8.test.messaging.anonymization import test_community from ipv8.test.messaging.anonymization.test_community import MockDHTProvider @@ -34,6 +36,24 @@ from tribler.core.utilities.utilities import MEMORY_DB +# pylint: disable=redefined-outer-name + +@pytest.fixture() +def tunnel_community(): + community = TriblerTunnelCommunity(MagicMock(), + MagicMock(), + MagicMock(), + socks_servers=MagicMock(), + config=MagicMock(), + notifier=MagicMock(), + dlmgr=MagicMock(), + bandwidth_community=MagicMock(), + dht_provider=MagicMock(), + exitnode_cache=MagicMock(), + settings=MagicMock()) + return community + + class TestTriblerTunnelCommunity(TestBase): # pylint: disable=too-many-public-methods def setUp(self): @@ -180,6 +200,7 @@ def test_monitor_downloads_recreate_ip(self): def mock_create_ip(*_, **__): mock_create_ip.called = True + mock_create_ip.called = False self.nodes[0].overlay.create_introduction_point = mock_create_ip @@ -200,8 +221,10 @@ def test_monitor_downloads_intro(self): """ Test whether rendezvous points are removed when a download is stopped """ + def mocked_remove_circuit(circuit_id, *_, **__): mocked_remove_circuit.circuit_id = circuit_id + mocked_remove_circuit.circuit_id = -1 mock_circuit = MockObject() @@ -224,8 +247,10 @@ def test_monitor_downloads_stop_all(self): """ Test whether circuits are removed when all downloads are stopped """ + def mocked_remove_circuit(circuit_id, *_, **__): mocked_remove_circuit.circuit_id = circuit_id + mocked_remove_circuit.circuit_id = -1 mock_circuit = MockObject() @@ -644,3 +669,21 @@ async def test_perform_http_request_failed(self): await wait_for(self.nodes[0].overlay.perform_http_request(('127.0.0.1', 1234), b'GET /scrape?info_hash=0 HTTP/1.1\r\n\r\n'), timeout=.3) + + +@patch.object(Network, 'snapshot', Mock(return_value=b'snapshot')) +def test_cache_exitnodes_to_disk(tunnel_community: TriblerTunnelCommunity, tmp_path): + """ Test whether we can cache exit nodes to disk """ + tunnel_community.exitnode_cache = tmp_path / 'exitnode_cache.dat' + tunnel_community.cache_exitnodes_to_disk() + + assert tunnel_community.exitnode_cache.read_bytes() == b'snapshot' + + +@patch.object(Network, 'snapshot', Mock(return_value=b'snapshot')) +def test_cache_exitnodes_to_disk_os_error(tunnel_community: TriblerTunnelCommunity): + """ Test whether we can handle an OSError when caching exit nodes to disk and raise no errors """ + tunnel_community.exitnode_cache = Mock(write_bytes=Mock(side_effect=FileNotFoundError)) + tunnel_community.cache_exitnodes_to_disk() + + assert tunnel_community.exitnode_cache.write_bytes.called