Skip to content

Commit

Permalink
[#47] Adding filesystem support for save_df
Browse files Browse the repository at this point in the history
...

Signed-off-by: Todd Gaugler <[email protected]>

...

...
  • Loading branch information
gauglertodd committed Oct 23, 2024
1 parent 8868835 commit 0558819
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
25 changes: 12 additions & 13 deletions raydar/task_tracker/task_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import logging
import os
from collections.abc import Iterable
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Type

import coolname
import pandas as pd
import polars as pl
import pyarrow.fs as fs
import pyarrow.parquet as pq
import ray
from packaging.version import Version
from ray.serve import shutdown
Expand Down Expand Up @@ -88,8 +90,9 @@ def __init__(
self,
name: str,
namespace: str,
path: Optional[str] = None,
enable_perspective_dashboard: bool = False,
filesystem: Type[fs.FileSystem] = fs.LocalFileSystem,
filesystem_kwargs: dict = dict(),
):
"""An async Ray Actor Class to track task level metadata.
Expand All @@ -114,13 +117,13 @@ def __init__(
lifetime="detached",
get_if_exists=True,
).remote(name, namespace)
self.path = path
self.df = None
self.finished_tasks = {}
self.user_defined_metadata = {}
self.perspective_dashboard_enabled = enable_perspective_dashboard
self.pending_tasks = []
self.perspective_table_name = f"{name}_data"
self.filesystem = filesystem(**filesystem_kwargs)

# WARNING: Do not move this import. Importing these modules elsewhere can cause
# difficult to diagnose, "There is no current event loop in thread 'ray_client_server_" errors.
Expand Down Expand Up @@ -306,14 +309,10 @@ def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
return self.proxy_server
raise Exception("This task_tracker has no active proxy_server.")

def save_df(self) -> None:
"""Saves the internally maintained dataframe of task related information from the ray GCS"""
self.get_df()
if self.path is not None and self.df is not None:
logger.info(f"Writing DataFrame to {self.path}")
self.df.write_parquet(self.path)
return True
return False
def save_df(self, path: str) -> None:
"""Saves the internally maintained dataframe of task related information from the ray GCS to a provided path, using the filesystem attributed"""
logger.info(f"Writing DataFrame to {path}")
pq.write_table(self.get_df().to_arrow(), path, filesystem=self.filesystem)

def clear_df(self) -> None:
"""Clears the internally maintained dataframe of task related information from the ray GCS"""
Expand Down Expand Up @@ -363,9 +362,9 @@ def get_df(self, process_user_metadata_column=False) -> pl.DataFrame:
return df_with_user_metadata
return df

def save_df(self) -> None:
def save_df(self, path: str) -> None:
"""Save the dataframe used by this object's AsyncMetadataTracker actor"""
return ray.get(self.tracker.save_df.remote())
return ray.get(self.tracker.save_df.remote(path))

def clear(self) -> None:
"""Clear the dataframe used by this object's AsyncMetadataTracker actor"""
Expand Down
15 changes: 15 additions & 0 deletions raydar/tests/test_task_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import tempfile
import time

import pandas as pd
import pytest
import ray
import requests
Expand Down Expand Up @@ -39,3 +42,15 @@ def test_get_proxy_server(self):
time.sleep(2)
response = requests.get("http://localhost:8000/tables")
assert eval(response.text) == ["test_table"]

def test_save_df(self):
task_tracker = RayTaskTracker()
refs = [do_some_work.remote() for _ in range(100)]
task_tracker.process(refs)
_ = ray.get(refs)
df = task_tracker.get_df()
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "output_dir")
task_tracker.save_df(path)
loaded_df = pd.read_parquet(path)
assert loaded_df.equals(df.to_pandas())

0 comments on commit 0558819

Please sign in to comment.