Skip to content

Commit

Permalink
fix return value of get_repository
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Apr 19, 2023
1 parent 58c2855 commit 01e672b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import cast

import networkx as nx
from aiopg.sa.engine import Engine
Expand Down Expand Up @@ -93,7 +92,7 @@ async def run_new_pipeline(

runs_repo: CompRunsRepository = get_repository(
self.db_engine, CompRunsRepository
) # type: ignore
)
new_run: CompRunsAtDB = await runs_repo.create(
user_id=user_id,
project_id=project_id,
Expand Down Expand Up @@ -153,7 +152,7 @@ async def schedule_all_pipelines(self) -> None:
async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph:
comp_pipeline_repo: CompPipelinesRepository = get_repository(
self.db_engine, CompPipelinesRepository
) # type: ignore
)
pipeline_at_db: CompPipelineAtDB = await comp_pipeline_repo.get_pipeline(
project_id
)
Expand All @@ -166,7 +165,7 @@ async def _get_pipeline_tasks(
) -> dict[str, CompTaskAtDB]:
comp_tasks_repo: CompTasksRepository = get_repository(
self.db_engine, CompTasksRepository
) # type: ignore
)
pipeline_comp_tasks: dict[str, CompTaskAtDB] = {
f"{t.node_id}": t
for t in await comp_tasks_repo.get_comp_tasks(project_id)
Expand Down Expand Up @@ -203,7 +202,7 @@ async def _set_run_result(
) -> None:
comp_runs_repo: CompRunsRepository = get_repository(
self.db_engine, CompRunsRepository
) # type: ignore
)
await comp_runs_repo.set_run_result(
user_id=user_id,
project_id=project_id,
Expand All @@ -225,9 +224,8 @@ async def _set_states_following_failed_to_aborted(
tasks[f"{task}"].state = RunningState.ABORTED
if tasks_to_set_aborted:
# update the current states back in DB
comp_tasks_repo: CompTasksRepository = cast(
CompTasksRepository,
get_repository(self.db_engine, CompTasksRepository),
comp_tasks_repo: CompTasksRepository = get_repository(
self.db_engine, CompTasksRepository
)
await comp_tasks_repo.set_project_tasks_state(
project_id,
Expand Down Expand Up @@ -426,7 +424,7 @@ async def _schedule_tasks_to_stop(
# get any running task and stop them
comp_tasks_repo: CompTasksRepository = get_repository(
self.db_engine, CompTasksRepository
) # type: ignore
)
await comp_tasks_repo.mark_project_published_tasks_as_aborted(project_id)
# stop any remaining running task, these are already submitted
tasks_to_stop = [
Expand Down Expand Up @@ -469,7 +467,7 @@ async def _schedule_tasks_to_start(
# Change the tasks state to PENDING
comp_tasks_repo: CompTasksRepository = get_repository(
self.db_engine, CompTasksRepository
) # type: ignore
)
await comp_tasks_repo.set_project_tasks_state(
project_id, list(tasks_ready_to_start.keys()), RunningState.PENDING
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def _cluster_dask_client(
if cluster_id != DEFAULT_CLUSTER_ID:
clusters_repo: ClustersRepository = get_repository(
scheduler.db_engine, ClustersRepository
) # type: ignore
)
cluster = await clusters_repo.get_cluster(user_id, cluster_id)
async with scheduler.dask_clients_pool.acquire(cluster) as client:
yield client
Expand Down Expand Up @@ -94,7 +94,7 @@ async def _start_tasks(
# update the database so we do have the correct job_ids there
comp_tasks_repo: CompTasksRepository = get_repository(
self.db_engine, CompTasksRepository
) # type: ignore
)
await asyncio.gather(
*[
comp_tasks_repo.set_project_task_job_id(project_id, node_id, job_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import cast

from fastapi import FastAPI
from models_library.clusters import DEFAULT_CLUSTER_ID
Expand All @@ -22,9 +21,7 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler:
"Database connection is missing. Please check application configuration."
)
db_engine = app.state.engine
runs_repository: CompRunsRepository = cast(
CompRunsRepository, get_repository(db_engine, CompRunsRepository)
)
runs_repository: CompRunsRepository = get_repository(db_engine, CompRunsRepository)

# get currently scheduled runs
runs: list[CompRunsAtDB] = await runs_repository.list(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import TypeVar

from aiopg.sa.engine import Engine
from models_library.projects_state import RunningState
from pydantic import PositiveInt

from ..modules.db.repositories import BaseRepository

SCHEDULED_STATES: set[RunningState] = {
RunningState.PUBLISHED,
RunningState.PENDING,
Expand All @@ -29,8 +29,10 @@
RunningState.UNKNOWN,
}

RepoType = TypeVar("RepoType")


def get_repository(db_engine: Engine, repo_cls: type[BaseRepository]) -> BaseRepository:
def get_repository(db_engine: Engine, repo_cls: type[RepoType]) -> RepoType:
return repo_cls(db_engine=db_engine)


Expand Down

0 comments on commit 01e672b

Please sign in to comment.