diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 3fc32336..8b03e5e1 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -83,7 +83,9 @@ def initialize_settings(self): dask_client_future=dask_client_future, ) - job_files_manager = self.job_files_manager_class(scheduler=scheduler) + job_files_manager = self.job_files_manager_class( + scheduler=scheduler, dask_client_future=dask_client_future + ) self.settings.update( environments_manager=environments_manager, diff --git a/jupyter_scheduler/handlers.py b/jupyter_scheduler/handlers.py index 8e773b75..34b1b281 100644 --- a/jupyter_scheduler/handlers.py +++ b/jupyter_scheduler/handlers.py @@ -402,7 +402,7 @@ def job_files_manager(self): if not self._job_files_manager: self._job_files_manager = self.settings.get("job_files_manager", None) - return self._job_files_manager + return self._job_files_managerdela @authenticated async def get(self, job_id): diff --git a/jupyter_scheduler/job_files_manager.py b/jupyter_scheduler/job_files_manager.py index 0e39c2b7..878c0fbf 100644 --- a/jupyter_scheduler/job_files_manager.py +++ b/jupyter_scheduler/job_files_manager.py @@ -1,10 +1,10 @@ import os import random import tarfile -from multiprocessing import Process -from typing import Dict, List, Optional, Type +from typing import Awaitable, Dict, List, Optional, Type import fsspec +from dask.distributed import Client as DaskClient from jupyter_server.utils import ensure_async from jupyter_scheduler.exceptions import SchedulerError @@ -14,8 +14,13 @@ class JobFilesManager: scheduler = None - def __init__(self, scheduler: Type[BaseScheduler]): + def __init__( + self, + scheduler: Type[BaseScheduler], + dask_client_future: Awaitable[DaskClient], + ): self.scheduler = scheduler + self.dask_client_future = dask_client_future async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = False): job = await ensure_async(self.scheduler.get_job(job_id, False)) @@ -23,8 +28,9 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals output_filenames = self.scheduler.get_job_filenames(job) output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True) - p = Process( - target=Downloader( + dask_client: DaskClient = await self.dask_client_future + dask_client.submit( + Downloader( output_formats=job.output_formats, output_filenames=output_filenames, staging_paths=staging_paths, @@ -33,7 +39,6 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals include_staging_files=job.package_input_folder, ).download ) - p.start() class Downloader: