Skip to content

Commit

Permalink
Add unshorten URI logic
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Jan 15, 2024
1 parent 7cbca95 commit 3127550
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from tribler.core.utilities.simpledefs import DownloadStatus, MAX_LIBTORRENT_RATE_LIMIT, STATEDIR_CHECKPOINT_DIR
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import bdecode_compat, has_bep33_support, parse_magnetlink
from tribler.core.utilities.utilities import bdecode_compat, has_bep33_support, parse_magnetlink, unshorten
from tribler.core.version import version_id

SOCKS5_PROXY_DEF = 2
Expand Down Expand Up @@ -573,6 +573,8 @@ def _map_call_on_ltsessions(self, hops, funcname, *args, **kwargs):

async def start_download_from_uri(self, uri, config=None):
self._logger.info(f'Start download from URI: {uri}')

uri = await unshorten(uri)
scheme = scheme_from_url(uri)

if scheme in (HTTP_SCHEME, HTTPS_SCHEME):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
url_to_path,
)
from tribler.core.utilities.unicode import hexlify, recursive_unicode
from tribler.core.utilities.utilities import bdecode_compat, froze_it, parse_magnetlink
from tribler.core.utilities.utilities import bdecode_compat, froze_it, parse_magnetlink, unshorten


async def query_http_uri(uri: str) -> bytes:
Expand Down Expand Up @@ -87,7 +87,7 @@ async def get_torrent_info(self, request):
if not uri:
return RESTResponse({"error": "uri parameter missing"}, status=HTTP_BAD_REQUEST)

metainfo = None
uri = await unshorten(uri)
scheme = scheme_from_url(uri)

if scheme == FILE_SCHEME:
Expand All @@ -110,7 +110,7 @@ async def get_torrent_info(self, request):
_, infohash, _ = parse_magnetlink(response)
except RuntimeError as e:
return RESTResponse(
{"error": f'Error while getting an ingo hash from magnet: {e.__class__.__name__}: {e}'},
{"error": f'Error while getting an infohash from magnet: {e.__class__.__name__}: {e}'},
status=HTTP_INTERNAL_SERVER_ERROR
)

Expand All @@ -124,7 +124,7 @@ async def get_torrent_info(self, request):
_, infohash, _ = parse_magnetlink(uri)
except RuntimeError as e:
return RESTResponse(
{"error": f'Error while getting an ingo hash from magnet: {e.__class__.__name__}: {e}'},
{"error": f'Error while getting an infohash from magnet: {e.__class__.__name__}: {e}'},
status=HTTP_BAD_REQUEST
)

Expand Down
47 changes: 45 additions & 2 deletions src/tribler/core/utilities/tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
from unittest.mock import MagicMock, Mock, patch
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from aiohttp import ClientSession, web
from aiohttp.hdrs import LOCATION, URI

from tribler.core.logger.logger import load_logger_config
from tribler.core.utilities.patch_import import patch_import
from tribler.core.utilities.tracker_utils import add_url_params
from tribler.core.utilities.utilities import (Query, extract_tags, get_normally_distributed_positive_integers,
is_channel_public_key, is_infohash, is_simple_match_query, is_valid_url,
parse_bool, parse_magnetlink, parse_query, random_infohash, safe_repr,
show_system_popup, to_fts_query)
show_system_popup, to_fts_query, unshorten)

# pylint: disable=import-outside-toplevel, import-error, redefined-outer-name
# fmt: off
Expand Down Expand Up @@ -370,3 +372,44 @@ class MyException(Exception):
obj = MagicMock(__repr__=Mock(side_effect=MyException("exception text")))
result = safe_repr(obj)
assert result == f'<Repr of {object.__repr__(obj)} raises MyException: exception text>'



UNSHORTEN_TEST_DATA = [
SimpleNamespace(
# Test that the `unshorten` function returns the unshorten URL if there is a redirect detected by
# the right status code and right header.
url='http://shorten',
response=SimpleNamespace(status=301, headers={LOCATION: 'http://unshorten'}),
expected='http://unshorten'
),
SimpleNamespace(
# Test that the `unshorten` function returns the same URL if there is wrong scheme
url='file://shorten',
response=SimpleNamespace(status=0, headers={}),
expected='file://shorten'
),
SimpleNamespace(
# Test that the `unshorten` function returns the same URL if there is no redirect detected by the wrong status
# code.
url='http://shorten',
response=SimpleNamespace(status=401, headers={LOCATION: 'http://unshorten'}),
expected='http://shorten'
),
SimpleNamespace(
# Test that the `unshorten` function returns the same URL if there is no redirect detected by the wrong header.
url='http://shorten',
response=SimpleNamespace(status=301, headers={URI: 'http://unshorten'}),
expected='http://shorten'
)
]


@pytest.mark.parametrize("test_data", UNSHORTEN_TEST_DATA)
async def test_unshorten(test_data):
# The function mocks the ClientSession.get method to return a mocked response with the given status and headers.
# It is used with the test data above to test the unshorten function.
response = MagicMock(status=test_data.response.status, headers=test_data.response.headers)
mocked_get = AsyncMock(return_value=AsyncMock(__aenter__=AsyncMock(return_value=response)))
with patch.object(ClientSession, 'get', mocked_get):
assert await unshorten(test_data.url) == test_data.expected
32 changes: 32 additions & 0 deletions src/tribler/core/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from typing import Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlsplit

from aiohttp import ClientSession
from aiohttp.hdrs import LOCATION

from tribler.core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt
from tribler.core.utilities.rest_utils import HTTPS_SCHEME, HTTP_SCHEME, scheme_from_url
from tribler.core.utilities.sentinels import sentinel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -335,3 +339,31 @@ def safe_repr(obj):
return repr(obj)
except Exception as e: # pylint: disable=broad-except
return f'<Repr of {object.__repr__(obj)} raises {e.__class__.__name__}: {e}>'


async def unshorten(uri: str) -> str:
""" Unshorten a URI if it is a short URI. Return the original URI if it is not a short URI.
Args:
uri (str): A string representing the shortened URL that needs to be unshortened.
Returns:
str: The unshortened URL. If the original URL does not redirect to another URL, the original URL is returned.
"""

scheme = scheme_from_url(uri)
if scheme not in (HTTP_SCHEME, HTTPS_SCHEME):
return uri

logger.info(f'Unshortening URI: {uri}')

async with ClientSession() as session:
try:
async with await session.get(uri, allow_redirects=False) as response:
if response.status in (301, 302, 303, 307, 308):
uri = response.headers.get(LOCATION, uri)
except Exception as e:
logger.warning(f'Error while unshortening a URI: {e.__class__.__name__}: {e}', exc_info=e)

logger.info(f'Unshorted URI: {uri}')
return uri

0 comments on commit 3127550

Please sign in to comment.