diff --git a/src/tribler-core/tribler_core/components/restapi/restapi_component.py b/src/tribler-core/tribler_core/components/restapi/restapi_component.py index c22e61cd439..8b619a963b2 100644 --- a/src/tribler-core/tribler_core/components/restapi/restapi_component.py +++ b/src/tribler-core/tribler_core/components/restapi/restapi_component.py @@ -2,12 +2,10 @@ from typing import Any, Dict, List, Set, Tuple from ipv8_service import IPv8 - from tribler_common.reported_error import ReportedError from tribler_common.simpledefs import STATE_START_API - from tribler_core.components.base import Component -from tribler_core.components.reporter.exception_handler import CoreExceptionHandler +from tribler_core.components.reporter.exception_handler import CoreExceptionHandler, default_core_exception_handler from tribler_core.components.reporter.reporter_component import ReporterComponent from tribler_core.components.restapi.rest.debug_endpoint import DebugEndpoint from tribler_core.components.restapi.rest.events_endpoint import EventsEndpoint @@ -75,6 +73,7 @@ class RESTComponent(Component): _events_endpoint: EventsEndpoint _state_endpoint: StateEndpoint + _core_exception_handler: CoreExceptionHandler = default_core_exception_handler async def run(self): await super().run() @@ -115,10 +114,13 @@ def report_callback(reported_error: ReportedError): self._events_endpoint.on_tribler_exception(reported_error) self._state_endpoint.on_tribler_exception(reported_error.text) - CoreExceptionHandler.report_callback = report_callback + self._core_exception_handler.report_callback = report_callback async def shutdown(self): await super().shutdown() - CoreExceptionHandler.report_callback = None + + if self._core_exception_handler: + self._core_exception_handler.report_callback = None + if self.rest_manager: await self.rest_manager.stop() diff --git a/src/tribler-core/tribler_core/components/restapi/tests/test_restapi_component.py b/src/tribler-core/tribler_core/components/restapi/tests/test_restapi_component.py index eaefe8519eb..40bcde37ddd 100644 --- a/src/tribler-core/tribler_core/components/restapi/tests/test_restapi_component.py +++ b/src/tribler-core/tribler_core/components/restapi/tests/test_restapi_component.py @@ -1,12 +1,12 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from tribler_common.reported_error import ReportedError - from tribler_core.components.base import Session from tribler_core.components.key.key_component import KeyComponent from tribler_core.components.reporter.exception_handler import CoreExceptionHandler +from tribler_core.components.restapi.rest.rest_manager import RESTManager from tribler_core.components.restapi.restapi_component import RESTComponent pytestmark = pytest.mark.asyncio @@ -14,17 +14,6 @@ # pylint: disable=protected-access, not-callable -def assert_report_callback_is_correct(component: RESTComponent): - assert CoreExceptionHandler.report_callback - component._events_endpoint.on_tribler_exception = MagicMock() - component._state_endpoint.on_tribler_exception = MagicMock() - - error = ReportedError(type='', text='text', event={}) - CoreExceptionHandler.report_callback(error) - - component._events_endpoint.on_tribler_exception.assert_called_with(error) - component._state_endpoint.on_tribler_exception.assert_called_with(error.text) - async def test_restful_component(tribler_config): components = [KeyComponent(), RESTComponent()] @@ -35,5 +24,27 @@ async def test_restful_component(tribler_config): comp = RESTComponent.instance() assert comp.started_event.is_set() and not comp.failed assert comp.rest_manager - assert_report_callback_is_correct(comp) await session.shutdown() + + +@patch.object(RESTComponent, 'get_component', new=AsyncMock()) +@patch.object(RESTManager, 'start', new=AsyncMock()) +async def test_report_callback_set_up_correct(): + component = RESTComponent() + component.session = MagicMock() + + component._core_exception_handler = CoreExceptionHandler() + + await component.run() + + # mock callbacks + component._events_endpoint.on_tribler_exception = MagicMock() + component._state_endpoint.on_tribler_exception = MagicMock() + + # try to call report_callback from core_exception_handler and assert + # that corresponding methods in events_endpoint and state_endpoint have been called + + error = ReportedError(type='', text='text', event={}) + component._core_exception_handler.report_callback(error) + component._events_endpoint.on_tribler_exception.assert_called_with(error) + component._state_endpoint.on_tribler_exception.assert_called_with(error.text)