Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Jan 11, 2024
1 parent 4297e77 commit fcf67da
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 47 deletions.
53 changes: 29 additions & 24 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

DEFAULT_FRAME_LABEL = "result" # A base default label for an Ensemble's result frames.

METADATA_FILENAME = "ensemble_metadata.json"


class Ensemble:
"""Ensemble object is a collection of light curve ids"""
Expand Down Expand Up @@ -1286,9 +1288,9 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
# Determine the path
ens_path = os.path.join(path, dirname)

# First look for an existing metadata.json file in the path
# First look for an existing metadata file in the path
try:
with open(os.path.join(ens_path, "metadata.json"), "r") as oldfile:
with open(os.path.join(ens_path, METADATA_FILENAME), "r") as oldfile:
# Reading from json file
old_metadata = json.load(oldfile)
old_subdirs = old_metadata["subdirs"]
Expand All @@ -1302,7 +1304,7 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
if additional_frames is True:
frames_to_save = list(self.frames.keys()) # save all frames
elif additional_frames is False:
frames_to_save = ["object", "source"] # save just object and source
frames_to_save = [OBJECT_FRAME_LABEL, SOURCE_FRAME_LABEL] # save just object and source
elif isinstance(additional_frames, Iterable):
frames_to_save = set(additional_frames)
invalid_frames = frames_to_save.difference(set(self.frames.keys()))
Expand All @@ -1314,14 +1316,14 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
frames_to_save = list(frames_to_save)

# Make sure object and source are in the frame list
if "object" not in frames_to_save:
frames_to_save.append("object")
if "source" not in frames_to_save:
frames_to_save.append("source")
if OBJECT_FRAME_LABEL not in frames_to_save:
frames_to_save.append(OBJECT_FRAME_LABEL)
if SOURCE_FRAME_LABEL not in frames_to_save:
frames_to_save.append(SOURCE_FRAME_LABEL)
else:
raise ValueError("Invalid input to `additional_frames`, must be boolean or list-like")

# Save the frame list to disk
# Generate the metadata first
created_subdirs = [] # track the list of created subdirectories
divisions_known = [] # log whether divisions were known for each frame
for frame_label in frames_to_save:
Expand All @@ -1331,28 +1333,29 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
# When the frame has no columns, avoid the save as parquet doesn't handle it
# Most commonly this applies to the object table when it's built from source
if len(frame.columns) == 0:
print(f"Frame: {frame_label} was not saved as no columns were present.")
print(f"Frame: {frame_label} will not be saved as no columns are present.")
continue

# creates a subdirectory for the frame partition files
frame.to_parquet(os.path.join(ens_path, frame_label), write_metadata_file=True, **kwargs)
created_subdirs.append(frame_label)
divisions_known.append(frame.known_divisions)

# Save a metadata file
col_map = self.make_column_map() # grab the current column_mapper

metadata = {
"subdirs": created_subdirs,
"known_divisions": divisions_known,
"column_mapper": col_map.map,
}
json_metadata = json.dumps(metadata, indent=4)

with open(os.path.join(ens_path, "metadata.json"), "w") as outfile:
# Make the directory if it doesn't already exist
os.makedirs(ens_path, exist_ok=True)
with open(os.path.join(ens_path, METADATA_FILENAME), "w") as outfile:
outfile.write(json_metadata)

# np.save(os.path.join(ens_path, "column_mapper.npy"), col_map.map)
# Now write out the frames to subdirectories
for subdir in created_subdirs:
self.frames[subdir].to_parquet(os.path.join(ens_path, subdir), write_metadata_file=True, **kwargs)

print(f"Saved to {os.path.join(path, dirname)}")

Expand Down Expand Up @@ -1390,8 +1393,8 @@ def from_ensemble(
The ensemble object.
"""

# Read in the metadata.json file
with open(os.path.join(dirpath, "metadata.json"), "r") as metadatafile:
# Read in the metadata file
with open(os.path.join(dirpath, METADATA_FILENAME), "r") as metadatafile:
# Reading from json file
metadata = json.load(metadatafile)

Expand All @@ -1405,26 +1408,26 @@ def from_ensemble(
# Load Object and Source

# Check for whether or not object is present, it's not saved when no columns are present
if "object" in subdirs:
if OBJECT_FRAME_LABEL in subdirs:
# divisions should be known for both tables to use the sorted kwarg
use_sorted = (
frame_known_divisions[subdirs.index("object")]
and frame_known_divisions[subdirs.index("source")]
frame_known_divisions[subdirs.index(OBJECT_FRAME_LABEL)]
and frame_known_divisions[subdirs.index(SOURCE_FRAME_LABEL)]
)

self.from_parquet(
os.path.join(dirpath, "source"),
os.path.join(dirpath, "object"),
os.path.join(dirpath, SOURCE_FRAME_LABEL),
os.path.join(dirpath, OBJECT_FRAME_LABEL),
column_mapper=column_mapper,
sorted=use_sorted,
sort=False,
sync_tables=False, # a sync should always be performed just before saving
**kwargs,
)
else:
use_sorted = frame_known_divisions[subdirs.index("source")]
use_sorted = frame_known_divisions[subdirs.index(SOURCE_FRAME_LABEL)]
self.from_parquet(
os.path.join(dirpath, "source"),
os.path.join(dirpath, SOURCE_FRAME_LABEL),
column_mapper=column_mapper,
sorted=use_sorted,
sort=False,
Expand All @@ -1446,7 +1449,9 @@ def from_ensemble(

# Filter out object and source from additional frames
frames_to_load = [
frame for frame in frames_to_load if os.path.split(frame)[1] not in ["object", "source"]
frame
for frame in frames_to_load
if os.path.split(frame)[1] not in [OBJECT_FRAME_LABEL, SOURCE_FRAME_LABEL]
]
if len(frames_to_load) > 0:
for frame in frames_to_load:
Expand Down
17 changes: 0 additions & 17 deletions src/tape/ensemble_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ def read_ensemble(
additional_frames=True,
column_mapper=None,
dask_client=True,
additional_cols=True,
partition_size=None,
**kwargs,
):
"""Load an ensemble from an on-disk ensemble.
Expand All @@ -37,19 +35,6 @@ def read_ensemble(
Supplies a ColumnMapper to the Ensemble, if None (default) searches
for a column_mapper.npy file in the directory, which should be
created when the ensemble is saved.
additional_cols: 'bool', optional
Boolean to indicate whether to carry in columns beyond the
critical columns, true will, while false will only load the columns
containing the critical quantities (id,time,flux,err,band)
partition_size: `int`, optional
If specified, attempts to repartition the ensemble to partitions
of size `partition_size`.
sorted: bool, optional
If the index column is already sorted in increasing order.
Defaults to False
sort: `bool`, optional
If True, sorts the DataFrame by the id column. Otherwise set the
index on the individual existing partitions. Defaults to False.
dask_client: `dask.distributed.client` or `bool`, optional
Accepts an existing `dask.distributed.Client`, or creates one if
`client=True`, passing any additional kwargs to a
Expand All @@ -68,8 +53,6 @@ def read_ensemble(
dirpath,
additional_frames=additional_frames,
column_mapper=column_mapper,
additional_cols=additional_cols,
partition_size=partition_size,
**kwargs,
)

Expand Down
57 changes: 52 additions & 5 deletions tests/tape_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def parquet_ensemble_without_client():

return ens


@pytest.fixture
def parquet_files_and_ensemble_without_client():
"""Create an Ensemble from parquet data without a dask client."""
Expand All @@ -246,12 +247,10 @@ def parquet_files_and_ensemble_without_client():
err_col="psFluxErr",
band_col="filterName",
)
ens = ens.from_parquet(
source_file,
object_file,
column_mapper=colmap)
ens = ens.from_parquet(source_file, object_file, column_mapper=colmap)
return ens, source_file, object_file, colmap


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble(dask_client):
Expand All @@ -270,6 +269,25 @@ def parquet_ensemble(dask_client):
return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_partition_size(dask_client):
"""Create an Ensemble from parquet data."""
ens = Ensemble(client=dask_client)
ens.from_parquet(
"tests/tape_tests/data/source/test_source.parquet",
"tests/tape_tests/data/object/test_object.parquet",
id_col="ps1_objid",
time_col="midPointTai",
band_col="filterName",
flux_col="psFlux",
err_col="psFluxErr",
partition_size="1MB",
)

return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_with_divisions(dask_client):
Expand Down Expand Up @@ -386,6 +404,34 @@ def dask_dataframe_ensemble(dask_client):
return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def dask_dataframe_ensemble_partition_size(dask_client):
"""Create an Ensemble from parquet data."""
ens = Ensemble(client=dask_client)

num_points = 1000
all_bands = np.array(["r", "g", "b", "i"])
rows = {
"id": 8000 + (np.arange(num_points) % 5),
"time": np.arange(num_points),
"flux": np.arange(num_points) % len(all_bands),
"band": np.repeat(all_bands, num_points / len(all_bands)),
"err": 0.1 * (np.arange(num_points) % 10),
"count": np.arange(num_points),
"something_else": np.full(num_points, None),
}
cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band")

ens.from_dask_dataframe(
source_frame=dd.from_dict(rows, npartitions=1),
column_mapper=cmap,
partition_size="1MB",
)

return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def dask_dataframe_with_object_ensemble(dask_client):
Expand Down Expand Up @@ -490,6 +536,7 @@ def pandas_with_object_ensemble(dask_client):

return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def ensemble_from_source_dict(dask_client):
Expand All @@ -511,4 +558,4 @@ def ensemble_from_source_dict(dask_client):
cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="error", band_col="band")
ens.from_source_dict(source_dict, column_mapper=cmap)

return ens, source_dict
return ens, source_dict
4 changes: 3 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_with_client():
"parquet_ensemble_from_hipscat",
"parquet_ensemble_with_column_mapper",
"parquet_ensemble_with_known_column_mapper",
"parquet_ensemble_partition_size",
"read_parquet_ensemble",
"read_parquet_ensemble_without_client",
"read_parquet_ensemble_from_source",
Expand Down Expand Up @@ -102,6 +103,7 @@ def test_parquet_construction(data_fixture, request):
"data_fixture",
[
"dask_dataframe_ensemble",
"dask_dataframe_ensemble_partition_size",
"dask_dataframe_with_object_ensemble",
"pandas_ensemble",
"pandas_with_object_ensemble",
Expand Down Expand Up @@ -533,7 +535,7 @@ def test_save_and_load_ensemble(dask_client, tmp_path, add_frames, obj_nocols, u
dircontents = os.listdir(os.path.join(save_path, "ensemble"))

assert "source" in dircontents # Source should always be there
assert "metadata.json" in dircontents # should make a metadata file
assert "ensemble_metadata.json" in dircontents # should make a metadata file
if obj_nocols: # object shouldn't if it was empty
assert "object" not in dircontents
else: # otherwise it should be present
Expand Down

0 comments on commit fcf67da

Please sign in to comment.