Skip to content

Commit

Permalink
🚀 dagster-ray v0.0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Nov 12, 2024
1 parent f08e7bd commit e332594
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 21 deletions.
11 changes: 7 additions & 4 deletions dagster_ray/kuberay/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar

if TYPE_CHECKING:
from kubernetes import client
Expand All @@ -18,7 +18,10 @@ def load_kubeconfig(context: Optional[str] = None, config_file: Optional[str] =
pass


class BaseKubeRayClient:
T_Status = TypeVar("T_Status")


class BaseKubeRayClient(Generic[T_Status]):
def __init__(
self,
group: str,
Expand All @@ -37,7 +40,7 @@ def __init__(
self._api = client.CustomObjectsApi(api_client=api_client)
self._core_v1_api = client.CoreV1Api(api_client=api_client)

def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 60):
def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 600):
from kubernetes.client import ApiException

start_time = time.time()
Expand All @@ -63,7 +66,7 @@ def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_int

time.sleep(poll_interval)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> Dict[str, Any]:
def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> T_Status:
from kubernetes.client import ApiException

while timeout > 0:
Expand Down
8 changes: 1 addition & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RayClusterStatus(TypedDict):
state: NotRequired[str]


class RayClusterClient(BaseKubeRayClient):
class RayClusterClient(BaseKubeRayClient[RayClusterStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -91,12 +91,6 @@ def __init__(
self.config_file = config_file
self.context = context

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayClusterStatus: # type: ignore
return cast(
RayClusterStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def wait_until_ready(
self,
name: str,
Expand Down
14 changes: 4 additions & 10 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict

from typing_extensions import NotRequired

Expand Down Expand Up @@ -31,7 +31,7 @@ class RayJobStatus(TypedDict):
message: NotRequired[str]


class RayJobClient(BaseKubeRayClient):
class RayJobClient(BaseKubeRayClient[RayJobStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -46,12 +46,6 @@ def __init__(

super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
return cast(
RayJobStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def get_ray_cluster_name(self, name: str, namespace: str) -> str:
return self.get_status(name, namespace)["rayClusterName"]

Expand All @@ -66,7 +60,7 @@ def wait_until_running(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 5,
) -> bool:
start_time = time.time()
Expand Down Expand Up @@ -103,7 +97,7 @@ def _wait_for_job_submission(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 10,
):
start_time = time.time()
Expand Down

0 comments on commit e332594

Please sign in to comment.