Skip to content

Commit

Permalink
Refactoring query_http_uri
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Jan 25, 2024
1 parent ed1d03b commit 51338cd
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tribler.core.components.libtorrent.utils import torrent_utils
from tribler.core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt
from tribler.core.utilities import path_util
from tribler.core.utilities.aiohttp.aiohttp_utils import unshorten
from tribler.core.utilities.network_utils import default_network_utils
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.path_util import Path
Expand All @@ -38,7 +39,7 @@
)
from tribler.core.utilities.simpledefs import DownloadStatus, MAX_LIBTORRENT_RATE_LIMIT, STATEDIR_CHECKPOINT_DIR
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import bdecode_compat, has_bep33_support, parse_magnetlink, unshorten
from tribler.core.utilities.utilities import bdecode_compat, has_bep33_support, parse_magnetlink
from tribler.core.version import version_id

SOCKS5_PROXY_DEF = 2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import json
import shutil
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from binascii import unhexlify
from ssl import SSLError
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from urllib.parse import quote_plus, unquote_plus

import pytest
from aiohttp import ClientConnectorError, ClientResponseError, ServerConnectionError
from ipv8.util import succeed

from tribler.core import notifications
Expand Down Expand Up @@ -98,7 +95,7 @@ async def mock_http_query(*_):
with open(tmp_path / "ubuntu.torrent", 'rb') as f:
return f.read()

with patch(f"{TARGET}.query_http_uri", new=mock_http_query):
with patch("tribler.core.components.libtorrent.restapi.torrentinfo_endpoint.query_uri", new=mock_http_query):
verify_valid_dict(await do_request(rest_api, url, params={'uri': path}, expected_code=200))

path = quote_plus(f'magnet:?xt=urn:btih:{hexlify(UBUNTU_1504_INFOHASH)}'
Expand Down Expand Up @@ -167,10 +164,10 @@ async def get_metainfo(infohash, timeout=20, hops=None, url=None): # pylint: di

async def test_get_torrentinfo_invalid_magnet(rest_api):
# Test that invalid magnet link casues an error
mocked_query_http_uri = AsyncMock(return_value=b'magnet:?xt=urn:ed2k:' + b"any hash")
mocked_query_uri = AsyncMock(return_value=b'magnet:?xt=urn:ed2k:' + b"any hash")
params = {'uri': 'http://any.uri'}

with patch(f'{TARGET}.query_http_uri', mocked_query_http_uri):
with patch('tribler.core.components.libtorrent.restapi.torrentinfo_endpoint.query_uri', mocked_query_uri):
result = await do_request(rest_api, 'torrentinfo', params=params, expected_code=HTTP_INTERNAL_SERVER_ERROR)

assert 'error' in result
Expand All @@ -182,12 +179,12 @@ async def test_get_torrentinfo_invalid_magnet(rest_api):
async def test_get_torrentinfo_get_metainfo_from_downloaded_magnet(rest_api, download_manager: DownloadManager):
# Test that the `get_metainfo` function passes the correct arguments.
magnet = b'magnet:?xt=urn:btih:' + b'0' * 40
mocked_query_http_uri = AsyncMock(return_value=magnet)
mocked_query_uri = AsyncMock(return_value=magnet)
params = {'uri': 'any non empty uri'}

download_manager.get_metainfo = AsyncMock(return_value={b'info': {}})

with patch(f'{TARGET}.query_http_uri', mocked_query_http_uri):
with patch(f'{TARGET}.query_uri', mocked_query_uri):
await do_request(rest_api, 'torrentinfo', params=params)

expected_url = magnet.decode('utf-8')
Expand All @@ -202,28 +199,3 @@ async def test_on_got_invalid_metainfo(rest_api):
path = f"magnet:?xt=urn:btih:{hexlify(UBUNTU_1504_INFOHASH)}&dn={quote_plus('test torrent')}"
res = await do_request(rest_api, f'torrentinfo?uri={path}', expected_code=HTTP_INTERNAL_SERVER_ERROR)
assert "error" in res


# These are the exceptions that are handled by torrent info endpoint when querying an HTTP URI.
caught_exceptions = [
ServerConnectionError(),
ClientResponseError(Mock(), Mock()),
SSLError(),
ClientConnectorError(Mock(), Mock()),
AsyncTimeoutError()
]


@patch(f"{TARGET}.query_http_uri")
@pytest.mark.parametrize("exception", caught_exceptions)
async def test_torrentinfo_endpoint_timeout_error(mocked_query_http_uri: AsyncMock, exception: Exception):
# Test that in the case of exceptions related to querying HTTP URI specified in this tests,
# no exception is raised.
mocked_query_http_uri.side_effect = exception

endpoint = TorrentInfoEndpoint(MagicMock())
request = MagicMock(query={'uri': 'http://some_torrent_url'})

info = await endpoint.get_torrent_info(request)

assert info.status == HTTP_INTERNAL_SERVER_ERROR
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import hashlib
import json
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
from copy import deepcopy
from ssl import SSLError

from aiohttp import ClientConnectorError, ClientResponseError, ClientSession, ServerConnectionError, web
from aiohttp import web
from aiohttp_apispec import docs
from ipv8.REST.schema import schema
from marshmallow.fields import String
Expand All @@ -20,6 +18,8 @@
RESTEndpoint,
RESTResponse,
)
from tribler.core.utilities.aiohttp.aiohttp_utils import query_uri, unshorten
from tribler.core.utilities.aiohttp.exceptions import AiohttpException
from tribler.core.utilities.rest_utils import (
FILE_SCHEME,
HTTPS_SCHEME,
Expand All @@ -29,16 +29,7 @@
url_to_path,
)
from tribler.core.utilities.unicode import hexlify, recursive_unicode
from tribler.core.utilities.utilities import bdecode_compat, froze_it, parse_magnetlink, unshorten


async def query_http_uri(uri: str) -> bytes:
# This is moved to a separate method to be able to patch it separately,
# for compatibility with pytest-aiohttp
async with ClientSession(raise_for_status=True) as session:
response = await session.get(uri)
response = await response.read()
return response
from tribler.core.utilities.utilities import bdecode_compat, froze_it, parse_magnetlink


@froze_it
Expand Down Expand Up @@ -100,9 +91,8 @@ async def get_torrent_info(self, request):
status=HTTP_INTERNAL_SERVER_ERROR)
elif scheme in (HTTP_SCHEME, HTTPS_SCHEME):
try:
response = await query_http_uri(uri)
except (ServerConnectionError, ClientResponseError, SSLError, ClientConnectorError, AsyncTimeoutError) as e:
self._logger.warning(f'Error while querying http uri: {e}')
response = await query_uri(uri)
except AiohttpException as e:
return RESTResponse({"error": str(e)}, status=HTTP_INTERNAL_SERVER_ERROR)

if response.startswith(b'magnet'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from unittest.mock import Mock

import pytest
from aiohttp import ClientSession

from tribler.core.components.socks_servers.socks5.aiohttp_connector import Socks5Connector
from tribler.core.components.socks_servers.socks5.client import Socks5Client, Socks5Error
from tribler.core.components.socks_servers.socks5.conversion import UdpPacket, socks5_serializer
from tribler.core.components.socks_servers.socks5.server import Socks5Server
from tribler.core.utilities.aiohttp.aiohttp_utils import query_uri


@pytest.fixture(name='socks5_server')
Expand Down Expand Up @@ -107,7 +107,8 @@ def return_data(conn, target, _):
conn.transport.close()

socks5_server.output_stream.on_socks5_tcp_data = return_data

async with ClientSession(connector=Socks5Connector(('127.0.0.1', socks5_server.port))) as session:
async with session.get('http://localhost') as response:
assert (await response.read()) == b'Hello'
result = await query_uri(
uri='http://localhost',
connector=Socks5Connector(('127.0.0.1', socks5_server.port))
)
assert result == b'Hello'
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import platform
from asyncio import sleep
from dataclasses import dataclass
Expand All @@ -11,6 +10,7 @@
from tribler.core.components.restapi.rest.rest_endpoint import RESTResponse
from tribler.core.components.version_check import versioncheck_manager
from tribler.core.components.version_check.versioncheck_manager import VersionCheckManager
from tribler.core.utilities.aiohttp.exceptions import AiohttpException
from tribler.core.version import version_id

# pylint: disable=redefined-outer-name, protected-access
Expand Down Expand Up @@ -74,12 +74,12 @@ async def test_start(version_check_manager: VersionCheckManager):
@patch('platform.python_version', Mock(return_value='3.0.0'))
@patch('platform.architecture', Mock(return_value=('64bit', 'FooBar')))
async def test_user_agent(version_server: VersionCheckManager):
result = await version_server._check_urls()

actual = result.request_info.headers['User-Agent']
expected = f'Tribler/{version_id} (machine=machine; os=os 1; python=3.0.0; executable=64bit)'

assert actual == expected
with patch('tribler.core.components.version_check.versioncheck_manager.query_uri') as mocked_query_uri:
await version_server._check_urls()
actual = mocked_query_uri.call_args.kwargs['headers']['User-Agent']
assert actual == expected


@patch.object(ResponseSettings, 'response', first_version)
Expand Down Expand Up @@ -111,7 +111,7 @@ async def test_version_check_api_timeout(version_server: VersionCheckManager):

# Since the time to respond is higher than the time version checker waits for response,
# it should raise the `asyncio.TimeoutError`
with pytest.raises(asyncio.TimeoutError):
with pytest.raises(AiohttpException):
await version_server._raw_request_new_version(version_server.urls[0])


Expand Down
27 changes: 15 additions & 12 deletions src/tribler/core/components/version_check/versioncheck_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import platform
from distutils.version import LooseVersion
from typing import List, Optional
from typing import Dict, List, Optional

from aiohttp import ClientResponse, ClientSession, ClientTimeout
from aiohttp import ClientTimeout
from ipv8.taskmanager import TaskManager

from tribler.core import notifications
from tribler.core.utilities.aiohttp.aiohttp_utils import query_uri
from tribler.core.utilities.notifier import Notifier
from tribler.core.version import version_id

Expand Down Expand Up @@ -42,28 +43,30 @@ def timeout(self):
def timeout(self, value: float):
self._timeout = ClientTimeout(total=value)

async def _check_urls(self) -> Optional[ClientResponse]:
async def _check_urls(self) -> Optional[Dict]:
for version_check_url in self.urls:
if result := await self._request_new_version(version_check_url):
return result

async def _request_new_version(self, version_check_url: str) -> Optional[ClientResponse]:
async def _request_new_version(self, version_check_url: str) -> Optional[Dict]:
try:
return await self._raw_request_new_version(version_check_url)
except Exception as e: # pylint: disable=broad-except
# broad exception handling for preventing an application crash that may follow
# the occurrence of an exception in the version check manager
self._logger.warning(e)

async def _raw_request_new_version(self, version_check_url: str) -> Optional[ClientResponse]:
async def _raw_request_new_version(self, version_check_url: str) -> Optional[Dict]:
headers = {'User-Agent': self._get_user_agent_string(version_id, platform)}
async with ClientSession(raise_for_status=True) as session:
response = await session.get(version_check_url, headers=headers, timeout=self.timeout)
response_dict = await response.json(content_type=None)
version = response_dict['name'][1:]
if LooseVersion(version) > LooseVersion(version_id):
self.notifier[notifications.tribler_new_version](version)
return response
json_dict = await query_uri(version_check_url, headers=headers, timeout=self.timeout, return_json=True)
version = json_dict['name'][1:]
print('!!!')
print(version)
if LooseVersion(version) > LooseVersion(version_id):
self.notifier[notifications.tribler_new_version](version)
return json_dict

return None

@staticmethod
def _get_user_agent_string(tribler_version, platform_module):
Expand Down
62 changes: 62 additions & 0 deletions src/tribler/core/utilities/aiohttp/aiohttp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import logging
from ssl import SSLError
from typing import Dict, Optional, Union

from aiohttp import BaseConnector, ClientConnectorError, ClientResponseError, ClientSession, ClientTimeout, \
ServerConnectionError
from aiohttp.hdrs import LOCATION
from aiohttp.typedefs import LooseHeaders

from tribler.core.utilities.aiohttp.exceptions import AiohttpException
from tribler.core.utilities.rest_utils import HTTPS_SCHEME, HTTP_SCHEME, scheme_from_url

logger = logging.getLogger(__name__)


async def query_uri(uri: str, connector: Optional[BaseConnector] = None, headers: Optional[LooseHeaders] = None,
timeout: ClientTimeout = None, return_json: bool = False, ) -> Union[Dict, bytes]:
kwargs = {'headers': headers}
if timeout:
# ClientSession uses a sentinel object for the default timeout. Therefore, it should only be specified if an
# actual value has been passed to this function.
kwargs['timeout'] = timeout

async with ClientSession(connector=connector, raise_for_status=True) as session:
try:
async with await session.get(uri, **kwargs) as response:
if return_json:
return await response.json(content_type=None)
return await response.read()
except (ServerConnectionError, ClientResponseError, SSLError, ClientConnectorError, asyncio.TimeoutError) as e:
message = f'Error while querying http uri. {e.__class__.__name__}: {e}'
logger.warning(message, exc_info=e)
raise AiohttpException(message) from e


async def unshorten(uri: str) -> str:
""" Unshorten a URI if it is a short URI. Return the original URI if it is not a short URI.
Args:
uri (str): A string representing the shortened URL that needs to be unshortened.
Returns:
str: The unshortened URL. If the original URL does not redirect to another URL, the original URL is returned.
"""

scheme = scheme_from_url(uri)
if scheme not in (HTTP_SCHEME, HTTPS_SCHEME):
return uri

logger.info(f'Unshortening URI: {uri}')

async with ClientSession() as session:
try:
async with await session.get(uri, allow_redirects=False) as response:
if response.status in (301, 302, 303, 307, 308):
uri = response.headers.get(LOCATION, uri)
except Exception as e:
logger.warning(f'Error while unshortening a URI: {e.__class__.__name__}: {e}', exc_info=e)

logger.info(f'Unshorted URI: {uri}')
return uri
5 changes: 5 additions & 0 deletions src/tribler/core/utilities/aiohttp/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from tribler.core.exceptions import TriblerException

Check notice on line 1 in src/tribler/core/utilities/aiohttp/exceptions.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

src/tribler/core/utilities/aiohttp/exceptions.py#L1

Similar lines in 2 files


class AiohttpException(TriblerException):
""" Base class for all aiohttp exceptions. """
Loading

0 comments on commit 51338cd

Please sign in to comment.