Skip to content

Commit

Permalink
🚩 Generalize region subset method to DataFrames
Browse files Browse the repository at this point in the history
Make bounding box subsetting work on DataFrames too! This includes pandas, dask and cudf DataFrames. Included a parametrized test for pandas and dask, the cudf one should work too since the APIs are similar. The original xarray.DataArray subsetter code will still work.
  • Loading branch information
weiji14 committed Aug 19, 2020
1 parent 8c2bf32 commit eb61ff6
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion atl11_play.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
"# Do the actual computation to find data points within region of interest\n",
"placename: str = \"kamb\" # Select Kamb Ice Stream region\n",
"region: deepicedrain.Region = regions[placename]\n",
"ds_subset: xr.Dataset = region.subset(ds=ds)\n",
"ds_subset: xr.Dataset = region.subset(data=ds)\n",
"ds_subset = ds_subset.unify_chunks()\n",
"ds_subset = ds_subset.compute()"
]
Expand Down
2 changes: 1 addition & 1 deletion atl11_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
# Do the actual computation to find data points within region of interest
placename: str = "kamb" # Select Kamb Ice Stream region
region: deepicedrain.Region = regions[placename]
ds_subset: xr.Dataset = region.subset(ds=ds)
ds_subset: xr.Dataset = region.subset(data=ds)
ds_subset = ds_subset.unify_chunks()
ds_subset = ds_subset.compute()

Expand Down
4 changes: 2 additions & 2 deletions atlxi_dhdt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
"# Subset dataset to geographic region of interest\n",
"placename: str = \"antarctica\"\n",
"region: deepicedrain.Region = regions[placename]\n",
"# ds = region.subset(ds=ds)"
"# ds = region.subset(data=ds)"
]
},
{
Expand Down Expand Up @@ -901,7 +901,7 @@
"region: deepicedrain.Region = regions[placename]\n",
"if not os.path.exists(f\"ATLXI/df_dhdt_{placename}.parquet\"):\n",
" # Subset dataset to geographic region of interest\n",
" ds_subset: xr.Dataset = region.subset(ds=ds_dhdt)\n",
" ds_subset: xr.Dataset = region.subset(data=ds_dhdt)\n",
" # Add a UTC_time column to the dataframe\n",
" ds_subset[\"utc_time\"] = deepicedrain.deltatime_to_utctime(\n",
" dataarray=ds_subset.delta_ds_subsettime\n",
Expand Down
4 changes: 2 additions & 2 deletions atlxi_dhdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
# Subset dataset to geographic region of interest
placename: str = "antarctica"
region: deepicedrain.Region = regions[placename]
# ds = region.subset(ds=ds)
# ds = region.subset(data=ds)

# %%
# We need at least 2 points to draw a trend line or compute differences
Expand Down Expand Up @@ -398,7 +398,7 @@
region: deepicedrain.Region = regions[placename]
if not os.path.exists(f"ATLXI/df_dhdt_{placename}.parquet"):
# Subset dataset to geographic region of interest
ds_subset: xr.Dataset = region.subset(ds=ds_dhdt)
ds_subset: xr.Dataset = region.subset(data=ds_dhdt)
# Add a UTC_time column to the dataframe
ds_subset["utc_time"] = deepicedrain.deltatime_to_utctime(
dataarray=ds_subset.delta_ds_subsettime
Expand Down
20 changes: 14 additions & 6 deletions deepicedrain/spatiotemporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,26 @@ def datashade(
)

def subset(
self, ds: xr.Dataset, x_dim: str = "x", y_dim: str = "y", drop: bool = True
self, data: xr.Dataset, x_dim: str = "x", y_dim: str = "y", drop: bool = True
) -> xr.Dataset:
"""
Convenience function to find datapoints in an xarray.Dataset
that fit within the bounding boxes of this region
Convenience function to find datapoints in an xarray.Dataset or
pandas.DataFrame that fit within the bounding boxes of this region.
Note that the 'drop' boolean flag is only valid for xarray.Dataset.
"""
cond = np.logical_and(
np.logical_and(ds[x_dim] > self.xmin, ds[x_dim] < self.xmax),
np.logical_and(ds[y_dim] > self.ymin, ds[y_dim] < self.ymax),
np.logical_and(data[x_dim] > self.xmin, data[x_dim] < self.xmax),
np.logical_and(data[y_dim] > self.ymin, data[y_dim] < self.ymax),
)

return ds.where(cond=cond, drop=drop)
try:
# xarray.DataArray subset method
data_subset = data.where(cond=cond, drop=drop)
except TypeError:
# pandas.DataFrame subset method
data_subset = data.loc[cond]

return data_subset


def deltatime_to_utctime(
Expand Down
28 changes: 25 additions & 3 deletions deepicedrain/tests/test_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import xarray as xr

import dask.dataframe
from deepicedrain import Region, catalog, lonlat_to_xy


Expand Down Expand Up @@ -51,7 +52,7 @@ def test_region_datashade():

atl11_dataset: xr.Dataset = catalog.test_data.atl11_test_case.to_dask()
atl11_dataset["x"], atl11_dataset["y"] = lonlat_to_xy(
longitude=atl11_dataset.longitude, latitude=atl11_dataset.latitude, epsg=3995,
longitude=atl11_dataset.longitude, latitude=atl11_dataset.latitude, epsg=3995
)
atl11_dataset = atl11_dataset.set_coords(["x", "y"])
df: pd.DataFrame = atl11_dataset.h_corr.to_dataframe()
Expand All @@ -64,7 +65,7 @@ def test_region_datashade():
npt.assert_allclose(agg_grid.max(), 1798.066285)


def test_region_subset():
def test_region_subset_xarray_dataset():
"""
Test that we can subset an xarray.Dataset based on the region's bounds
"""
Expand All @@ -76,6 +77,27 @@ def test_region_subset():
"y": np.linspace(start=-160, stop=160, num=50),
},
)
ds_subset = region.subset(ds=dataset)
ds_subset = region.subset(data=dataset)
assert isinstance(ds_subset, xr.Dataset)
assert ds_subset.h_corr.shape == (24, 30)


@pytest.mark.parametrize("dataframe_type", [pd.DataFrame, dask.dataframe.DataFrame])
def test_region_subset_dataframe(dataframe_type):
"""
Test that we can subset a pandas or dask DataFrame based on the region's
bounds
"""
region = Region("South Pole", -100, 100, -100, 100)
dataframe = pd.DataFrame(
data={
"x": np.linspace(start=-200, stop=200, num=50),
"y": np.linspace(start=-160, stop=160, num=50),
"dhdt": np.random.rand(50),
}
)
if dataframe_type == dask.dataframe.core.DataFrame:
dataframe = dask.dataframe.from_pandas(data=dataframe, npartitions=2)
df_subset = region.subset(data=dataframe)
assert isinstance(df_subset, dataframe_type)
assert len(df_subset.dhdt) == 24

0 comments on commit eb61ff6

Please sign in to comment.