From 01e672bb514060ade66a8a1b6622fc77cb034e50 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Wed, 19 Apr 2023 14:28:40 +0200 Subject: [PATCH] fix return value of get_repository --- .../modules/comp_scheduler/base_scheduler.py | 18 ++++++++---------- .../modules/comp_scheduler/dask_scheduler.py | 4 ++-- .../modules/comp_scheduler/factory.py | 5 +---- .../utils/scheduler.py | 8 +++++--- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py index 5cf5cb2aae04..8ea0a4bb9b12 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py @@ -15,7 +15,6 @@ import traceback from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import cast import networkx as nx from aiopg.sa.engine import Engine @@ -93,7 +92,7 @@ async def run_new_pipeline( runs_repo: CompRunsRepository = get_repository( self.db_engine, CompRunsRepository - ) # type: ignore + ) new_run: CompRunsAtDB = await runs_repo.create( user_id=user_id, project_id=project_id, @@ -153,7 +152,7 @@ async def schedule_all_pipelines(self) -> None: async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: comp_pipeline_repo: CompPipelinesRepository = get_repository( self.db_engine, CompPipelinesRepository - ) # type: ignore + ) pipeline_at_db: CompPipelineAtDB = await comp_pipeline_repo.get_pipeline( project_id ) @@ -166,7 +165,7 @@ async def _get_pipeline_tasks( ) -> dict[str, CompTaskAtDB]: comp_tasks_repo: CompTasksRepository = get_repository( self.db_engine, CompTasksRepository - ) # type: ignore + ) pipeline_comp_tasks: dict[str, CompTaskAtDB] = { f"{t.node_id}": t for t in await comp_tasks_repo.get_comp_tasks(project_id) @@ -203,7 +202,7 @@ async def _set_run_result( ) -> None: comp_runs_repo: CompRunsRepository = get_repository( self.db_engine, CompRunsRepository - ) # type: ignore + ) await comp_runs_repo.set_run_result( user_id=user_id, project_id=project_id, @@ -225,9 +224,8 @@ async def _set_states_following_failed_to_aborted( tasks[f"{task}"].state = RunningState.ABORTED if tasks_to_set_aborted: # update the current states back in DB - comp_tasks_repo: CompTasksRepository = cast( - CompTasksRepository, - get_repository(self.db_engine, CompTasksRepository), + comp_tasks_repo: CompTasksRepository = get_repository( + self.db_engine, CompTasksRepository ) await comp_tasks_repo.set_project_tasks_state( project_id, @@ -426,7 +424,7 @@ async def _schedule_tasks_to_stop( # get any running task and stop them comp_tasks_repo: CompTasksRepository = get_repository( self.db_engine, CompTasksRepository - ) # type: ignore + ) await comp_tasks_repo.mark_project_published_tasks_as_aborted(project_id) # stop any remaining running task, these are already submitted tasks_to_stop = [ @@ -469,7 +467,7 @@ async def _schedule_tasks_to_start( # Change the tasks state to PENDING comp_tasks_repo: CompTasksRepository = get_repository( self.db_engine, CompTasksRepository - ) # type: ignore + ) await comp_tasks_repo.set_project_tasks_state( project_id, list(tasks_ready_to_start.keys()), RunningState.PENDING ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py index 9088c0962811..48e919e40b9a 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py @@ -49,7 +49,7 @@ async def _cluster_dask_client( if cluster_id != DEFAULT_CLUSTER_ID: clusters_repo: ClustersRepository = get_repository( scheduler.db_engine, ClustersRepository - ) # type: ignore + ) cluster = await clusters_repo.get_cluster(user_id, cluster_id) async with scheduler.dask_clients_pool.acquire(cluster) as client: yield client @@ -94,7 +94,7 @@ async def _start_tasks( # update the database so we do have the correct job_ids there comp_tasks_repo: CompTasksRepository = get_repository( self.db_engine, CompTasksRepository - ) # type: ignore + ) await asyncio.gather( *[ comp_tasks_repo.set_project_task_job_id(project_id, node_id, job_id) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/factory.py index 87346d69aa8e..72118e12cded 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/factory.py @@ -1,5 +1,4 @@ import logging -from typing import cast from fastapi import FastAPI from models_library.clusters import DEFAULT_CLUSTER_ID @@ -22,9 +21,7 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: "Database connection is missing. Please check application configuration." ) db_engine = app.state.engine - runs_repository: CompRunsRepository = cast( - CompRunsRepository, get_repository(db_engine, CompRunsRepository) - ) + runs_repository: CompRunsRepository = get_repository(db_engine, CompRunsRepository) # get currently scheduled runs runs: list[CompRunsAtDB] = await runs_repository.list( diff --git a/services/director-v2/src/simcore_service_director_v2/utils/scheduler.py b/services/director-v2/src/simcore_service_director_v2/utils/scheduler.py index 716886952028..211dd003442e 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/scheduler.py @@ -1,9 +1,9 @@ +from typing import TypeVar + from aiopg.sa.engine import Engine from models_library.projects_state import RunningState from pydantic import PositiveInt -from ..modules.db.repositories import BaseRepository - SCHEDULED_STATES: set[RunningState] = { RunningState.PUBLISHED, RunningState.PENDING, @@ -29,8 +29,10 @@ RunningState.UNKNOWN, } +RepoType = TypeVar("RepoType") + -def get_repository(db_engine: Engine, repo_cls: type[BaseRepository]) -> BaseRepository: +def get_repository(db_engine: Engine, repo_cls: type[RepoType]) -> RepoType: return repo_cls(db_engine=db_engine)