Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#47] Adding filesystem support for save_df #48

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions raydar/task_tracker/task_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import logging
import os
from collections.abc import Iterable
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Type

import coolname
import pandas as pd
import polars as pl
import pyarrow.fs as fs
import pyarrow.parquet as pq
import ray
from packaging.version import Version
from ray.serve import shutdown
Expand Down Expand Up @@ -88,8 +90,9 @@ def __init__(
self,
name: str,
namespace: str,
path: Optional[str] = None,
enable_perspective_dashboard: bool = False,
filesystem: Type[fs.FileSystem] = fs.LocalFileSystem,
filesystem_kwargs: Optional[dict] = None,
):
"""An async Ray Actor Class to track task level metadata.

Expand All @@ -114,13 +117,13 @@ def __init__(
lifetime="detached",
get_if_exists=True,
).remote(name, namespace)
self.path = path
self.df = None
self.finished_tasks = {}
self.user_defined_metadata = {}
self.perspective_dashboard_enabled = enable_perspective_dashboard
self.pending_tasks = []
self.perspective_table_name = f"{name}_data"
self.filesystem = filesystem(**(filesystem_kwargs or dict()))

# WARNING: Do not move this import. Importing these modules elsewhere can cause
# difficult to diagnose, "There is no current event loop in thread 'ray_client_server_" errors.
Expand Down Expand Up @@ -306,14 +309,10 @@ def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
return self.proxy_server
raise Exception("This task_tracker has no active proxy_server.")

def save_df(self) -> None:
"""Saves the internally maintained dataframe of task related information from the ray GCS"""
self.get_df()
if self.path is not None and self.df is not None:
logger.info(f"Writing DataFrame to {self.path}")
self.df.write_parquet(self.path)
return True
return False
def save_df(self, path: str) -> None:
"""Saves the internally maintained dataframe of task related information from the ray GCS to a provided path, using the filesystem attribute"""
logger.info(f"Writing DataFrame to {path}")
pq.write_table(self.get_df().to_arrow(), path, filesystem=self.filesystem)

def clear_df(self) -> None:
"""Clears the internally maintained dataframe of task related information from the ray GCS"""
Expand Down Expand Up @@ -363,9 +362,9 @@ def get_df(self, process_user_metadata_column=False) -> pl.DataFrame:
return df_with_user_metadata
return df

def save_df(self) -> None:
def save_df(self, path: str) -> None:
"""Save the dataframe used by this object's AsyncMetadataTracker actor"""
return ray.get(self.tracker.save_df.remote())
return ray.get(self.tracker.save_df.remote(path))

def clear(self) -> None:
"""Clear the dataframe used by this object's AsyncMetadataTracker actor"""
Expand Down
15 changes: 15 additions & 0 deletions raydar/tests/test_task_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import tempfile
import time

import pandas as pd
import pytest
import ray
import requests
Expand Down Expand Up @@ -39,3 +42,15 @@ def test_get_proxy_server(self):
time.sleep(2)
response = requests.get("http://localhost:8000/tables")
assert eval(response.text) == ["test_table"]

def test_save_df(self):
task_tracker = RayTaskTracker()
refs = [do_some_work.remote() for _ in range(100)]
task_tracker.process(refs)
_ = ray.get(refs)
df = task_tracker.get_df()
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "output_dir")
task_tracker.save_df(path)
loaded_df = pd.read_parquet(path)
assert loaded_df.equals(df.to_pandas())