From e33259486381da0506a26d73d21e1f63ac041756 Mon Sep 17 00:00:00 2001 From: Daniel Gafni Date: Tue, 12 Nov 2024 13:47:08 +0000 Subject: [PATCH] :rocket: dagster-ray v0.0.10 --- dagster_ray/kuberay/client/base.py | 11 +++++++---- dagster_ray/kuberay/client/raycluster/client.py | 8 +------- dagster_ray/kuberay/client/rayjob/client.py | 14 ++++---------- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/dagster_ray/kuberay/client/base.py b/dagster_ray/kuberay/client/base.py index 1bc8bb4..f51160f 100644 --- a/dagster_ray/kuberay/client/base.py +++ b/dagster_ray/kuberay/client/base.py @@ -1,5 +1,5 @@ import time -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar if TYPE_CHECKING: from kubernetes import client @@ -18,7 +18,10 @@ def load_kubeconfig(context: Optional[str] = None, config_file: Optional[str] = pass -class BaseKubeRayClient: +T_Status = TypeVar("T_Status") + + +class BaseKubeRayClient(Generic[T_Status]): def __init__( self, group: str, @@ -37,7 +40,7 @@ def __init__( self._api = client.CustomObjectsApi(api_client=api_client) self._core_v1_api = client.CoreV1Api(api_client=api_client) - def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 60): + def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 600): from kubernetes.client import ApiException start_time = time.time() @@ -63,7 +66,7 @@ def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_int time.sleep(poll_interval) - def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> Dict[str, Any]: + def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> T_Status: from kubernetes.client import ApiException while timeout > 0: diff --git a/dagster_ray/kuberay/client/raycluster/client.py b/dagster_ray/kuberay/client/raycluster/client.py index 12953bf..46dfc6c 100644 --- a/dagster_ray/kuberay/client/raycluster/client.py +++ b/dagster_ray/kuberay/client/raycluster/client.py @@ -77,7 +77,7 @@ class RayClusterStatus(TypedDict): state: NotRequired[str] -class RayClusterClient(BaseKubeRayClient): +class RayClusterClient(BaseKubeRayClient[RayClusterStatus]): def __init__( self, config_file: Optional[str] = None, @@ -91,12 +91,6 @@ def __init__( self.config_file = config_file self.context = context - def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayClusterStatus: # type: ignore - return cast( - RayClusterStatus, - super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval), - ) - def wait_until_ready( self, name: str, diff --git a/dagster_ray/kuberay/client/rayjob/client.py b/dagster_ray/kuberay/client/rayjob/client.py index adfdb02..d553c24 100644 --- a/dagster_ray/kuberay/client/rayjob/client.py +++ b/dagster_ray/kuberay/client/rayjob/client.py @@ -1,6 +1,6 @@ import logging import time -from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast +from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict from typing_extensions import NotRequired @@ -31,7 +31,7 @@ class RayJobStatus(TypedDict): message: NotRequired[str] -class RayJobClient(BaseKubeRayClient): +class RayJobClient(BaseKubeRayClient[RayJobStatus]): def __init__( self, config_file: Optional[str] = None, @@ -46,12 +46,6 @@ def __init__( super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client) - def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore - return cast( - RayJobStatus, - super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval), - ) - def get_ray_cluster_name(self, name: str, namespace: str) -> str: return self.get_status(name, namespace)["rayClusterName"] @@ -66,7 +60,7 @@ def wait_until_running( self, name: str, namespace: str, - timeout: int = 300, + timeout: int = 600, poll_interval: int = 5, ) -> bool: start_time = time.time() @@ -103,7 +97,7 @@ def _wait_for_job_submission( self, name: str, namespace: str, - timeout: int = 300, + timeout: int = 600, poll_interval: int = 10, ): start_time = time.time()