Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect cache_exitnodes_to_disk from raising OSErrors #7039

Merged
merged 1 commit into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/tribler/core/components/tunnel/community/tunnel_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from binascii import unhexlify
from collections import Counter
from distutils.version import LooseVersion
from typing import List
from typing import List, Optional

import async_timeout

from ipv8.messaging.anonymization.caches import CreateRequestCache
from ipv8.messaging.anonymization.community import unpack_cell
from ipv8.messaging.anonymization.hidden_services import HiddenTunnelCommunity
Expand Down Expand Up @@ -48,6 +47,7 @@
RelayBalanceResponsePayload,
)
from tribler.core.utilities.bencodecheck import is_bencoded
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import (
DLSTATUS_DOWNLOADING,
DLSTATUS_METADATA,
Expand All @@ -72,7 +72,7 @@ class TriblerTunnelCommunity(HiddenTunnelCommunity):

def __init__(self, *args, **kwargs):
self.bandwidth_community = kwargs.pop('bandwidth_community', None)
self.exitnode_cache = kwargs.pop('exitnode_cache', None)
self.exitnode_cache: Optional[Path] = kwargs.pop('exitnode_cache', None)
self.config = kwargs.pop('config', None)
self.notifier = kwargs.pop('notifier', None)
self.dlmgr = kwargs.pop('dlmgr', None)
Expand Down Expand Up @@ -140,9 +140,12 @@ def cache_exitnodes_to_disk(self):
exit_nodes = Network()
for peer in self.get_candidates(PEER_FLAG_EXIT_BT):
exit_nodes.add_verified_peer(peer)
self.logger.debug('Writing exit nodes to cache: %s', self.exitnode_cache)
with open(self.exitnode_cache, 'wb') as cache:
cache.write(exit_nodes.snapshot())
snapshot = exit_nodes.snapshot()
self.logger.info(f'Writing exit nodes to cache file: {self.exitnode_cache}')
try:
self.exitnode_cache.write_bytes(snapshot)
except OSError as e:
self.logger.warning(f'{e.__class__.__name__}: {e}')

def restore_exitnodes_from_disk(self):
"""
Expand Down Expand Up @@ -601,7 +604,7 @@ async def on_http_request(self, source_address, payload, circuit_id):
response += line
if not line.strip():
# Read HTTP response body (1MB max)
response += await reader.read(1024**2)
response += await reader.read(1024 ** 2)
break
except OSError:
self.logger.warning('Tunnel HTTP request failed')
Expand All @@ -624,7 +627,7 @@ async def on_http_request(self, source_address, payload, circuit_id):
for i in range(num_cells):
self.send_cell(source_address,
HTTPResponsePayload(circuit_id, payload.identifier, i, num_cells,
response[i*MAX_HTTP_PACKET_SIZE:(i+1)*MAX_HTTP_PACKET_SIZE]))
response[i * MAX_HTTP_PACKET_SIZE:(i + 1) * MAX_HTTP_PACKET_SIZE]))

@unpack_cell(HTTPResponsePayload)
def on_http_response(self, source_address, payload, circuit_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from asyncio import Future, TimeoutError as AsyncTimeoutError, sleep, wait_for
from collections import defaultdict
from random import random
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from ipv8.messaging.anonymization.payload import EstablishIntroPayload
from ipv8.messaging.anonymization.tunnel import (
CIRCUIT_STATE_READY,
Expand All @@ -12,6 +13,7 @@
PEER_FLAG_EXIT_BT,
)
from ipv8.peer import Peer
from ipv8.peerdiscovery.network import Network
from ipv8.test.base import TestBase
from ipv8.test.messaging.anonymization import test_community
from ipv8.test.messaging.anonymization.test_community import MockDHTProvider
Expand All @@ -34,6 +36,24 @@
from tribler.core.utilities.utilities import MEMORY_DB


# pylint: disable=redefined-outer-name

@pytest.fixture()
def tunnel_community():
community = TriblerTunnelCommunity(MagicMock(),
MagicMock(),
MagicMock(),
socks_servers=MagicMock(),
config=MagicMock(),
notifier=MagicMock(),
dlmgr=MagicMock(),
bandwidth_community=MagicMock(),
dht_provider=MagicMock(),
exitnode_cache=MagicMock(),
settings=MagicMock())
return community


class TestTriblerTunnelCommunity(TestBase): # pylint: disable=too-many-public-methods

def setUp(self):
Expand Down Expand Up @@ -180,6 +200,7 @@ def test_monitor_downloads_recreate_ip(self):

def mock_create_ip(*_, **__):
mock_create_ip.called = True

mock_create_ip.called = False
self.nodes[0].overlay.create_introduction_point = mock_create_ip

Expand All @@ -200,8 +221,10 @@ def test_monitor_downloads_intro(self):
"""
Test whether rendezvous points are removed when a download is stopped
"""

def mocked_remove_circuit(circuit_id, *_, **__):
mocked_remove_circuit.circuit_id = circuit_id

mocked_remove_circuit.circuit_id = -1

mock_circuit = MockObject()
Expand All @@ -224,8 +247,10 @@ def test_monitor_downloads_stop_all(self):
"""
Test whether circuits are removed when all downloads are stopped
"""

def mocked_remove_circuit(circuit_id, *_, **__):
mocked_remove_circuit.circuit_id = circuit_id

mocked_remove_circuit.circuit_id = -1

mock_circuit = MockObject()
Expand Down Expand Up @@ -644,3 +669,21 @@ async def test_perform_http_request_failed(self):
await wait_for(self.nodes[0].overlay.perform_http_request(('127.0.0.1', 1234),
b'GET /scrape?info_hash=0 HTTP/1.1\r\n\r\n'),
timeout=.3)


@patch.object(Network, 'snapshot', Mock(return_value=b'snapshot'))
def test_cache_exitnodes_to_disk(tunnel_community: TriblerTunnelCommunity, tmp_path):
""" Test whether we can cache exit nodes to disk """
tunnel_community.exitnode_cache = tmp_path / 'exitnode_cache.dat'
tunnel_community.cache_exitnodes_to_disk()

assert tunnel_community.exitnode_cache.read_bytes() == b'snapshot'


@patch.object(Network, 'snapshot', Mock(return_value=b'snapshot'))
def test_cache_exitnodes_to_disk_os_error(tunnel_community: TriblerTunnelCommunity):
""" Test whether we can handle an OSError when caching exit nodes to disk and raise no errors """
tunnel_community.exitnode_cache = Mock(write_bytes=Mock(side_effect=FileNotFoundError))
tunnel_community.cache_exitnodes_to_disk()

assert tunnel_community.exitnode_cache.write_bytes.called