From 8ce069fb36497b0dc61c7428f40467ba713b4e35 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:51:00 +0100 Subject: [PATCH 01/21] ruff --- .../container_tasks/io.py | 35 ++++++++----------- .../src/simcore_service_dask_sidecar/tasks.py | 3 +- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/io.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/io.py index 388ec63596d..4b63d2f6911 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/io.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/io.py @@ -1,7 +1,7 @@ import json from contextlib import suppress from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import Any, ClassVar, Union, cast from models_library.basic_regex import MIME_TYPE_RE from models_library.generics import DictModel @@ -26,7 +26,7 @@ class PortSchema(BaseModel): class Config: extra = Extra.forbid - schema_extra: dict[str, Any] = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "required": True, @@ -39,11 +39,11 @@ class Config: class FilePortSchema(PortSchema): - mapping: Optional[str] = None + mapping: str | None = None url: AnyUrl class Config(PortSchema.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "mapping": "some_filename.txt", @@ -60,17 +60,17 @@ class Config(PortSchema.Config): class FileUrl(BaseModel): url: AnyUrl - file_mapping: Optional[str] = Field( + file_mapping: str | None = Field( default=None, description="Local file relpath name (if given), otherwise it takes the url filename", ) - file_mime_type: Optional[str] = Field( + file_mime_type: str | None = Field( default=None, description="the file MIME type", regex=MIME_TYPE_RE ) class Config: extra = Extra.forbid - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ {"url": "https://some_file_url", "file_mime_type": "application/json"}, { @@ -97,7 +97,7 @@ class Config: class TaskInputData(DictModel[PortKey, PortValue]): class Config(DictModel.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "boolean_input": False, @@ -121,7 +121,7 @@ class TaskOutputDataSchema(DictModel[PortKey, PortSchemaValue]): # sent as a json-schema instead of with a dynamically-created model class # class Config(DictModel.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "boolean_output": {"required": False}, @@ -159,8 +159,6 @@ def from_task_output( for output_key, output_params in schema.items(): if isinstance(output_params, FilePortSchema): file_relpath = output_params.mapping or output_key - # TODO: file_path is built here, saved truncated in file_mapping and - # then rebuild again int _retrieve_output_data. Review. file_path = output_folder / file_relpath if file_path.exists(): data[output_key] = { @@ -168,20 +166,17 @@ def from_task_output( "file_mapping": file_relpath, } elif output_params.required: - raise ValueError( - f"Could not locate '{file_path}' in {output_folder}" - ) - else: - if output_key not in data and output_params.required: - raise ValueError( - f"Could not locate '{output_key}' in {output_data_file}" - ) + msg = f"Could not locate '{file_path}' in {output_folder}" + raise ValueError(msg) + elif output_key not in data and output_params.required: + msg = f"Could not locate '{output_key}' in {output_data_file}" + raise ValueError(msg) # NOTE: this cast is necessary to make mypy happy return cast(TaskOutputData, cls.parse_obj(data)) class Config(DictModel.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "boolean_output": False, diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py index 0690a45091f..73709d41942 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py @@ -135,8 +135,7 @@ async def _run_computational_sidecar_async( # pylint: disable=too-many-argument return output_data -def run_computational_sidecar( - # pylint: disable=too-many-arguments +def run_computational_sidecar( # pylint: disable=too-many-arguments # noqa: PLR0913 docker_auth: DockerBasicAuth, service_key: ContainerImage, service_version: ContainerTag, From ab75c09e696b8423be1cd0a9bb70205a2888d839 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 16:43:11 +0100 Subject: [PATCH 02/21] ruff --- .../dask-sidecar/tests/unit/test_tasks.py | 78 ++++++++++--------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index a88da15076f..9cbacfaab85 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -8,12 +8,13 @@ import json import logging import re +from collections.abc import Callable, Coroutine, Iterable # copied out from dask from dataclasses import dataclass from pprint import pformat from random import randint -from typing import Any, Callable, Coroutine, Iterable +from typing import Any from unittest import mock from uuid import uuid4 @@ -41,7 +42,6 @@ from models_library.users import UserID from packaging import version from pydantic import AnyUrl, SecretStr, parse_obj_as -from pytest import FixtureRequest, LogCaptureFixture from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from settings_library.s3 import S3Settings @@ -143,7 +143,7 @@ class ServiceExampleParam: expected_output_data: TaskOutputData expected_logs: list[str] integration_version: version.Version - task_envs: dict[EnvVarsDict, str] + task_envs: dict[EnvVarKey, str] def sidecar_params(self) -> dict[str, Any]: return { @@ -170,7 +170,7 @@ def _bash_check_env_exist(variable_name: str, variable_value: str) -> list[str]: @pytest.fixture(params=list(BootMode), ids=str) -def boot_mode(request: FixtureRequest) -> BootMode: +def boot_mode(request: pytest.FixtureRequest) -> BootMode: return request.param @@ -182,7 +182,7 @@ def boot_mode(request: FixtureRequest) -> BootMode: ], ids=lambda v: f"integration.version.{v}", ) -def integration_version(request: FixtureRequest) -> version.Version: +def integration_version(request: pytest.FixtureRequest) -> version.Version: print("--> Using service integration:", request.param) return version.Version(request.param) @@ -274,7 +274,7 @@ def sleeper_task( f"echo '{faker.text(max_nb_chars=17216)}'", f"(test -f ${{INPUT_FOLDER}}/{input_json_file_name} || (echo ${{INPUT_FOLDER}}/{input_json_file_name} file does not exists && exit 1))", f"echo $(cat ${{INPUT_FOLDER}}/{input_json_file_name})", - f"sleep {randint(1,4)}", + f"sleep {randint(1,4)}", # noqa: S311 ] # defines the expected outputs @@ -287,34 +287,38 @@ def sleeper_task( output_file_url = s3_remote_file_url(file_path="output_file") expected_output_keys = TaskOutputDataSchema.parse_obj( { - **{k: {"required": True} for k in jsonable_outputs.keys()}, - **{ - "pytest_file": { - "required": True, - "mapping": "a_outputfile", - "url": f"{output_file_url}", - }, - "pytest_file_with_mapping": { - "required": True, - "mapping": "subfolder/a_outputfile", - "url": f"{output_file_url}", - }, - }, + **( + {k: {"required": True} for k in jsonable_outputs} + | { + "pytest_file": { + "required": True, + "mapping": "a_outputfile", + "url": f"{output_file_url}", + }, + "pytest_file_with_mapping": { + "required": True, + "mapping": "subfolder/a_outputfile", + "url": f"{output_file_url}", + }, + } + ), } ) expected_output_data = TaskOutputData.parse_obj( { - **jsonable_outputs, - **{ - "pytest_file": { - "url": f"{output_file_url}", - "file_mapping": "a_outputfile", - }, - "pytest_file_with_mapping": { - "url": f"{output_file_url}", - "file_mapping": "subfolder/a_outputfile", - }, - }, + **( + jsonable_outputs + | { + "pytest_file": { + "url": f"{output_file_url}", + "file_mapping": "a_outputfile", + }, + "pytest_file_with_mapping": { + "url": f"{output_file_url}", + "file_mapping": "subfolder/a_outputfile", + }, + } + ), } ) jsonized_outputs = json.dumps(jsonable_outputs).replace('"', '\\"') @@ -423,7 +427,9 @@ def sleeper_task_unexpected_output( @pytest.fixture() -def caplog_info_level(caplog: LogCaptureFixture) -> Iterable[LogCaptureFixture]: +def caplog_info_level( + caplog: pytest.LogCaptureFixture, +) -> Iterable[pytest.LogCaptureFixture]: with caplog.at_level(logging.INFO, logger="simcore_service_dask_sidecar"): yield caplog @@ -436,16 +442,15 @@ def mocked_get_image_labels( ImageLabels, ServiceDockerData.Config.schema_extra["examples"][0] ) labels.integration_version = f"{integration_version}" - mocked_get_image_labels = mocker.patch( + return mocker.patch( "simcore_service_dask_sidecar.computational_sidecar.core.get_image_labels", autospec=True, return_value=labels, ) - return mocked_get_image_labels def test_run_computational_sidecar_real_fct( - caplog_info_level: LogCaptureFixture, + caplog_info_level: pytest.LogCaptureFixture, event_loop: asyncio.AbstractEventLoop, app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], @@ -609,6 +614,7 @@ async def test_run_computational_sidecar_dask( ), f"Could not find {log} in worker_logs:\n {pformat(worker_logs, width=240)}" # check that the task produce the expected data, not less not more + assert isinstance(output_data, dict) for k, v in sleeper_task.expected_output_data.items(): assert k in output_data assert output_data[k] == v @@ -684,7 +690,7 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub "integration_version, boot_mode", [("1.0.0", BootMode.CPU)], indirect=True ) def test_failing_service_raises_exception( - caplog_info_level: LogCaptureFixture, + caplog_info_level: pytest.LogCaptureFixture, app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], failing_ubuntu_task: ServiceExampleParam, @@ -703,7 +709,7 @@ def test_failing_service_raises_exception( "integration_version, boot_mode", [("1.0.0", BootMode.CPU)], indirect=True ) def test_running_service_that_generates_unexpected_data_raises_exception( - caplog_info_level: LogCaptureFixture, + caplog_info_level: pytest.LogCaptureFixture, app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], sleeper_task_unexpected_output: ServiceExampleParam, From c6d128afe1af33440eb58c5f74d5ac84cffeac1d Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 17:17:58 +0100 Subject: [PATCH 03/21] refactoring --- .../container_tasks/protocol.py | 51 +++++++++--- .../computational_sidecar/core.py | 71 +++++++---------- .../src/simcore_service_dask_sidecar/tasks.py | 60 +++----------- .../dask-sidecar/tests/unit/test_tasks.py | 79 ++++++++++++------- 4 files changed, 130 insertions(+), 131 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py index bae2b295715..4af5b3510a9 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py @@ -1,9 +1,12 @@ -from typing import Protocol, TypeAlias +from typing import Any, Protocol, TypeAlias from models_library.basic_types import EnvVarKey from models_library.docker import DockerLabelKey +from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID from models_library.services_resources import BootMode -from pydantic import AnyUrl +from models_library.users import UserID +from pydantic import AnyUrl, BaseModel, root_validator from settings_library.s3 import S3Settings from .docker import DockerBasicAuth @@ -17,20 +20,46 @@ ContainerLabelsDict: TypeAlias = dict[DockerLabelKey, str] +class TaskOwner(BaseModel): + user_id: UserID + project_id: ProjectID + node_id: NodeID + + parent_project_id: ProjectID | None + parent_node_id: NodeID | None + + @root_validator + @classmethod + def check_parent_valid(cls, values: dict[str, Any]) -> dict[str, Any]: + parent_project_id = values.get("parent_project_id") + parent_node_id = values.get("parent_node_id") + if (parent_node_id is None and parent_project_id is not None) or ( + parent_node_id is not None and parent_project_id is None + ): + msg = "either both parent_node_id and parent_project_id are None or both are set!" + raise ValueError(msg) + return values + + +class ContainerTaskParameters(BaseModel): + image: ContainerImage + tag: ContainerTag + input_data: TaskInputData + output_data_keys: TaskOutputDataSchema + command: ContainerCommands + envs: ContainerEnvsDict + labels: ContainerLabelsDict + boot_mode: BootMode + task_owner: TaskOwner + + class ContainerRemoteFct(Protocol): - def __call__( # pylint: disable=too-many-arguments # noqa: PLR0913 + def __call__( self, *, + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode, ) -> TaskOutputData: ... diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py index bd5446187e8..93d0e072917 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py @@ -3,30 +3,20 @@ import logging import os import socket +from collections.abc import Coroutine from dataclasses import dataclass from pathlib import Path from pprint import pformat from types import TracebackType -from typing import Coroutine, cast +from typing import cast from uuid import uuid4 from aiodocker import Docker from dask_task_models_library.container_tasks.docker import DockerBasicAuth from dask_task_models_library.container_tasks.errors import ServiceRuntimeError from dask_task_models_library.container_tasks.events import TaskLogEvent -from dask_task_models_library.container_tasks.io import ( - FileUrl, - TaskInputData, - TaskOutputData, - TaskOutputDataSchema, -) -from dask_task_models_library.container_tasks.protocol import ( - ContainerEnvsDict, - ContainerImage, - ContainerLabelsDict, - ContainerTag, -) -from models_library.services_resources import BootMode +from dask_task_models_library.container_tasks.io import FileUrl, TaskOutputData +from dask_task_models_library.container_tasks.protocol import ContainerTaskParameters from packaging import version from pydantic import ValidationError from pydantic.networks import AnyUrl @@ -53,20 +43,14 @@ CONTAINER_WAIT_TIME_SECS = 2 -@dataclass -class ComputationalSidecar: # pylint: disable=too-many-instance-attributes +@dataclass(kw_only=True, frozen=True, slots=True) +class ComputationalSidecar: + task_parameters: ContainerTaskParameters docker_auth: DockerBasicAuth - service_key: ContainerImage - service_version: ContainerTag - input_data: TaskInputData - output_data_keys: TaskOutputDataSchema log_file_url: AnyUrl - boot_mode: BootMode task_max_resources: dict[str, float] task_publishers: TaskPublisher s3_settings: S3Settings | None - task_envs: ContainerEnvsDict - task_labels: ContainerLabelsDict async def _write_input_data( self, @@ -80,7 +64,7 @@ async def _write_input_data( local_input_data_file = {} download_tasks = [] - for input_key, input_params in self.input_data.items(): + for input_key, input_params in self.task_parameters.input_data.items(): if isinstance(input_params, FileUrl): file_name = ( input_params.file_mapping @@ -125,11 +109,11 @@ async def _retrieve_output_data( ) logger.debug( "following outputs will be searched for:\n%s", - self.output_data_keys.json(indent=1), + self.task_parameters.output_data_keys.json(indent=1), ) output_data = TaskOutputData.from_task_output( - self.output_data_keys, + self.task_parameters.output_data_keys, task_volumes.outputs_folder, "outputs.json" if integration_version > LEGACY_INTEGRATION_VERSION @@ -160,8 +144,8 @@ async def _retrieve_output_data( except (ValueError, ValidationError) as exc: raise ServiceBadFormattedOutputError( - service_key=self.service_key, - service_version=self.service_version, + service_key=self.task_parameters.image, + service_version=self.task_parameters.tag, exc=exc, ) from exc @@ -177,7 +161,7 @@ async def _publish_sidecar_log( async def run(self, command: list[str]) -> TaskOutputData: # ensure we pass the initial logs and progress await self._publish_sidecar_log( - f"Starting task for {self.service_key}:{self.service_version} on {socket.gethostname()}..." + f"Starting task for {self.task_parameters.image}:{self.task_parameters.tag} on {socket.gethostname()}..." ) self.task_publishers.publish_progress(0) @@ -189,27 +173,30 @@ async def run(self, command: list[str]) -> TaskOutputData: await pull_image( docker_client, self.docker_auth, - self.service_key, - self.service_version, + self.task_parameters.image, + self.task_parameters.tag, self._publish_sidecar_log, ) image_labels: ImageLabels = await get_image_labels( - docker_client, self.docker_auth, self.service_key, self.service_version + docker_client, + self.docker_auth, + self.task_parameters.image, + self.task_parameters.tag, ) computational_shared_data_mount_point = ( await get_computational_shared_data_mount_point(docker_client) ) config = await create_container_config( docker_registry=self.docker_auth.server_address, - image=self.service_key, - tag=self.service_version, + image=self.task_parameters.image, + tag=self.task_parameters.tag, command=command, comp_volume_mount_point=f"{computational_shared_data_mount_point}/{run_id}", - boot_mode=self.boot_mode, + boot_mode=self.task_parameters.boot_mode, task_max_resources=self.task_max_resources, - envs=self.task_envs, - labels=self.task_labels, + envs=self.task_parameters.envs, + labels=self.task_parameters.labels, ) await self._write_input_data( task_volumes, image_labels.get_integration_version() @@ -219,12 +206,12 @@ async def run(self, command: list[str]) -> TaskOutputData: async with managed_container( docker_client, config, - name=f"{self.service_key.split(sep='/')[-1]}_{run_id}", + name=f"{self.task_parameters.image.split(sep='/')[-1]}_{run_id}", ) as container, managed_monitor_container_log_task( container=container, progress_regexp=image_labels.get_progress_regexp(), - service_key=self.service_key, - service_version=self.service_version, + service_key=self.task_parameters.image, + service_version=self.task_parameters.tag, task_publishers=self.task_publishers, integration_version=image_labels.get_integration_version(), task_volumes=task_volumes, @@ -241,8 +228,8 @@ async def run(self, command: list[str]) -> TaskOutputData: await asyncio.sleep(CONTAINER_WAIT_TIME_SECS) if container_data["State"]["ExitCode"] > os.EX_OK: raise ServiceRuntimeError( - service_key=self.service_key, - service_version=self.service_version, + service_key=self.task_parameters.image, + service_version=self.task_parameters.tag, container_id=container.id, exit_code=container_data["State"]["ExitCode"], service_logs=await cast( diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py index 73709d41942..658e40656df 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py @@ -6,21 +6,12 @@ import distributed from dask_task_models_library.container_tasks.docker import DockerBasicAuth -from dask_task_models_library.container_tasks.io import ( - TaskInputData, - TaskOutputData, - TaskOutputDataSchema, -) +from dask_task_models_library.container_tasks.io import TaskOutputData from dask_task_models_library.container_tasks.protocol import ( - ContainerCommands, - ContainerEnvsDict, - ContainerImage, - ContainerLabelsDict, - ContainerTag, + ContainerTaskParameters, LogFileUploadURL, ) from distributed.worker import logger -from models_library.services_resources import BootMode from servicelib.logging_utils import config_all_loggers from settings_library.s3 import S3Settings @@ -90,25 +81,18 @@ async def dask_teardown(_worker: distributed.Worker) -> None: logger.warning("Tearing down worker!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") -async def _run_computational_sidecar_async( # pylint: disable=too-many-arguments # noqa: PLR0913 +async def _run_computational_sidecar_async( *, + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode, ) -> TaskOutputData: task_publishers = TaskPublisher() _logger.debug( "run_computational_sidecar %s", - f"{docker_auth=}, {service_key=}, {service_version=}, {input_data=}, {output_data_keys=}, {command=}, {s3_settings=}", + f"{task_parameters.dict()=}, {docker_auth=}, {log_file_url=}, {s3_settings=}", ) current_task = asyncio.current_task() assert current_task # nosec @@ -117,36 +101,23 @@ async def _run_computational_sidecar_async( # pylint: disable=too-many-argument ): task_max_resources = get_current_task_resources() async with ComputationalSidecar( - service_key=service_key, - service_version=service_version, - input_data=input_data, - output_data_keys=output_data_keys, - log_file_url=log_file_url, + task_parameters=task_parameters, docker_auth=docker_auth, - boot_mode=boot_mode, + log_file_url=log_file_url, + s3_settings=s3_settings, task_max_resources=task_max_resources, task_publishers=task_publishers, - s3_settings=s3_settings, - task_envs=task_envs, - task_labels=task_labels, ) as sidecar: - output_data = await sidecar.run(command=command) + output_data = await sidecar.run(command=task_parameters.command) _logger.debug("completed run of sidecar with result %s", f"{output_data=}") return output_data -def run_computational_sidecar( # pylint: disable=too-many-arguments # noqa: PLR0913 +def run_computational_sidecar( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: # NOTE: The event loop MUST BE created in the main thread prior to this # Dask creates threads to run these calls, and the loop shall be created before @@ -161,16 +132,9 @@ def run_computational_sidecar( # pylint: disable=too-many-arguments # noqa: PLR return asyncio.get_event_loop().run_until_complete( _run_computational_sidecar_async( + task_parameters=task_parameters, docker_auth=docker_auth, - service_key=service_key, - service_version=service_version, - input_data=input_data, - output_data_keys=output_data_keys, log_file_url=log_file_url, - command=command, - task_envs=task_envs, - task_labels=task_labels, s3_settings=s3_settings, - boot_mode=boot_mode, ) ) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index 9cbacfaab85..29906dc8ead 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -16,7 +16,6 @@ from random import randint from typing import Any from unittest import mock -from uuid import uuid4 import distributed import fsspec @@ -33,6 +32,10 @@ TaskOutputData, TaskOutputDataSchema, ) +from dask_task_models_library.container_tasks.protocol import ( + ContainerTaskParameters, + TaskOwner, +) from faker import Faker from models_library.basic_types import EnvVarKey from models_library.projects import ProjectID @@ -68,18 +71,18 @@ def job_id() -> str: @pytest.fixture -def user_id() -> UserID: - return 1 +def user_id(faker: Faker) -> UserID: + return faker.pyint(min_value=1) @pytest.fixture -def project_id() -> ProjectID: - return uuid4() +def project_id(faker: Faker) -> ProjectID: + return faker.uuid4(cast_to=None) @pytest.fixture -def node_id() -> NodeID: - return uuid4() +def node_id(faker: Faker) -> NodeID: + return faker.uuid4(cast_to=None) @pytest.fixture() @@ -131,7 +134,7 @@ def dask_subsystem_mock(mocker: MockerFixture) -> dict[str, mock.Mock]: } -@dataclass +@dataclass(slots=True, kw_only=True) class ServiceExampleParam: docker_basic_auth: DockerBasicAuth service_key: str @@ -144,17 +147,24 @@ class ServiceExampleParam: expected_logs: list[str] integration_version: version.Version task_envs: dict[EnvVarKey, str] + task_owner: TaskOwner + boot_mode: BootMode def sidecar_params(self) -> dict[str, Any]: return { + "task_parameters": ContainerTaskParameters( + image=self.service_key, + tag=self.service_version, + input_data=self.input_data, + output_data_keys=self.output_data_keys, + command=self.command, + envs=self.task_envs, + labels={}, + task_owner=self.task_owner, + boot_mode=self.boot_mode, + ), "docker_auth": self.docker_basic_auth, - "service_key": self.service_key, - "service_version": self.service_version, - "input_data": self.input_data, - "output_data_keys": self.output_data_keys, "log_file_url": self.log_file_url, - "command": self.command, - "task_envs": self.task_envs, } @@ -192,6 +202,17 @@ def additional_envs(faker: Faker) -> dict[EnvVarKey, str]: return parse_obj_as(dict[EnvVarKey, str], faker.pydict(allowed_types=(str,))) +@pytest.fixture +def task_owner(user_id: UserID, project_id: ProjectID, node_id: NodeID) -> TaskOwner: + return TaskOwner( + user_id=user_id, + project_id=project_id, + node_id=node_id, + parent_project_id=None, + parent_node_id=None, + ) + + @pytest.fixture def sleeper_task( integration_version: version.Version, @@ -200,6 +221,10 @@ def sleeper_task( boot_mode: BootMode, additional_envs: dict[EnvVarKey, str], faker: Faker, + user_id: UserID, + project_id: ProjectID, + node_id: NodeID, + task_owner: TaskOwner, ) -> ServiceExampleParam: """Creates a console task in an ubuntu distro that checks for the expected files and error in case they are missing""" # let's have some input files on the file server @@ -379,6 +404,8 @@ def sleeper_task( ], integration_version=integration_version, task_envs=additional_envs, + task_owner=task_owner, + boot_mode=boot_mode, ) @@ -389,6 +416,7 @@ def sidecar_task( s3_remote_file_url: Callable[..., AnyUrl], boot_mode: BootMode, faker: Faker, + task_owner: TaskOwner, ) -> Callable[..., ServiceExampleParam]: def _creator(command: list[str] | None = None) -> ServiceExampleParam: return ServiceExampleParam( @@ -406,6 +434,8 @@ def _creator(command: list[str] | None = None) -> ServiceExampleParam: expected_logs=[], integration_version=integration_version, task_envs={}, + task_owner=task_owner, + boot_mode=boot_mode, ) return _creator @@ -455,15 +485,15 @@ def test_run_computational_sidecar_real_fct( app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], sleeper_task: ServiceExampleParam, - s3_settings: S3Settings, - boot_mode: BootMode, mocked_get_image_labels: mock.Mock, + user_id: UserID, + project_id: ProjectID, + node_id: NodeID, + s3_settings: S3Settings, ): output_data = run_computational_sidecar( **sleeper_task.sidecar_params(), s3_settings=s3_settings, - boot_mode=boot_mode, - task_labels={}, ) mocked_get_image_labels.assert_called_once_with( mock.ANY, @@ -522,9 +552,8 @@ def test_run_computational_sidecar_real_fct( def test_run_multiple_computational_sidecar_dask( dask_client: distributed.Client, sleeper_task: ServiceExampleParam, - s3_settings: S3Settings, - boot_mode: BootMode, mocked_get_image_labels: mock.Mock, + s3_settings: S3Settings, ): NUMBER_OF_TASKS = 50 @@ -534,8 +563,6 @@ def test_run_multiple_computational_sidecar_dask( **sleeper_task.sidecar_params(), s3_settings=s3_settings, resources={}, - boot_mode=boot_mode, - task_labels={}, ) for _ in range(NUMBER_OF_TASKS) ] @@ -572,7 +599,6 @@ async def test_run_computational_sidecar_dask( dask_client: distributed.Client, sleeper_task: ServiceExampleParam, s3_settings: S3Settings, - boot_mode: BootMode, log_sub: distributed.Sub, progress_sub: distributed.Sub, mocked_get_image_labels: mock.Mock, @@ -582,8 +608,6 @@ async def test_run_computational_sidecar_dask( **sleeper_task.sidecar_params(), s3_settings=s3_settings, resources={}, - boot_mode=boot_mode, - task_labels={}, ) worker_name = next(iter(dask_client.scheduler_info()["workers"])) @@ -638,7 +662,6 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub dask_client: distributed.Client, sidecar_task: Callable[..., ServiceExampleParam], s3_settings: S3Settings, - boot_mode: BootMode, log_sub: distributed.Sub, progress_sub: distributed.Sub, mocked_get_image_labels: mock.Mock, @@ -660,8 +683,6 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub ).sidecar_params(), s3_settings=s3_settings, resources={}, - boot_mode=boot_mode, - task_labels={}, ) output_data = future.result() assert output_data is not None @@ -701,7 +722,6 @@ def test_failing_service_raises_exception( run_computational_sidecar( **failing_ubuntu_task.sidecar_params(), s3_settings=s3_settings, - task_labels={}, ) @@ -719,5 +739,4 @@ def test_running_service_that_generates_unexpected_data_raises_exception( run_computational_sidecar( **sleeper_task_unexpected_output.sidecar_params(), s3_settings=s3_settings, - task_labels={}, ) From d451eb7e5e48771980827b727bb196b516b5d0f3 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 17:57:39 +0100 Subject: [PATCH 04/21] fix assertion --- services/dask-sidecar/tests/unit/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index 29906dc8ead..02ba3cfb006 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -638,7 +638,7 @@ async def test_run_computational_sidecar_dask( ), f"Could not find {log} in worker_logs:\n {pformat(worker_logs, width=240)}" # check that the task produce the expected data, not less not more - assert isinstance(output_data, dict) + assert isinstance(output_data, TaskOutputData) for k, v in sleeper_task.expected_output_data.items(): assert k in output_data assert output_data[k] == v From 495205eab1031072135b9c95acd91908bdd7aa3f Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:03:48 +0100 Subject: [PATCH 05/21] remove unused fixture --- .../dask-sidecar/tests/unit/test_tasks.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index 02ba3cfb006..0df4f494cae 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -149,6 +149,7 @@ class ServiceExampleParam: task_envs: dict[EnvVarKey, str] task_owner: TaskOwner boot_mode: BootMode + s3_settings: S3Settings def sidecar_params(self) -> dict[str, Any]: return { @@ -165,6 +166,7 @@ def sidecar_params(self) -> dict[str, Any]: ), "docker_auth": self.docker_basic_auth, "log_file_url": self.log_file_url, + "s3_settings": self.s3_settings, } @@ -225,6 +227,7 @@ def sleeper_task( project_id: ProjectID, node_id: NodeID, task_owner: TaskOwner, + s3_settings: S3Settings, ) -> ServiceExampleParam: """Creates a console task in an ubuntu distro that checks for the expected files and error in case they are missing""" # let's have some input files on the file server @@ -406,6 +409,7 @@ def sleeper_task( task_envs=additional_envs, task_owner=task_owner, boot_mode=boot_mode, + s3_settings=s3_settings, ) @@ -417,6 +421,7 @@ def sidecar_task( boot_mode: BootMode, faker: Faker, task_owner: TaskOwner, + s3_settings: S3Settings, ) -> Callable[..., ServiceExampleParam]: def _creator(command: list[str] | None = None) -> ServiceExampleParam: return ServiceExampleParam( @@ -436,6 +441,7 @@ def _creator(command: list[str] | None = None) -> ServiceExampleParam: task_envs={}, task_owner=task_owner, boot_mode=boot_mode, + s3_settings=s3_settings, ) return _creator @@ -486,14 +492,10 @@ def test_run_computational_sidecar_real_fct( dask_subsystem_mock: dict[str, mock.Mock], sleeper_task: ServiceExampleParam, mocked_get_image_labels: mock.Mock, - user_id: UserID, - project_id: ProjectID, - node_id: NodeID, s3_settings: S3Settings, ): output_data = run_computational_sidecar( **sleeper_task.sidecar_params(), - s3_settings=s3_settings, ) mocked_get_image_labels.assert_called_once_with( mock.ANY, @@ -553,7 +555,6 @@ def test_run_multiple_computational_sidecar_dask( dask_client: distributed.Client, sleeper_task: ServiceExampleParam, mocked_get_image_labels: mock.Mock, - s3_settings: S3Settings, ): NUMBER_OF_TASKS = 50 @@ -561,7 +562,6 @@ def test_run_multiple_computational_sidecar_dask( dask_client.submit( run_computational_sidecar, **sleeper_task.sidecar_params(), - s3_settings=s3_settings, resources={}, ) for _ in range(NUMBER_OF_TASKS) @@ -598,15 +598,14 @@ def progress_sub(dask_client: distributed.Client) -> distributed.Sub: async def test_run_computational_sidecar_dask( dask_client: distributed.Client, sleeper_task: ServiceExampleParam, - s3_settings: S3Settings, log_sub: distributed.Sub, progress_sub: distributed.Sub, mocked_get_image_labels: mock.Mock, + s3_settings: S3Settings, ): future = dask_client.submit( run_computational_sidecar, **sleeper_task.sidecar_params(), - s3_settings=s3_settings, resources={}, ) @@ -661,7 +660,6 @@ async def test_run_computational_sidecar_dask( async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub( dask_client: distributed.Client, sidecar_task: Callable[..., ServiceExampleParam], - s3_settings: S3Settings, log_sub: distributed.Sub, progress_sub: distributed.Sub, mocked_get_image_labels: mock.Mock, @@ -681,7 +679,6 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub ), ], ).sidecar_params(), - s3_settings=s3_settings, resources={}, ) output_data = future.result() @@ -715,14 +712,10 @@ def test_failing_service_raises_exception( app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], failing_ubuntu_task: ServiceExampleParam, - s3_settings: S3Settings, mocked_get_image_labels: mock.Mock, ): with pytest.raises(ServiceRuntimeError): - run_computational_sidecar( - **failing_ubuntu_task.sidecar_params(), - s3_settings=s3_settings, - ) + run_computational_sidecar(**failing_ubuntu_task.sidecar_params()) @pytest.mark.parametrize( @@ -733,10 +726,8 @@ def test_running_service_that_generates_unexpected_data_raises_exception( app_environment: EnvVarsDict, dask_subsystem_mock: dict[str, mock.Mock], sleeper_task_unexpected_output: ServiceExampleParam, - s3_settings: S3Settings, ): with pytest.raises(ServiceBadFormattedOutputError): run_computational_sidecar( **sleeper_task_unexpected_output.sidecar_params(), - s3_settings=s3_settings, ) From 5df68edd35c79c7b32aa77cdbfc5b242d281cdb2 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:11:04 +0100 Subject: [PATCH 06/21] ruff --- .../dask_task_models_library/container_tasks/events.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py index 27b43cea55d..e4e9a2a22e2 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import TypeAlias, Union +from typing import Any, ClassVar, TypeAlias from distributed.worker import get_worker from pydantic import BaseModel, Extra, validator @@ -31,7 +31,7 @@ def from_dask_worker(cls, progress: float) -> "TaskProgressEvent": return cls(job_id=get_worker().get_current_task(), progress=progress) class Config(BaseTaskEvent.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "job_id": "simcore/services/comp/sleeper:1.1.0:projectid_ec7e595a-63ee-46a1-a04a-901b11b649f8:nodeid_39467d89-b659-4914-9359-c40b1b6d1d6d:uuid_5ee5c655-450d-4711-a3ec-32ffe16bc580", @@ -69,7 +69,7 @@ def from_dask_worker(cls, log: str, log_level: LogLevelInt) -> "TaskLogEvent": return cls(job_id=get_worker().get_current_task(), log=log, log_level=log_level) class Config(BaseTaskEvent.Config): - schema_extra = { + schema_extra: ClassVar[dict[str, Any]] = { "examples": [ { "job_id": "simcore/services/comp/sleeper:1.1.0:projectid_ec7e595a-63ee-46a1-a04a-901b11b649f8:nodeid_39467d89-b659-4914-9359-c40b1b6d1d6d:uuid_5ee5c655-450d-4711-a3ec-32ffe16bc580", @@ -78,6 +78,3 @@ class Config(BaseTaskEvent.Config): }, ] } - - -DaskTaskEvents = type[Union[TaskLogEvent, TaskProgressEvent]] From d33485c2b5c58166c71cd166383e5d8da6be370b Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:21:52 +0100 Subject: [PATCH 07/21] added examples --- .../container_tasks/events.py | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py index e4e9a2a22e2..878e53a7969 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py @@ -5,9 +5,12 @@ from distributed.worker import get_worker from pydantic import BaseModel, Extra, validator +from .protocol import TaskOwner + class BaseTaskEvent(BaseModel, ABC): job_id: str + task_owner: TaskOwner msg: str | None = None @staticmethod @@ -27,8 +30,14 @@ def topic_name() -> str: return "task_progress" @classmethod - def from_dask_worker(cls, progress: float) -> "TaskProgressEvent": - return cls(job_id=get_worker().get_current_task(), progress=progress) + def from_dask_worker( + cls, progress: float, *, task_owner: TaskOwner + ) -> "TaskProgressEvent": + return cls( + job_id=get_worker().get_current_task(), + progress=progress, + task_owner=task_owner, + ) class Config(BaseTaskEvent.Config): schema_extra: ClassVar[dict[str, Any]] = { @@ -36,10 +45,24 @@ class Config(BaseTaskEvent.Config): { "job_id": "simcore/services/comp/sleeper:1.1.0:projectid_ec7e595a-63ee-46a1-a04a-901b11b649f8:nodeid_39467d89-b659-4914-9359-c40b1b6d1d6d:uuid_5ee5c655-450d-4711-a3ec-32ffe16bc580", "progress": 0, + "task_owner": { + "user_id": 32, + "project_id": "ec7e595a-63ee-46a1-a04a-901b11b649f8", + "node_id": "39467d89-b659-4914-9359-c40b1b6d1d6d", + "parent_project_id": None, + "parent_node_id": None, + }, }, { "job_id": "simcore/services/comp/sleeper:1.1.0:projectid_ec7e595a-63ee-46a1-a04a-901b11b649f8:nodeid_39467d89-b659-4914-9359-c40b1b6d1d6d:uuid_5ee5c655-450d-4711-a3ec-32ffe16bc580", "progress": 1.0, + "task_owner": { + "user_id": 32, + "project_id": "ec7e595a-63ee-46a1-a04a-901b11b649f8", + "node_id": "39467d89-b659-4914-9359-c40b1b6d1d6d", + "parent_project_id": "887e595a-63ee-46a1-a04a-901b11b649f8", + "parent_node_id": "aa467d89-b659-4914-9359-c40b1b6d1d6d", + }, }, ] } @@ -65,8 +88,15 @@ def topic_name() -> str: return "task_logs" @classmethod - def from_dask_worker(cls, log: str, log_level: LogLevelInt) -> "TaskLogEvent": - return cls(job_id=get_worker().get_current_task(), log=log, log_level=log_level) + def from_dask_worker( + cls, log: str, log_level: LogLevelInt, *, task_owner: TaskOwner + ) -> "TaskLogEvent": + return cls( + job_id=get_worker().get_current_task(), + log=log, + log_level=log_level, + task_owner=task_owner, + ) class Config(BaseTaskEvent.Config): schema_extra: ClassVar[dict[str, Any]] = { @@ -75,6 +105,13 @@ class Config(BaseTaskEvent.Config): "job_id": "simcore/services/comp/sleeper:1.1.0:projectid_ec7e595a-63ee-46a1-a04a-901b11b649f8:nodeid_39467d89-b659-4914-9359-c40b1b6d1d6d:uuid_5ee5c655-450d-4711-a3ec-32ffe16bc580", "log": "some logs", "log_level": logging.INFO, + "task_owner": { + "user_id": 32, + "project_id": "ec7e595a-63ee-46a1-a04a-901b11b649f8", + "node_id": "39467d89-b659-4914-9359-c40b1b6d1d6d", + "parent_project_id": None, + "parent_node_id": None, + }, }, ] } From 2924a33e0006ec8be3b724193cd0b6fb1fd7e601 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:25:19 +0100 Subject: [PATCH 08/21] getting there --- .../container_tasks/protocol.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py index 4af5b3510a9..322c1f8aa79 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py @@ -1,4 +1,4 @@ -from typing import Any, Protocol, TypeAlias +from typing import Any, ClassVar, Protocol, TypeAlias from models_library.basic_types import EnvVarKey from models_library.docker import DockerLabelKey @@ -40,6 +40,26 @@ def check_parent_valid(cls, values: dict[str, Any]) -> dict[str, Any]: raise ValueError(msg) return values + class Config: + schema_extra: ClassVar[dict[str, Any]] = { + "examples": [ + { + "user_id": 32, + "project_id": "ec7e595a-63ee-46a1-a04a-901b11b649f8", + "node_id": "39467d89-b659-4914-9359-c40b1b6d1d6d", + "parent_project_id": None, + "parent_node_id": None, + }, + { + "user_id": 32, + "project_id": "ec7e595a-63ee-46a1-a04a-901b11b649f8", + "node_id": "39467d89-b659-4914-9359-c40b1b6d1d6d", + "parent_project_id": "887e595a-63ee-46a1-a04a-901b11b649f8", + "parent_node_id": "aa467d89-b659-4914-9359-c40b1b6d1d6d", + }, + ] + } + class ContainerTaskParameters(BaseModel): image: ContainerImage From 996d824247adcfa7f3637c45a20ad6060c1e359b Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:25:39 +0100 Subject: [PATCH 09/21] getting there --- .../tests/container_tasks/test_events.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/packages/dask-task-models-library/tests/container_tasks/test_events.py b/packages/dask-task-models-library/tests/container_tasks/test_events.py index 55c5bb1d1bf..309f45f9a25 100644 --- a/packages/dask-task-models-library/tests/container_tasks/test_events.py +++ b/packages/dask-task-models-library/tests/container_tasks/test_events.py @@ -13,6 +13,8 @@ TaskLogEvent, TaskProgressEvent, ) +from dask_task_models_library.container_tasks.protocol import TaskOwner +from faker import Faker from pytest_mock.plugin import MockerFixture @@ -45,6 +47,11 @@ def mocked_dask_worker_job_id(mocker: MockerFixture) -> str: return fake_job_id +@pytest.fixture() +def task_owner(faker: Faker) -> TaskOwner: + return False + + def test_task_progress_from_worker(mocked_dask_worker_job_id: str): event = TaskProgressEvent.from_dask_worker(0.7) From ffa9647722fa71fbca0e5255055fc3437aece050 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 09:22:33 +0100 Subject: [PATCH 10/21] install settings lib as requirement test 98% --- .../requirements/_base.in | 1 + .../requirements/ci.txt | 1 + .../requirements/dev.txt | 1 + .../container_tasks/protocol.py | 19 +++++++++++ .../tests/container_tasks/test_events.py | 33 ++++++++++++++----- .../tests/container_tasks/test_protocol.py | 31 +++++++++++++++++ 6 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 packages/dask-task-models-library/tests/container_tasks/test_protocol.py diff --git a/packages/dask-task-models-library/requirements/_base.in b/packages/dask-task-models-library/requirements/_base.in index b137bd88365..3cdef671c4b 100644 --- a/packages/dask-task-models-library/requirements/_base.in +++ b/packages/dask-task-models-library/requirements/_base.in @@ -3,6 +3,7 @@ # --constraint ../../../requirements/constraints.txt --requirement ../../../packages/models-library/requirements/_base.in +--requirement ../../../packages/settings-library/requirements/_base.in dask[distributed] pydantic[email] diff --git a/packages/dask-task-models-library/requirements/ci.txt b/packages/dask-task-models-library/requirements/ci.txt index e4f199bc3f8..b8e5d4577dc 100644 --- a/packages/dask-task-models-library/requirements/ci.txt +++ b/packages/dask-task-models-library/requirements/ci.txt @@ -14,6 +14,7 @@ # installs this repo's packages ../pytest-simcore/ ../models-library/ +../settings-library/ # current module . diff --git a/packages/dask-task-models-library/requirements/dev.txt b/packages/dask-task-models-library/requirements/dev.txt index 33506f6a8be..0edd20961ac 100644 --- a/packages/dask-task-models-library/requirements/dev.txt +++ b/packages/dask-task-models-library/requirements/dev.txt @@ -14,6 +14,7 @@ # installs this repo's packages --editable ../pytest-simcore/ --editable ../models-library/ +--editable ../settings-library/ # current module --editable . diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py index 322c1f8aa79..ad4adfdd9f3 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py @@ -72,6 +72,25 @@ class ContainerTaskParameters(BaseModel): boot_mode: BootMode task_owner: TaskOwner + class Config: + schema_extra: ClassVar[dict[str, Any]] = { + "examples": [ + { + "image": "ubuntu", + "tag": "latest", + "input_data": TaskInputData.Config.schema_extra["examples"][0], + "output_data_keys": TaskOutputDataSchema.Config.schema_extra[ + "examples" + ][0], + "command": ["sleep 10", "echo hello"], + "envs": {"MYENV": "is an env"}, + "labels": {"io.simcore.thelabel": "is amazing"}, + "boot_mode": BootMode.CPU.value, + "task_owner": TaskOwner.Config.schema_extra["examples"][0], + }, + ] + } + class ContainerRemoteFct(Protocol): def __call__( diff --git a/packages/dask-task-models-library/tests/container_tasks/test_events.py b/packages/dask-task-models-library/tests/container_tasks/test_events.py index 309f45f9a25..ce085c43e88 100644 --- a/packages/dask-task-models-library/tests/container_tasks/test_events.py +++ b/packages/dask-task-models-library/tests/container_tasks/test_events.py @@ -14,7 +14,6 @@ TaskProgressEvent, ) from dask_task_models_library.container_tasks.protocol import TaskOwner -from faker import Faker from pytest_mock.plugin import MockerFixture @@ -47,23 +46,41 @@ def mocked_dask_worker_job_id(mocker: MockerFixture) -> str: return fake_job_id -@pytest.fixture() -def task_owner(faker: Faker) -> TaskOwner: - return False +@pytest.fixture(params=TaskOwner.Config.schema_extra["examples"]) +def task_owner(request: pytest.FixtureRequest) -> TaskOwner: + return TaskOwner(**request.param) -def test_task_progress_from_worker(mocked_dask_worker_job_id: str): - event = TaskProgressEvent.from_dask_worker(0.7) +def test_task_progress_from_worker( + mocked_dask_worker_job_id: str, task_owner: TaskOwner +): + event = TaskProgressEvent.from_dask_worker(0.7, task_owner=task_owner) assert event.job_id == mocked_dask_worker_job_id assert event.progress == 0.7 -def test_task_log_from_worker(mocked_dask_worker_job_id: str): +def test_task_log_from_worker(mocked_dask_worker_job_id: str, task_owner: TaskOwner): event = TaskLogEvent.from_dask_worker( - log="here is the amazing logs", log_level=logging.INFO + log="here is the amazing logs", log_level=logging.INFO, task_owner=task_owner ) assert event.job_id == mocked_dask_worker_job_id assert event.log == "here is the amazing logs" assert event.log_level == logging.INFO + + +@pytest.mark.parametrize( + "progress_value, expected_progress", [(1.5, 1), (-0.5, 0), (0.75, 0.75)] +) +def test_task_progress_progress_value_is_capped_between_0_and_1( + mocked_dask_worker_job_id: str, + task_owner: TaskOwner, + progress_value: float, + expected_progress: float, +): + event = TaskProgressEvent( + job_id=mocked_dask_worker_job_id, task_owner=task_owner, progress=progress_value + ) + assert event + assert event.progress == expected_progress diff --git a/packages/dask-task-models-library/tests/container_tasks/test_protocol.py b/packages/dask-task-models-library/tests/container_tasks/test_protocol.py new file mode 100644 index 00000000000..d17202adabd --- /dev/null +++ b/packages/dask-task-models-library/tests/container_tasks/test_protocol.py @@ -0,0 +1,31 @@ +import pytest +from dask_task_models_library.container_tasks.protocol import ( + ContainerTaskParameters, + TaskOwner, +) +from faker import Faker +from pydantic import ValidationError + + +@pytest.mark.parametrize("model_cls", [TaskOwner, ContainerTaskParameters]) +def test_events_models_examples(model_cls): + examples = model_cls.Config.schema_extra["examples"] + + for index, example in enumerate(examples): + print(f"{index:-^10}:\n", example) + + model_instance = model_cls(**example) + assert model_instance + + +def test_task_owner_parent_valid(faker: Faker): + invalid_task_owner_example = TaskOwner.Config.schema_extra["examples"][0] + invalid_task_owner_example["parent_project_id"] = faker.uuid4() + assert invalid_task_owner_example["parent_node_id"] is None + with pytest.raises(ValidationError, match=r".+ are None or both are set!"): + TaskOwner(**invalid_task_owner_example) + + invalid_task_owner_example["parent_project_id"] = None + invalid_task_owner_example["parent_node_id"] = faker.uuid4() + with pytest.raises(ValidationError, match=r".+ are None or both are set!"): + TaskOwner(**invalid_task_owner_example) From 3518066fea6e2f0c0501aaff7a896b9b8d065c0c Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:10:23 +0100 Subject: [PATCH 11/21] tests are passing again --- .../container_tasks/events.py | 9 ++- .../tests/container_tasks/test_events.py | 7 +- .../computational_sidecar/core.py | 9 ++- .../dask_utils.py | 48 +++++++------ .../src/simcore_service_dask_sidecar/tasks.py | 4 +- services/dask-sidecar/tests/unit/conftest.py | 67 +++++++++++++++---- .../tests/unit/test_dask_utils.py | 46 +++++++++---- .../dask-sidecar/tests/unit/test_tasks.py | 31 --------- 8 files changed, 131 insertions(+), 90 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py index 878e53a7969..33f41cbef88 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py @@ -33,8 +33,11 @@ def topic_name() -> str: def from_dask_worker( cls, progress: float, *, task_owner: TaskOwner ) -> "TaskProgressEvent": + worker = get_worker() + job_id = worker.get_current_task() + return cls( - job_id=get_worker().get_current_task(), + job_id=job_id, progress=progress, task_owner=task_owner, ) @@ -91,8 +94,10 @@ def topic_name() -> str: def from_dask_worker( cls, log: str, log_level: LogLevelInt, *, task_owner: TaskOwner ) -> "TaskLogEvent": + worker = get_worker() + job_id = worker.get_current_task() return cls( - job_id=get_worker().get_current_task(), + job_id=job_id, log=log, log_level=log_level, task_owner=task_owner, diff --git a/packages/dask-task-models-library/tests/container_tasks/test_events.py b/packages/dask-task-models-library/tests/container_tasks/test_events.py index ce085c43e88..528e208bf56 100644 --- a/packages/dask-task-models-library/tests/container_tasks/test_events.py +++ b/packages/dask-task-models-library/tests/container_tasks/test_events.py @@ -37,13 +37,12 @@ def test_events_models_examples(model_cls): @pytest.fixture() -def mocked_dask_worker_job_id(mocker: MockerFixture) -> str: +def mocked_dask_worker_job_id(mocker: MockerFixture, job_id: str) -> str: mock_get_worker = mocker.patch( "dask_task_models_library.container_tasks.events.get_worker", autospec=True ) - fake_job_id = "some_fake_job_id" - mock_get_worker.return_value.get_current_task.return_value = fake_job_id - return fake_job_id + mock_get_worker.return_value.get_current_task.return_value = job_id + return job_id @pytest.fixture(params=TaskOwner.Config.schema_extra["examples"]) diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py index 93d0e072917..ff43704d339 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/computational_sidecar/core.py @@ -14,7 +14,6 @@ from aiodocker import Docker from dask_task_models_library.container_tasks.docker import DockerBasicAuth from dask_task_models_library.container_tasks.errors import ServiceRuntimeError -from dask_task_models_library.container_tasks.events import TaskLogEvent from dask_task_models_library.container_tasks.io import FileUrl, TaskOutputData from dask_task_models_library.container_tasks.protocol import ContainerTaskParameters from packaging import version @@ -24,7 +23,7 @@ from settings_library.s3 import S3Settings from yarl import URL -from ..dask_utils import TaskPublisher, publish_event +from ..dask_utils import TaskPublisher from ..file_utils import pull_file_from_remote, push_file_to_remote from ..settings import Settings from .docker_utils import ( @@ -152,10 +151,10 @@ async def _retrieve_output_data( async def _publish_sidecar_log( self, log: LogMessageStr, log_level: LogLevelInt = logging.INFO ) -> None: - publish_event( - self.task_publishers.logs, - TaskLogEvent.from_dask_worker(log=f"[sidecar] {log}", log_level=log_level), + self.task_publishers.publish_logs( + message=f"[sidecar] {log}", log_level=log_level ) + logger.log(log_level, log) async def run(self, command: list[str]) -> TaskOutputData: diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py index 1bf14129833..a0defe028eb 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py @@ -1,8 +1,9 @@ import asyncio import contextlib import logging +from collections.abc import AsyncIterator from dataclasses import dataclass, field -from typing import AsyncIterator, Final +from typing import Final import distributed from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -12,6 +13,7 @@ TaskProgressEvent, ) from dask_task_models_library.container_tasks.io import TaskCancelEventName +from dask_task_models_library.container_tasks.protocol import TaskOwner from distributed.worker import get_worker from distributed.worker_state_machine import TaskState from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch @@ -55,8 +57,9 @@ def get_current_task_resources() -> dict[str, float]: return current_task_resources -@dataclass() +@dataclass(slots=True, kw_only=True) class TaskPublisher: + task_owner: TaskOwner progress: distributed.Pub = field(init=False) _last_published_progress_value: float = -1 logs: distributed.Pub = field(init=False) @@ -68,11 +71,14 @@ def __post_init__(self) -> None: def publish_progress(self, value: float) -> None: rounded_value = round(value, ndigits=2) if rounded_value > self._last_published_progress_value: - publish_event( - self.progress, - TaskProgressEvent.from_dask_worker(progress=rounded_value), - ) - self._last_published_progress_value = rounded_value + with log_catch(logger=logger, reraise=False): + publish_event( + self.progress, + TaskProgressEvent.from_dask_worker( + progress=rounded_value, task_owner=self.task_owner + ), + ) + self._last_published_progress_value = rounded_value def publish_logs( self, @@ -80,9 +86,13 @@ def publish_logs( message: LogMessageStr, log_level: LogLevelInt, ) -> None: - publish_event( - self.logs, TaskLogEvent.from_dask_worker(log=message, log_level=log_level) - ) + with log_catch(logger=logger, reraise=False): + publish_event( + self.logs, + TaskLogEvent.from_dask_worker( + log=message, log_level=log_level, task_owner=self.task_owner + ), + ) _TASK_ABORTION_INTERVAL_CHECK_S: int = 2 @@ -90,7 +100,7 @@ def publish_logs( @contextlib.asynccontextmanager async def monitor_task_abortion( - task_name: str, log_publisher: distributed.Pub + task_name: str, task_publishers: TaskPublisher ) -> AsyncIterator[None]: """This context manager periodically checks whether the client cancelled the monitored task. If that is the case, the monitored task will be cancelled (e.g. @@ -101,13 +111,9 @@ async def cancel_task(task_name: str) -> None: if task := next( (t for t in asyncio.all_tasks() if t.get_name() == task_name), None ): - publish_event( - log_publisher, - TaskLogEvent.from_dask_worker( - log="[sidecar] cancelling task...", log_level=logging.INFO - ), + task_publishers.publish_logs( + message="[sidecar] cancelling task...", log_level=logging.INFO ) - logger.debug("cancelling %s....................", f"{task=}") task.cancel() async def periodicaly_check_if_aborted(task_name: str) -> None: @@ -125,12 +131,10 @@ async def periodicaly_check_if_aborted(task_name: str) -> None: yield except asyncio.CancelledError as exc: - publish_event( - log_publisher, - TaskLogEvent.from_dask_worker( - log="[sidecar] task run was aborted", log_level=logging.INFO - ), + task_publishers.publish_logs( + message="[sidecar] task run was aborted", log_level=logging.INFO ) + raise TaskCancelledError from exc finally: if periodically_checking_task: diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py index 658e40656df..dfb7c971981 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/tasks.py @@ -88,7 +88,7 @@ async def _run_computational_sidecar_async( log_file_url: LogFileUploadURL, s3_settings: S3Settings | None, ) -> TaskOutputData: - task_publishers = TaskPublisher() + task_publishers = TaskPublisher(task_owner=task_parameters.task_owner) _logger.debug( "run_computational_sidecar %s", @@ -97,7 +97,7 @@ async def _run_computational_sidecar_async( current_task = asyncio.current_task() assert current_task # nosec async with monitor_task_abortion( - task_name=current_task.get_name(), log_publisher=task_publishers.logs + task_name=current_task.get_name(), task_publishers=task_publishers ): task_max_resources = get_current_task_resources() async with ComputationalSidecar( diff --git a/services/dask-sidecar/tests/unit/conftest.py b/services/dask-sidecar/tests/unit/conftest.py index 40fa1466bc3..79005bff754 100644 --- a/services/dask-sidecar/tests/unit/conftest.py +++ b/services/dask-sidecar/tests/unit/conftest.py @@ -3,19 +3,23 @@ # pylint: disable=unused-variable # pylint: disable=too-many-arguments +from collections.abc import AsyncIterator, Callable, Iterator from pathlib import Path from pprint import pformat -from typing import AsyncIterator, Callable, Iterator import dask +import dask.config import distributed import fsspec import pytest import simcore_service_dask_sidecar from aiobotocore.session import AioBaseClient, get_session +from dask_task_models_library.container_tasks.protocol import TaskOwner from faker import Faker +from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID +from models_library.users import UserID from pydantic import AnyUrl, parse_obj_as -from pytest import MonkeyPatch from pytest_localftpserver.servers import ProcessFTPServer from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict @@ -72,7 +76,9 @@ def shared_data_folder( @pytest.fixture def app_environment( - monkeypatch: MonkeyPatch, env_devel_dict: EnvVarsDict, shared_data_folder: Path + monkeypatch: pytest.MonkeyPatch, + env_devel_dict: EnvVarsDict, + shared_data_folder: Path, ) -> EnvVarsDict: # configured as worker envs = setenvs_from_dict( @@ -99,10 +105,8 @@ def local_cluster(app_environment: EnvVarsDict) -> Iterator[distributed.LocalClu print(pformat(dask.config.get("distributed"))) with distributed.LocalCluster( worker_class=distributed.Worker, - **{ - "resources": {"CPU": 10, "GPU": 10}, - "preload": "simcore_service_dask_sidecar.tasks", - }, + resources={"CPU": 10, "GPU": 10}, + preload="simcore_service_dask_sidecar.tasks", ) as cluster: assert cluster assert isinstance(cluster, distributed.LocalCluster) @@ -124,10 +128,8 @@ async def async_local_cluster( print(pformat(dask.config.get("distributed"))) async with distributed.LocalCluster( worker_class=distributed.Worker, - **{ - "resources": {"CPU": 10, "GPU": 10}, - "preload": "simcore_service_dask_sidecar.tasks", - }, + resources={"CPU": 10, "GPU": 10}, + preload="simcore_service_dask_sidecar.tasks", asynchronous=True, ) as cluster: assert cluster @@ -201,7 +203,7 @@ async def bucket( assert response["Buckets"] assert len(response["Buckets"]) == 1 bucket_name = response["Buckets"][0]["Name"] - yield bucket_name + return bucket_name # await _clean_bucket_content(aiobotocore_s3_client, bucket_name) @@ -243,3 +245,44 @@ def creator() -> AnyUrl: fs = fsspec.filesystem("s3", **s3_storage_kwargs) for file in list_of_created_files: fs.delete(file.partition(f"{file.scheme}://")[2]) + + +@pytest.fixture +def job_id() -> str: + return "some_incredible_string" + + +@pytest.fixture +def user_id(faker: Faker) -> UserID: + return faker.pyint(min_value=1) + + +@pytest.fixture +def project_id(faker: Faker) -> ProjectID: + return faker.uuid4(cast_to=None) + + +@pytest.fixture +def node_id(faker: Faker) -> NodeID: + return faker.uuid4(cast_to=None) + + +@pytest.fixture(params=["no_parent_node", "with_parent_node"]) +def task_owner( + user_id: UserID, + project_id: ProjectID, + node_id: NodeID, + request: pytest.FixtureRequest, + faker: Faker, +) -> TaskOwner: + return TaskOwner( + user_id=user_id, + project_id=project_id, + node_id=node_id, + parent_project_id=None + if request.param == "no_parent_node" + else faker.uuid4(cast_to=None), + parent_node_id=None + if request.param == "no_parent_node" + else faker.uuid4(cast_to=None), + ) diff --git a/services/dask-sidecar/tests/unit/test_dask_utils.py b/services/dask-sidecar/tests/unit/test_dask_utils.py index 544597a20da..36597cb287e 100644 --- a/services/dask-sidecar/tests/unit/test_dask_utils.py +++ b/services/dask-sidecar/tests/unit/test_dask_utils.py @@ -8,15 +8,18 @@ import concurrent.futures import logging import time -from typing import Any, AsyncIterator, Callable, Coroutine +from collections.abc import AsyncIterator, Callable, Coroutine +from typing import Any import distributed import pytest from dask_task_models_library.container_tasks.errors import TaskCancelledError from dask_task_models_library.container_tasks.events import TaskLogEvent from dask_task_models_library.container_tasks.io import TaskCancelEventName +from dask_task_models_library.container_tasks.protocol import TaskOwner from simcore_service_dask_sidecar.dask_utils import ( _DEFAULT_MAX_RESOURCES, + TaskPublisher, get_current_task_resources, is_current_task_aborted, monitor_task_abortion, @@ -31,11 +34,16 @@ DASK_TESTING_TIMEOUT_S = 25 -def test_publish_event(dask_client: distributed.Client): +def test_publish_event( + dask_client: distributed.Client, job_id: str, task_owner: TaskOwner +): dask_pub = distributed.Pub("some_topic", client=dask_client) dask_sub = distributed.Sub("some_topic", client=dask_client) event_to_publish = TaskLogEvent( - job_id="some_fake_job_id", log="the log", log_level=logging.INFO + job_id=job_id, + log="the log", + log_level=logging.INFO, + task_owner=task_owner, ) publish_event(dask_pub=dask_pub, event=event_to_publish) @@ -48,11 +56,13 @@ def test_publish_event(dask_client: distributed.Client): assert received_task_log_event == event_to_publish -async def test_publish_event_async(async_dask_client: distributed.Client): +async def test_publish_event_async( + async_dask_client: distributed.Client, job_id: str, task_owner: TaskOwner +): dask_pub = distributed.Pub("some_topic", client=async_dask_client) dask_sub = distributed.Sub("some_topic", client=async_dask_client) event_to_publish = TaskLogEvent( - job_id="some_fake_job_id", log="the log", log_level=logging.INFO + job_id=job_id, log="the log", log_level=logging.INFO, task_owner=task_owner ) publish_event(dask_pub=dask_pub, event=event_to_publish) @@ -86,6 +96,8 @@ def _creator(coro: Coroutine) -> asyncio.Task: async def test_publish_event_async_using_task( async_dask_client: distributed.Client, asyncio_task: Callable[[Coroutine], asyncio.Task], + job_id: str, + task_owner: TaskOwner, ): dask_pub = distributed.Pub("some_topic", client=async_dask_client) dask_sub = distributed.Sub("some_topic", client=async_dask_client) @@ -106,7 +118,10 @@ async def _dask_publisher_task(pub: distributed.Pub) -> None: print("--> starting publisher task") for n in range(NUMBER_OF_MESSAGES): event_to_publish = TaskLogEvent( - job_id="some_fake_job_id", log=f"the log {n}", log_level=logging.INFO + job_id=job_id, + log=f"the log {n}", + log_level=logging.INFO, + task_owner=task_owner, ) publish_event(dask_pub=pub, event=event_to_publish) print("<-- finished publisher task") @@ -177,16 +192,20 @@ def test_task_is_aborted_using_event(dask_client: distributed.Client): assert result == -1 -def _some_long_running_task_with_monitoring() -> int: +def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int: assert is_current_task_aborted() is False # we are started now start_event = distributed.Event(DASK_TASK_STARTED_EVENT) start_event.set() async def _long_running_task_async() -> int: - log_publisher = distributed.Pub(TaskLogEvent.topic_name()) + task_publishers = TaskPublisher(task_owner=task_owner) _notify_task_is_started_and_ready() - async with monitor_task_abortion(task_name=asyncio.current_task().get_name(), log_publisher=log_publisher): # type: ignore + current_task = asyncio.current_task() + assert current_task + async with monitor_task_abortion( + task_name=current_task.get_name(), task_publishers=task_publishers + ): for i in range(300): print("running iteration", i) await asyncio.sleep(0.5) @@ -201,9 +220,12 @@ async def _long_running_task_async() -> int: return asyncio.get_event_loop().run_until_complete(_long_running_task_async()) -def test_monitor_task_abortion(dask_client: distributed.Client): - job_id = "myfake_job_id" - future = dask_client.submit(_some_long_running_task_with_monitoring, key=job_id) +def test_monitor_task_abortion( + dask_client: distributed.Client, job_id: str, task_owner: TaskOwner +): + future = dask_client.submit( + _some_long_running_task_with_monitoring, task_owner=task_owner, key=job_id + ) _wait_for_task_to_start() # trigger cancellation dask_event = distributed.Event(TaskCancelEventName.format(job_id)) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index 0df4f494cae..f2751ff835d 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -65,26 +65,6 @@ logger = logging.getLogger(__name__) -@pytest.fixture -def job_id() -> str: - return "some_incredible_string" - - -@pytest.fixture -def user_id(faker: Faker) -> UserID: - return faker.pyint(min_value=1) - - -@pytest.fixture -def project_id(faker: Faker) -> ProjectID: - return faker.uuid4(cast_to=None) - - -@pytest.fixture -def node_id(faker: Faker) -> NodeID: - return faker.uuid4(cast_to=None) - - @pytest.fixture() def dask_subsystem_mock(mocker: MockerFixture) -> dict[str, mock.Mock]: # mock dask client @@ -204,17 +184,6 @@ def additional_envs(faker: Faker) -> dict[EnvVarKey, str]: return parse_obj_as(dict[EnvVarKey, str], faker.pydict(allowed_types=(str,))) -@pytest.fixture -def task_owner(user_id: UserID, project_id: ProjectID, node_id: NodeID) -> TaskOwner: - return TaskOwner( - user_id=user_id, - project_id=project_id, - node_id=node_id, - parent_project_id=None, - parent_node_id=None, - ) - - @pytest.fixture def sleeper_task( integration_version: version.Version, From ce24b88cc2c96fbabba22dadc3087baf5d946630 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:49:12 +0100 Subject: [PATCH 12/21] refactor --- .../modules/dask_client.py | 92 ++++++++----------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index d2dd0c1926a..7f5c3274360 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -66,23 +66,7 @@ from ..models.comp_tasks import Image from ..models.dask_subsystem import DaskClientTaskState from ..modules.storage import StorageClient -from ..utils.dask import ( - check_communication_with_scheduler_is_open, - check_if_cluster_is_able_to_run_pipeline, - check_maximize_workers, - check_scheduler_is_still_the_same, - check_scheduler_status, - compute_input_data, - compute_output_data_schema, - compute_service_log_file_upload_link, - compute_task_envs, - compute_task_labels, - create_node_ports, - dask_sub_consumer_task, - from_node_reqs_to_dask_resources, - generate_dask_job_id, - wrap_client_async_routine, -) +from ..utils import dask as dask_utils from ..utils.dask_client_utils import ( DaskSubSystem, TaskHandlers, @@ -158,7 +142,7 @@ async def create( backend = await create_internal_client_based_on_auth( endpoint, authentication ) - check_scheduler_status(backend.client) + dask_utils.check_scheduler_status(backend.client) instance = cls( app=app, backend=backend, @@ -195,7 +179,7 @@ def register_handlers(self, task_handlers: TaskHandlers) -> None: ] self._subscribed_tasks = [ asyncio.create_task( - dask_sub_consumer_task(dask_sub, handler), + dask_utils.dask_sub_consumer_task(dask_sub, handler), name=f"{dask_sub.name}_dask_sub_consumer_task", ) for dask_sub, handler in _event_consumer_map @@ -252,7 +236,7 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 remote_fct = _comp_sidecar_fct list_of_node_id_to_job_id: list[tuple[NodeID, str]] = [] for node_id, node_image in tasks.items(): - job_id = generate_dask_job_id( + job_id = dask_utils.generate_dask_job_id( service_key=node_image.name, service_version=node_image.tag, user_id=user_id, @@ -260,7 +244,7 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 node_id=node_id, ) assert node_image.node_requirements # nosec - dask_resources = from_node_reqs_to_dask_resources( + dask_resources = dask_utils.from_node_reqs_to_dask_resources( node_image.node_requirements ) if hardware_info.aws_ec2_instances: @@ -268,12 +252,12 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 f"{DASK_TASK_EC2_RESOURCE_RESTRICTION_KEY}:{hardware_info.aws_ec2_instances[0]}" ] = 1 - check_scheduler_is_still_the_same( + dask_utils.check_scheduler_is_still_the_same( self.backend.scheduler_id, self.backend.client ) - check_communication_with_scheduler_is_open(self.backend.client) - check_scheduler_status(self.backend.client) - await check_maximize_workers(self.backend.gateway_cluster) + dask_utils.check_communication_with_scheduler_is_open(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) + await dask_utils.check_maximize_workers(self.backend.gateway_cluster) # NOTE: in case it's a gateway or it is an on-demand cluster # we do not check a priori if the task # is runnable because we CAN'T. A cluster might auto-scale, the worker(s) @@ -283,7 +267,7 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 self.backend.gateway is None ): _logger.warning("cluster type: %s", self.cluster_type) - check_if_cluster_is_able_to_run_pipeline( + dask_utils.check_if_cluster_is_able_to_run_pipeline( project_id=project_id, node_id=node_id, scheduler_info=self.backend.client.scheduler_info(), @@ -302,40 +286,40 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 raise ComputationalBackendNoS3AccessError from err # This instance is created only once so it can be reused in calls below - node_ports = await create_node_ports( + node_ports = await dask_utils.create_node_ports( db_engine=self.app.state.engine, user_id=user_id, project_id=project_id, node_id=node_id, ) # NOTE: for download there is no need to go with S3 links - input_data = await compute_input_data( + input_data = await dask_utils.compute_input_data( project_id=project_id, node_id=node_id, node_ports=node_ports, file_link_type=FileLinkType.PRESIGNED, ) - output_data_keys = await compute_output_data_schema( + output_data_keys = await dask_utils.compute_output_data_schema( user_id=user_id, project_id=project_id, node_id=node_id, node_ports=node_ports, file_link_type=self.tasks_file_link_type, ) - log_file_url = await compute_service_log_file_upload_link( + log_file_url = await dask_utils.compute_service_log_file_upload_link( user_id, project_id, node_id, file_link_type=self.tasks_file_link_type, ) - task_labels = compute_task_labels( + task_labels = dask_utils.compute_task_labels( user_id=user_id, project_id=project_id, node_id=node_id, run_metadata=metadata, node_requirements=node_image.node_requirements, ) - task_envs = await compute_task_envs( + task_envs = await dask_utils.compute_task_envs( self.app, user_id=user_id, project_id=project_id, @@ -373,7 +357,7 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 task_future.add_done_callback(lambda _: callback()) list_of_node_id_to_job_id.append((node_id, job_id)) - await wrap_client_async_routine( + await dask_utils.wrap_client_async_routine( self.backend.client.publish_dataset(task_future, name=job_id) ) @@ -384,17 +368,17 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 ) except Exception: # Dask raises a base Exception here in case of connection error, this will raise a more precise one - check_scheduler_status(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) # if the connection is good, then the problem is different, so we re-raise raise return list_of_node_id_to_job_id async def get_tasks_status(self, job_ids: list[str]) -> list[DaskClientTaskState]: - check_scheduler_is_still_the_same( + dask_utils.check_scheduler_is_still_the_same( self.backend.scheduler_id, self.backend.client ) - check_communication_with_scheduler_is_open(self.backend.client) - check_scheduler_status(self.backend.client) + dask_utils.check_communication_with_scheduler_is_open(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) # try to get the task from the scheduler def _get_pipeline_statuses( @@ -405,7 +389,7 @@ def _get_pipeline_statuses( ) return statuses - task_statuses = await wrap_client_async_routine( + task_statuses = await dask_utils.wrap_client_async_routine( self.backend.client.run_on_scheduler(_get_pipeline_statuses) ) _logger.debug("found dask task statuses: %s", f"{task_statuses=}") @@ -415,7 +399,7 @@ def _get_pipeline_statuses( dask_status = task_statuses.get(job_id, "lost") if dask_status == "erred": # find out if this was a cancellation - exception = await wrap_client_async_routine( + exception = await dask_utils.wrap_client_async_routine( distributed.Future(job_id).exception( timeout=_DASK_DEFAULT_TIMEOUT_S ) @@ -449,16 +433,18 @@ async def abort_computation_task(self, job_id: str) -> None: # process, and report when it is finished and properly cancelled. _logger.debug("cancelling task with %s", f"{job_id=}") try: - task_future: distributed.Future = await wrap_client_async_routine( - self.backend.client.get_dataset(name=job_id) + task_future: distributed.Future = ( + await dask_utils.wrap_client_async_routine( + self.backend.client.get_dataset(name=job_id) + ) ) # NOTE: It seems there is a bug in the pubsub system in dask # Event are more robust to connections/disconnections cancel_event = await distributed.Event( name=TaskCancelEventName.format(job_id), client=self.backend.client ) - await wrap_client_async_routine(cancel_event.set()) - await wrap_client_async_routine(task_future.cancel()) + await dask_utils.wrap_client_async_routine(cancel_event.set()) + await dask_utils.wrap_client_async_routine(task_future.cancel()) _logger.debug("Dask task %s cancelled", task_future.key) except KeyError: _logger.warning("Unknown task cannot be aborted: %s", f"{job_id=}") @@ -466,10 +452,12 @@ async def abort_computation_task(self, job_id: str) -> None: async def get_task_result(self, job_id: str) -> TaskOutputData: _logger.debug("getting result of %s", f"{job_id=}") try: - task_future: distributed.Future = await wrap_client_async_routine( - self.backend.client.get_dataset(name=job_id) + task_future: distributed.Future = ( + await dask_utils.wrap_client_async_routine( + self.backend.client.get_dataset(name=job_id) + ) ) - return await wrap_client_async_routine( + return await dask_utils.wrap_client_async_routine( task_future.result(timeout=_DASK_DEFAULT_TIMEOUT_S) ) except KeyError as exc: @@ -481,21 +469,21 @@ async def release_task_result(self, job_id: str) -> None: _logger.debug("releasing results for %s", f"{job_id=}") try: # first check if the key exists - await wrap_client_async_routine( + await dask_utils.wrap_client_async_routine( self.backend.client.get_dataset(name=job_id) ) - await wrap_client_async_routine( + await dask_utils.wrap_client_async_routine( self.backend.client.unpublish_dataset(name=job_id) ) except KeyError: _logger.warning("Unknown task cannot be unpublished: %s", f"{job_id=}") async def get_cluster_details(self) -> ClusterDetails: - check_scheduler_is_still_the_same( + dask_utils.check_scheduler_is_still_the_same( self.backend.scheduler_id, self.backend.client ) - check_communication_with_scheduler_is_open(self.backend.client) - check_scheduler_status(self.backend.client) + dask_utils.check_communication_with_scheduler_is_open(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) scheduler_info = self.backend.client.scheduler_info() scheduler_status = self.backend.client.status dashboard_link = self.backend.client.dashboard_link @@ -512,7 +500,7 @@ def _get_worker_used_resources( # NOTE: this runs directly on the dask-scheduler and may rise exceptions used_resources_per_worker: dict[ str, dict[str, Any] - ] = await wrap_client_async_routine( + ] = await dask_utils.wrap_client_async_routine( self.backend.client.run_on_scheduler(_get_worker_used_resources) ) From fdf3d18d152f2fc62475510afc7f59a433c203b8 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:58:23 +0100 Subject: [PATCH 13/21] upgraded dask client --- .../modules/dask_client.py | 51 +++++++------------ .../simcore_service_director_v2/utils/dask.py | 20 +++++++- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index 7f5c3274360..829cd6f0a4f 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -24,17 +24,11 @@ from dask_task_models_library.container_tasks.errors import TaskCancelledError from dask_task_models_library.container_tasks.io import ( TaskCancelEventName, - TaskInputData, TaskOutputData, - TaskOutputDataSchema, ) from dask_task_models_library.container_tasks.protocol import ( - ContainerCommands, - ContainerEnvsDict, - ContainerImage, - ContainerLabelsDict, ContainerRemoteFct, - ContainerTag, + ContainerTaskParameters, LogFileUploadURL, ) from distributed.scheduler import TaskStateState as DaskSchedulerTaskState @@ -44,7 +38,6 @@ from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.resource_tracker import HardwareInfo -from models_library.services_resources import BootMode from models_library.users import UserID from pydantic import parse_obj_as from pydantic.networks import AnyUrl @@ -200,36 +193,22 @@ async def send_computation_tasks( """actually sends the function remote_fct to be remotely executed. if None is kept then the default function that runs container will be started.""" - def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 + def _comp_sidecar_fct( *, + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode, ) -> TaskOutputData: """This function is serialized by the Dask client and sent over to the Dask sidecar(s) Therefore, (screaming here) DO NOT MOVE THAT IMPORT ANYWHERE ELSE EVER!!""" from simcore_service_dask_sidecar.tasks import run_computational_sidecar return run_computational_sidecar( + task_parameters=task_parameters, docker_auth=docker_auth, - service_key=service_key, - service_version=service_version, - input_data=input_data, - output_data_keys=output_data_keys, log_file_url=log_file_url, - command=command, - task_envs=task_envs, - task_labels=task_labels, s3_settings=s3_settings, - boot_mode=boot_mode, ) if remote_fct is None: @@ -327,6 +306,9 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 node_image=node_image, metadata=metadata, ) + task_owner = dask_utils.compute_task_owner( + user_id, project_id, node_id, metadata.get("project_metadata", {}) + ) try: assert self.app.state # nosec @@ -334,21 +316,24 @@ def _comp_sidecar_fct( # pylint: disable=too-many-arguments # noqa: PLR0913 settings: AppSettings = self.app.state.settings task_future = self.backend.client.submit( remote_fct, + task_parameters=ContainerTaskParameters( + image=node_image.name, + tag=node_image.tag, + input_data=input_data, + output_data_keys=output_data_keys, + command=node_image.command, + envs=task_envs, + labels=task_labels, + boot_mode=node_image.boot_mode, + task_owner=task_owner, + ), docker_auth=DockerBasicAuth( server_address=settings.DIRECTOR_V2_DOCKER_REGISTRY.resolved_registry_url, username=settings.DIRECTOR_V2_DOCKER_REGISTRY.REGISTRY_USER, password=settings.DIRECTOR_V2_DOCKER_REGISTRY.REGISTRY_PW, ), - service_key=node_image.name, - service_version=node_image.tag, - input_data=input_data, - output_data_keys=output_data_keys, log_file_url=log_file_url, - command=node_image.command, - task_envs=task_envs, - task_labels=task_labels, s3_settings=s3_settings, - boot_mode=node_image.boot_mode, key=job_id, resources=dask_resources, retries=0, diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask.py b/services/director-v2/src/simcore_service_director_v2/utils/dask.py index de82bcdeecb..e96bd146364 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask.py @@ -18,6 +18,7 @@ from dask_task_models_library.container_tasks.protocol import ( ContainerEnvsDict, ContainerLabelsDict, + TaskOwner, ) from fastapi import FastAPI from models_library.api_schemas_directorv2.services import NodeRequirements @@ -39,8 +40,8 @@ from simcore_sdk.node_ports_v2 import FileLinkType, Port, links, port_utils from simcore_sdk.node_ports_v2.links import ItemValue as _NPItemValue from simcore_sdk.node_ports_v2.ports_mapping import PortKey -from simcore_service_director_v2.constants import UNDEFINED_DOCKER_LABEL +from ..constants import UNDEFINED_DOCKER_LABEL from ..core.errors import ( ComputationalBackendNotConnectedError, ComputationalSchedulerChangedError, @@ -48,7 +49,7 @@ MissingComputationalResourcesError, PortsValidationError, ) -from ..models.comp_runs import RunMetadataDict +from ..models.comp_runs import ProjectMetadataDict, RunMetadataDict from ..models.comp_tasks import Image from ..modules.osparc_variables_substitutions import ( resolve_and_substitute_session_variables_in_specs, @@ -623,3 +624,18 @@ async def wrap_client_async_routine( a union of types. this wrapper makes both mypy and pylance happy""" assert client_coroutine # nosec return await client_coroutine + + +def compute_task_owner( + user_id: UserID, + project_id: ProjectID, + node_id: ProjectID, + project_metadata: ProjectMetadataDict, +) -> TaskOwner: + return TaskOwner( + user_id=user_id, + project_id=project_id, + node_id=node_id, + parent_node_id=project_metadata.get("parent_node_id"), + parent_project_id=project_metadata.get("parent_project_id"), + ) From 1e46bfa2149eba4ca691374b5c385425e48514c4 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:03:08 +0100 Subject: [PATCH 14/21] clean --- .../director-v2/tests/unit/_dask_helpers.py | 53 +------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/services/director-v2/tests/unit/_dask_helpers.py b/services/director-v2/tests/unit/_dask_helpers.py index c942c6a6f36..9bf9a739946 100644 --- a/services/director-v2/tests/unit/_dask_helpers.py +++ b/services/director-v2/tests/unit/_dask_helpers.py @@ -1,21 +1,9 @@ # pylint:disable=unused-variable # pylint:disable=unused-argument -from typing import Any, NamedTuple +from typing import NamedTuple from dask_gateway_server.app import DaskGateway -from dask_task_models_library.container_tasks.docker import DockerBasicAuth -from dask_task_models_library.container_tasks.io import ( - TaskInputData, - TaskOutputData, - TaskOutputDataSchema, -) -from dask_task_models_library.container_tasks.protocol import ( - ContainerCommands, - ContainerImage, - ContainerTag, - LogFileUploadURL, -) class DaskGatewayServer(NamedTuple): @@ -23,42 +11,3 @@ class DaskGatewayServer(NamedTuple): proxy_address: str password: str server: DaskGateway - - -def fake_sidecar_fct( - docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, - log_file_url: LogFileUploadURL, - command: ContainerCommands, - expected_annotations: dict[str, Any], -) -> TaskOutputData: - import time - - from dask.distributed import get_worker - - # sleep a bit in case someone is aborting us - time.sleep(1) - - # get the task data - worker = get_worker() - task = worker.state.tasks.get(worker.get_current_task()) - assert task is not None - assert task.annotations == expected_annotations - - return TaskOutputData.parse_obj({"some_output_key": 123}) - - -def fake_failing_sidecar_fct( - docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, - log_file_url: LogFileUploadURL, - command: ContainerCommands, -) -> TaskOutputData: - err_msg = "sadly we are failing to execute anything cause we are dumb..." - raise ValueError(err_msg) From e7e09c3b5b9a64667173a2fc8382f9be65992bf6 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:03:16 +0100 Subject: [PATCH 15/21] fix mocks --- .../director-v2/tests/unit/test_modules_dask_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/services/director-v2/tests/unit/test_modules_dask_client.py b/services/director-v2/tests/unit/test_modules_dask_client.py index 680d91846b2..2719a996856 100644 --- a/services/director-v2/tests/unit/test_modules_dask_client.py +++ b/services/director-v2/tests/unit/test_modules_dask_client.py @@ -374,20 +374,20 @@ def image_params( @pytest.fixture def _mocked_node_ports(mocker: MockerFixture) -> None: mocker.patch( - "simcore_service_director_v2.modules.dask_client.create_node_ports", + "simcore_service_director_v2.modules.dask_client.dask_utils.create_node_ports", return_value=None, ) mocker.patch( - "simcore_service_director_v2.modules.dask_client.compute_input_data", + "simcore_service_director_v2.modules.dask_client.dask_utils.compute_input_data", return_value=TaskInputData.parse_obj({}), ) mocker.patch( - "simcore_service_director_v2.modules.dask_client.compute_output_data_schema", + "simcore_service_director_v2.modules.dask_client.dask_utils.compute_output_data_schema", return_value=TaskOutputDataSchema.parse_obj({}), ) mocker.patch( - "simcore_service_director_v2.modules.dask_client.compute_service_log_file_upload_link", + "simcore_service_director_v2.modules.dask_client.dask_utils.compute_service_log_file_upload_link", return_value=parse_obj_as(AnyUrl, "file://undefined"), ) From 6741117cc5d9662ac88d43278a811d87368a0c51 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:10:34 +0100 Subject: [PATCH 16/21] adjusted remote fct syntax --- .../tests/unit/test_modules_dask_client.py | 86 ++++--------------- 1 file changed, 16 insertions(+), 70 deletions(-) diff --git a/services/director-v2/tests/unit/test_modules_dask_client.py b/services/director-v2/tests/unit/test_modules_dask_client.py index 2719a996856..99022798fd2 100644 --- a/services/director-v2/tests/unit/test_modules_dask_client.py +++ b/services/director-v2/tests/unit/test_modules_dask_client.py @@ -32,11 +32,9 @@ TaskOutputDataSchema, ) from dask_task_models_library.container_tasks.protocol import ( - ContainerCommands, ContainerEnvsDict, - ContainerImage, ContainerLabelsDict, - ContainerTag, + ContainerTaskParameters, LogFileUploadURL, ) from distributed import Event, Scheduler @@ -55,7 +53,6 @@ from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.resource_tracker import HardwareInfo -from models_library.services_resources import BootMode from models_library.users import UserID from pydantic import AnyUrl, ByteSize, SecretStr from pydantic.tools import parse_obj_as @@ -523,17 +520,10 @@ async def test_send_computation_task( # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_sidecar_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode, expected_annotations: dict[str, Any], expected_envs: ContainerEnvsDict, expected_labels: ContainerLabelsDict, @@ -543,9 +533,9 @@ def fake_sidecar_fct( task = worker.state.tasks.get(worker.get_current_task()) assert task is not None assert task.annotations == expected_annotations - assert task_envs == expected_envs - assert task_labels == expected_labels - assert command == ["run"] + assert task_parameters.envs == expected_envs + assert task_parameters.labels == expected_labels + assert task_parameters.command == ["run"] event = distributed.Event(_DASK_EVENT_NAME) event.wait(timeout=25) @@ -553,12 +543,10 @@ def fake_sidecar_fct( # NOTE: We pass another fct so it can run in our localy created dask cluster # NOTE2: since there is only 1 task here, it's ok to pass the nodeID - assert image_params.fake_tasks[node_id].node_requirements is not None - assert isinstance( - image_params.fake_tasks[node_id].node_requirements, NodeRequirements - ) - assert image_params.fake_tasks[node_id].node_requirements.cpu - assert image_params.fake_tasks[node_id].node_requirements.ram + node_params = image_params.fake_tasks[node_id] + assert node_params.node_requirements is not None + assert node_params.node_requirements.cpu + assert node_params.node_requirements.ram assert "product_name" in comp_run_metadata assert "simcore_user_agent" in comp_run_metadata node_id_to_job_ids = await dask_client.send_computation_tasks( @@ -649,17 +637,10 @@ async def test_computation_task_is_persisted_on_dask_scheduler( # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_sidecar_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: # get the task data worker = get_worker() @@ -735,17 +716,10 @@ async def test_abort_computation_tasks( # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_remote_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: # get the task data worker = get_worker() @@ -826,17 +800,10 @@ async def test_failed_task_returns_exceptions( # NOTE: this must be inlined so that the test works, # the dask-worker must be able to import the function def fake_failing_sidecar_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: err_msg = "sadly we are failing to execute anything cause we are dumb..." raise ValueError(err_msg) @@ -1099,17 +1066,10 @@ async def test_get_tasks_status( _DASK_EVENT_NAME = faker.pystr() def fake_remote_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: # wait here until the client allows us to continue start_event = Event(_DASK_EVENT_NAME) @@ -1190,17 +1150,10 @@ async def test_dask_sub_handlers( _DASK_START_EVENT = "start" def fake_remote_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode = BootMode.CPU, ) -> TaskOutputData: progress_pub = distributed.Pub(TaskProgressEvent.topic_name()) logs_pub = distributed.Pub(TaskLogEvent.topic_name()) @@ -1271,17 +1224,10 @@ async def test_get_cluster_details( # send a fct that uses resources def fake_sidecar_fct( + task_parameters: ContainerTaskParameters, docker_auth: DockerBasicAuth, - service_key: ContainerImage, - service_version: ContainerTag, - input_data: TaskInputData, - output_data_keys: TaskOutputDataSchema, log_file_url: LogFileUploadURL, - command: ContainerCommands, - task_envs: ContainerEnvsDict, - task_labels: ContainerLabelsDict, s3_settings: S3Settings | None, - boot_mode: BootMode, expected_annotations, ) -> TaskOutputData: # get the task data @@ -1289,7 +1235,7 @@ def fake_sidecar_fct( task = worker.state.tasks.get(worker.get_current_task()) assert task is not None assert task.annotations == expected_annotations - assert command == ["run"] + assert task_parameters.command == ["run"] event = distributed.Event(_DASK_EVENT_NAME) event.wait(timeout=25) From 0eb10030ac0a19de00a7365d8afe6a8c92429e68 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:24:01 +0100 Subject: [PATCH 17/21] send logs to parent node as well --- .../container_tasks/protocol.py | 4 +++ .../modules/comp_scheduler/dask_scheduler.py | 25 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py index ad4adfdd9f3..00f89d96d94 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/protocol.py @@ -28,6 +28,10 @@ class TaskOwner(BaseModel): parent_project_id: ProjectID | None parent_node_id: NodeID | None + @property + def has_parent(self) -> bool: + return bool(self.parent_node_id and self.parent_project_id) + @root_validator @classmethod def check_parent_valid(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py index 14a4b0d42c3..18ea17f0a49 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py @@ -323,10 +323,9 @@ async def _task_progress_change_handler(self, event: str) -> None: with log_catch(_logger, reraise=False): task_progress_event = TaskProgressEvent.parse_raw(event) _logger.debug("received task progress update: %s", task_progress_event) - *_, user_id, project_id, node_id = parse_dask_job_id( - task_progress_event.job_id - ) - + user_id = task_progress_event.task_owner.user_id + project_id = task_progress_event.task_owner.project_id + node_id = task_progress_event.task_owner.node_id comp_tasks_repo = CompTasksRepository(self.db_engine) task = await comp_tasks_repo.get_task(project_id, node_id) if task.progress is None: @@ -355,12 +354,22 @@ async def _task_log_change_handler(self, event: str) -> None: with log_catch(_logger, reraise=False): task_log_event = TaskLogEvent.parse_raw(event) _logger.debug("received task log update: %s", task_log_event) - *_, user_id, project_id, node_id = parse_dask_job_id(task_log_event.job_id) await publish_service_log( self.rabbitmq_client, - user_id=user_id, - project_id=project_id, - node_id=node_id, + user_id=task_log_event.task_owner.user_id, + project_id=task_log_event.task_owner.project_id, + node_id=task_log_event.task_owner.node_id, log=task_log_event.log, log_level=task_log_event.log_level, ) + if task_log_event.task_owner.has_parent: + assert task_log_event.task_owner.parent_project_id # nosec + assert task_log_event.task_owner.parent_node_id # nosec + await publish_service_log( + self.rabbitmq_client, + user_id=task_log_event.task_owner.user_id, + project_id=task_log_event.task_owner.parent_project_id, + node_id=task_log_event.task_owner.parent_node_id, + log=task_log_event.log, + log_level=task_log_event.log_level, + ) From b81bfec87377e0c49423efb1e7db0b00350ae4ff Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:31:23 +0100 Subject: [PATCH 18/21] fixing --- ...t_modules_comp_scheduler_dask_scheduler.py | 66 +++++++++++++++++-- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 164a1f5f090..3e12b32fd01 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -25,6 +25,7 @@ from dask_task_models_library.container_tasks.errors import TaskCancelledError from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import TaskOutputData +from dask_task_models_library.container_tasks.protocol import TaskOwner from faker import Faker from fastapi.applications import FastAPI from models_library.clusters import DEFAULT_CLUSTER_ID @@ -39,6 +40,7 @@ RabbitResourceTrackingStartedMessage, RabbitResourceTrackingStoppedMessage, ) +from models_library.users import UserID from pydantic import parse_obj_as, parse_raw_as from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict @@ -573,8 +575,25 @@ async def _send_computation_tasks( mocked_dask_client.send_computation_tasks.side_effect = _send_computation_tasks -async def _trigger_progress_event(scheduler: BaseCompScheduler, *, job_id: str) -> None: - event = TaskProgressEvent(job_id=job_id, progress=0) +async def _trigger_progress_event( + scheduler: BaseCompScheduler, + *, + job_id: str, + user_id: UserID, + project_id: ProjectID, + node_id: NodeID, +) -> None: + event = TaskProgressEvent( + job_id=job_id, + progress=0, + task_owner=TaskOwner( + user_id=user_id, + project_id=project_id, + node_id=node_id, + parent_project_id=None, + parent_node_id=None, + ), + ) await cast(DaskScheduler, scheduler)._task_progress_change_handler( # noqa: SLF001 event.json() ) @@ -659,7 +678,16 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # 3. the "worker" starts processing a task # here we trigger a progress from the worker assert exp_started_task.job_id - await _trigger_progress_event(scheduler, job_id=exp_started_task.job_id) + assert exp_started_task.project_id + assert exp_started_task.node_id + assert published_project.project.prj_owner + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) await run_comp_scheduler(scheduler) # comp_run, the comp_task switch to STARTED @@ -809,7 +837,13 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_running # trigger the scheduler, run state should keep to STARTED, task should be as well assert exp_started_task.job_id - await _trigger_progress_event(scheduler, job_id=exp_started_task.job_id) + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) await run_comp_scheduler(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) await _assert_comp_tasks_db( @@ -956,11 +990,22 @@ async def test_task_progress_triggers( # send some progress started_task = expected_pending_tasks[0] assert started_task.job_id + assert published_project.project.prj_owner for progress in [-1, 0, 0.3, 0.5, 1, 1.5, 0.7, 0, 20]: progress_event = TaskProgressEvent( - job_id=started_task.job_id, progress=progress + job_id=started_task.job_id, + progress=progress, + task_owner=TaskOwner( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + node_id=started_task.node_id, + parent_node_id=None, + parent_project_id=None, + ), ) - await cast(DaskScheduler, scheduler)._task_progress_change_handler( + await cast( + DaskScheduler, scheduler + )._task_progress_change_handler( # noqa: SLF001 progress_event.json() ) # NOTE: not sure whether it should switch to STARTED.. it would make sense @@ -1252,7 +1297,14 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running assert exp_started_task.job_id - await _trigger_progress_event(scheduler, job_id=exp_started_task.job_id) + assert published_project.project.prj_owner + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) await run_comp_scheduler(scheduler) messages = await _assert_message_received( From bea6a03b05013897ecb3fa4352eb03ae315d86a2 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 18:16:12 +0100 Subject: [PATCH 19/21] missing fixture --- .../tests/container_tasks/test_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/dask-task-models-library/tests/container_tasks/test_events.py b/packages/dask-task-models-library/tests/container_tasks/test_events.py index 528e208bf56..16a308e11e0 100644 --- a/packages/dask-task-models-library/tests/container_tasks/test_events.py +++ b/packages/dask-task-models-library/tests/container_tasks/test_events.py @@ -14,6 +14,7 @@ TaskProgressEvent, ) from dask_task_models_library.container_tasks.protocol import TaskOwner +from faker import Faker from pytest_mock.plugin import MockerFixture @@ -36,6 +37,11 @@ def test_events_models_examples(model_cls): assert model_instance.topic_name() +@pytest.fixture +def job_id(faker: Faker) -> str: + return faker.pystr() + + @pytest.fixture() def mocked_dask_worker_job_id(mocker: MockerFixture, job_id: str) -> str: mock_get_worker = mocker.patch( From 4e6c03c34030bc978ab244ab9fb3ab5c85d30cea Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 18:17:33 +0100 Subject: [PATCH 20/21] linter --- services/dask-sidecar/tests/unit/test_tasks.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/services/dask-sidecar/tests/unit/test_tasks.py b/services/dask-sidecar/tests/unit/test_tasks.py index f2751ff835d..efb073571cd 100644 --- a/services/dask-sidecar/tests/unit/test_tasks.py +++ b/services/dask-sidecar/tests/unit/test_tasks.py @@ -38,11 +38,8 @@ ) from faker import Faker from models_library.basic_types import EnvVarKey -from models_library.projects import ProjectID -from models_library.projects_nodes_io import NodeID from models_library.services import ServiceDockerData from models_library.services_resources import BootMode -from models_library.users import UserID from packaging import version from pydantic import AnyUrl, SecretStr, parse_obj_as from pytest_mock.plugin import MockerFixture @@ -192,9 +189,6 @@ def sleeper_task( boot_mode: BootMode, additional_envs: dict[EnvVarKey, str], faker: Faker, - user_id: UserID, - project_id: ProjectID, - node_id: NodeID, task_owner: TaskOwner, s3_settings: S3Settings, ) -> ServiceExampleParam: From faadab76bb9a4f2b33e15e01bf17a0cde0991551 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 21 Nov 2023 18:18:43 +0100 Subject: [PATCH 21/21] ruff --- services/dask-sidecar/tests/unit/test_file_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/services/dask-sidecar/tests/unit/test_file_utils.py b/services/dask-sidecar/tests/unit/test_file_utils.py index 6f16e962779..2ca64008962 100644 --- a/services/dask-sidecar/tests/unit/test_file_utils.py +++ b/services/dask-sidecar/tests/unit/test_file_utils.py @@ -5,16 +5,16 @@ import asyncio import mimetypes import zipfile +from collections.abc import AsyncIterable from dataclasses import dataclass from pathlib import Path -from typing import Any, AsyncIterable, cast +from typing import Any, cast from unittest import mock import fsspec import pytest from faker import Faker from pydantic import AnyUrl, parse_obj_as -from pytest import FixtureRequest from pytest_localftpserver.servers import ProcessFTPServer from pytest_mock.plugin import MockerFixture from settings_library.s3 import S3Settings @@ -79,7 +79,7 @@ class StorageParameters: @pytest.fixture(params=["ftp", "s3"]) def remote_parameters( - request: FixtureRequest, + request: pytest.FixtureRequest, ftp_remote_file_url: AnyUrl, s3_remote_file_url: AnyUrl, s3_settings: S3Settings, @@ -314,9 +314,8 @@ async def test_pull_compressed_zip_file_from_remote( mode="wb", **storage_kwargs, ), - ) as dest_fp: - with local_zip_file_path.open("rb") as src_fp: - dest_fp.write(src_fp.read()) + ) as dest_fp, local_zip_file_path.open("rb") as src_fp: + dest_fp.write(src_fp.read()) # now we want to download that file so it becomes the source src_url = destination_url