Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webserver: refactor computation module #4155

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .application_settings import setup_settings
from .catalog import setup_catalog
from .clusters.plugin import setup_clusters
from .computation import setup_computation
from .db import setup_db
from .diagnostics import setup_diagnostics
from .director.plugin import setup_director
Expand All @@ -26,10 +25,10 @@
from .login.plugin import setup_login
from .long_running_tasks import setup_long_running_tasks
from .meta_modeling.plugin import setup_meta_modeling
from .notifications.rabbitmq import setup_rabbitmq
from .products import setup_products
from .projects.plugin import setup_projects
from .publications import setup_publications
from .rabbitmq import setup_rabbitmq
from .redis import setup_redis
from .remote_debug import setup_remote_debugging
from .resource_manager.plugin import setup_resource_manager
Expand Down Expand Up @@ -79,7 +78,6 @@ def create_application() -> web.Application:
# monitoring
setup_diagnostics(app)
setup_activity(app)
setup_computation(app)
setup_socketio(app)

# login
Expand Down
9 changes: 3 additions & 6 deletions services/web/server/src/simcore_service_webserver/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import logging
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncIterator

from aiohttp import web
from aiopg.sa import Engine, create_engine
Expand All @@ -28,7 +28,6 @@

@retry(**PostgresRetryPolicyUponInitialization(log).kwargs)
async def _ensure_pg_ready(settings: PostgresSettings) -> Engine:

log.info("Connecting to postgres with %s", f"{settings=}")
engine = await create_engine(
settings.dsn,
Expand All @@ -48,7 +47,6 @@ async def _ensure_pg_ready(settings: PostgresSettings) -> Engine:


async def postgres_cleanup_ctx(app: web.Application) -> AsyncIterator[None]:

settings = get_plugin_settings(app)
aiopg_engine = await _ensure_pg_ready(settings)
app[APP_DB_ENGINE_KEY] = aiopg_engine
Expand Down Expand Up @@ -83,7 +81,7 @@ async def is_service_responsive(app: web.Application):


def get_engine_state(app: web.Application) -> dict[str, Any]:
engine: Optional[Engine] = app.get(APP_DB_ENGINE_KEY)
engine: Engine | None = app.get(APP_DB_ENGINE_KEY)
if engine:
return get_pg_engine_stateinfo(engine)
return {}
Expand All @@ -96,8 +94,7 @@ def get_database_engine(app: web.Application) -> Engine:
@app_module_setup(
__name__, ModuleCategory.ADDON, settings_name="WEBSERVER_DB", logger=log
)
def setup_db(app: web.Application):

def setup_db(app: web.Application) -> None:
# ensures keys exist
app[APP_DB_ENGINE_KEY] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ServiceWaitingForManualIntervention,
)
from .director_v2_settings import DirectorV2Settings, get_plugin_settings
from .rabbitmq import get_rabbitmq_client
from .notifications.rabbitmq import get_rabbitmq_client

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from contextlib import suppress
from pprint import pformat
from typing import AsyncIterator

from aiohttp import web
from aiopg.sa import Engine
Expand All @@ -21,9 +22,9 @@
from simcore_postgres_database.webserver_models import DB_CHANNEL_NAME, projects
from sqlalchemy.sql import select

from ..projects import projects_api, projects_exceptions
from ..projects.projects_nodes_utils import update_node_outputs
from .computation_utils import convert_state_from_db
from .projects import projects_api, projects_exceptions
from .projects.projects_nodes_utils import update_node_outputs

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,7 +168,7 @@ async def comp_tasks_listening_task(app: web.Application) -> None:
await asyncio.sleep(3)


async def create_comp_tasks_listening_task(app: web.Application):
async def create_comp_tasks_listening_task(app: web.Application) -> AsyncIterator[None]:
task = asyncio.create_task(
comp_tasks_listening_task(app), name="computation db listener"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import logging
from typing import Union
from typing import AsyncIterator, Union

from aiohttp import web
from models_library.rabbitmq_messages import (
Expand All @@ -22,10 +22,10 @@
from servicelib.logging_utils import log_context
from servicelib.rabbitmq import RabbitMQClient

from .projects import projects_api
from .projects.projects_exceptions import NodeNotFoundError, ProjectNotFoundError
from .rabbitmq import get_rabbitmq_client
from .socketio.events import (
from ..projects import projects_api
from ..projects.projects_exceptions import NodeNotFoundError, ProjectNotFoundError
from ..rabbitmq import get_rabbitmq_client
from ..socketio.events import (
SOCKET_IO_EVENT,
SOCKET_IO_LOG_EVENT,
SOCKET_IO_NODE_PROGRESS_EVENT,
Expand All @@ -35,7 +35,7 @@
send_messages,
)

log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)


async def _handle_computation_running_progress(
Expand Down Expand Up @@ -65,21 +65,21 @@ async def _handle_computation_running_progress(
await send_messages(app, f"{message.user_id}", messages)
return True
except ProjectNotFoundError:
log.warning(
logger.warning(
"project related to received rabbitMQ progress message not found: '%s'",
json_dumps(message, indent=2),
)
return True
except NodeNotFoundError:
log.warning(
logger.warning(
"node related to received rabbitMQ progress message not found: '%s'",
json_dumps(message, indent=2),
)
return True
return False


async def progress_message_parser(app: web.Application, data: bytes) -> bool:
async def _progress_message_parser(app: web.Application, data: bytes) -> bool:
# update corresponding project, node, progress value
rabbit_message = parse_raw_as(
Union[ProgressRabbitMessageNode, ProgressRabbitMessageProject], data
Expand Down Expand Up @@ -110,7 +110,7 @@ async def progress_message_parser(app: web.Application, data: bytes) -> bool:
return True


async def log_message_parser(app: web.Application, data: bytes) -> bool:
async def _log_message_parser(app: web.Application, data: bytes) -> bool:
rabbit_message = LoggerRabbitMessage.parse_raw(data)

if not await projects_api.is_project_hidden(app, rabbit_message.project_id):
Expand All @@ -124,7 +124,7 @@ async def log_message_parser(app: web.Application, data: bytes) -> bool:
return True


async def instrumentation_message_parser(app: web.Application, data: bytes) -> bool:
async def _instrumentation_message_parser(app: web.Application, data: bytes) -> bool:
rabbit_message = InstrumentationRabbitMessage.parse_raw(data)
if rabbit_message.metrics == "service_started":
service_started(
Expand All @@ -137,7 +137,7 @@ async def instrumentation_message_parser(app: web.Application, data: bytes) -> b
return True


async def events_message_parser(app: web.Application, data: bytes) -> bool:
async def _events_message_parser(app: web.Application, data: bytes) -> bool:
rabbit_message = EventRabbitMessage.parse_raw(data)

socket_messages: list[SocketMessageDict] = [
Expand All @@ -156,32 +156,35 @@ async def events_message_parser(app: web.Application, data: bytes) -> bool:
EXCHANGE_TO_PARSER_CONFIG = (
(
LoggerRabbitMessage.get_channel_name(),
log_message_parser,
_log_message_parser,
{},
),
(
ProgressRabbitMessageNode.get_channel_name(),
progress_message_parser,
_progress_message_parser,
{},
),
(
InstrumentationRabbitMessage.get_channel_name(),
instrumentation_message_parser,
_instrumentation_message_parser,
dict(exclusive_queue=False),
),
(
EventRabbitMessage.get_channel_name(),
events_message_parser,
_events_message_parser,
{},
),
)


async def setup_rabbitmq_consumers(app: web.Application) -> None:
with log_context(log, logging.INFO, msg="Subscribing to rabbitmq channels"):
async def setup_rabbitmq_consumers(app: web.Application) -> AsyncIterator[None]:
with log_context(logger, logging.INFO, msg="Subscribing to rabbitmq channels"):
rabbit_client: RabbitMQClient = get_rabbitmq_client(app)

for exchange_name, parser_fct, queue_kwargs in EXCHANGE_TO_PARSER_CONFIG:
await rabbit_client.subscribe(
exchange_name, functools.partial(parser_fct, app), **queue_kwargs
)
yield

# cleanup?
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from aiohttp import web
from servicelib.aiohttp.application_setup import ModuleCategory, app_module_setup

from .computation_comp_tasks_listening_task import create_comp_tasks_listening_task
from .computation_subscribe import setup_rabbitmq_consumers
from .rabbitmq import setup_rabbitmq
from ..db import setup_db
from ..rabbitmq import setup_rabbitmq
from ._computation_comp_tasks_listening_task import create_comp_tasks_listening_task
from ._rabbitmq_consumers import setup_rabbitmq_consumers

log = logging.getLogger(__name__)

Expand All @@ -27,7 +28,8 @@ def setup_computation(app: web.Application):
setup_rabbitmq(app)
# Subscribe to rabbit upon startup for logs, progress and other
# metrics on the execution reported by sidecars
app.on_startup.append(setup_rabbitmq_consumers)
app.cleanup_ctx.append(setup_rabbitmq_consumers)

# Creates a task to listen to comp_task pg-db's table events
setup_db(app)
app.cleanup_ctx.append(create_comp_tasks_listening_task)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def _rabbitmq_client_cleanup_ctx(app: web.Application) -> AsyncIterator[No
logger=log,
depends=[],
)
def setup_rabbitmq(app: web.Application):
def setup_rabbitmq(app: web.Application) -> None:
app.cleanup_ctx.append(_rabbitmq_client_cleanup_ctx)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
- settings
"""

from typing import Optional

from aiohttp.web import Application
from settings_library.rabbit import RabbitSettings

from ._constants import APP_SETTINGS_KEY
from .._constants import APP_SETTINGS_KEY


def get_plugin_settings(app: Application) -> RabbitSettings:
settings: Optional[RabbitSettings] = app[APP_SETTINGS_KEY].WEBSERVER_RABBITMQ
settings: RabbitSettings | None = app[APP_SETTINGS_KEY].WEBSERVER_RABBITMQ
assert settings, "setup_settings not called?" # nosec
assert isinstance(settings, RabbitSettings) # nosec
return settings
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=redefined-outer-name
# pylint: disable=protected-access

from typing import Iterator
from unittest.mock import AsyncMock

import pytest
Expand All @@ -9,21 +11,22 @@
ProgressRabbitMessageProject,
ProgressType,
)
from pydantic import BaseModel
from pytest_mock import MockerFixture
from simcore_service_webserver import computation_subscribe
from simcore_service_webserver.notifications import _rabbitmq_consumers

_faker = Faker()


@pytest.fixture
def mock_send_messages(mocker: MockerFixture) -> dict:
def mock_send_messages(mocker: MockerFixture) -> Iterator[dict]:
reference = {}

async def mock_send_message(*args) -> None:
reference["args"] = args

mocker.patch.object(
computation_subscribe, "send_messages", side_effect=mock_send_message
_rabbitmq_consumers, "send_messages", side_effect=mock_send_message
)

yield reference
Expand Down Expand Up @@ -60,9 +63,9 @@ async def mock_send_message(*args) -> None:
],
)
async def test_regression_progress_message_parser(
mock_send_messages: dict, raw_data: bytes, class_type: type
mock_send_messages: dict, raw_data: bytes, class_type: type[BaseModel]
):
await computation_subscribe.progress_message_parser(AsyncMock(), raw_data)
await _rabbitmq_consumers._progress_message_parser(AsyncMock(), raw_data)
serialized_sent_data = mock_send_messages["args"][2][0]["data"]
# check that all fields are sent as expected
assert class_type.parse_obj(serialized_sent_data)
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
(StateType.NOT_STARTED, RunningState.NOT_STARTED),
],
)
def test_convert_state_from_db(db_state: int, expected_state: RunningState):
def test_convert_state_from_db(db_state: StateType, expected_state: RunningState):
assert convert_state_from_db(db_state) == expected_state