Skip to content

Commit

Permalink
Added a new fast_time_slicing parameter. If True, Xee performs an o…
Browse files Browse the repository at this point in the history
…ptimization that makes slicing an ImageCollection across time faster. This optimization loads EE images in a slice by ID, so any modifications to images in a computed ImageCollection will not be reflected.

Also adds several new warnings:

- if a user enables `fast_time_slicing` but there are no image IDs, and
- if a user is indexing into a very large ImageCollection.

PiperOrigin-RevId: 623280839
  • Loading branch information
naschmitz authored and Xee authors committed Apr 9, 2024
1 parent 7fe930c commit 5fe56bb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
39 changes: 34 additions & 5 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import functools
import importlib
import itertools
import logging
import math
import os
import sys
Expand Down Expand Up @@ -72,6 +73,12 @@
# trial & error.
REQUEST_BYTE_LIMIT = 2**20 * 48 # 48 MBs

# Xee uses the ee.ImageCollection.toList function for slicing into an
# ImageCollection. This function isn't optimized for large collections. If the
# end index of the slice is beyond 10k, display a warning to the user. This
# value was chosen by trial and error.
_TO_LIST_WARNING_LIMIT = 10000


def _check_request_limit(chunks: Dict[str, int], dtype_size: int, limit: int):
"""Checks that the actual number of bytes exceeds the limit."""
Expand Down Expand Up @@ -153,6 +160,7 @@ def open(
ee_init_if_necessary: bool = False,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
) -> 'EarthEngineStore':
if mode != 'r':
raise ValueError(
Expand All @@ -175,6 +183,7 @@ def open(
ee_init_if_necessary=ee_init_if_necessary,
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
)

def __init__(
Expand All @@ -194,9 +203,11 @@ def __init__(
ee_init_if_necessary: bool = False,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
):
self.ee_init_kwargs = ee_init_kwargs
self.ee_init_if_necessary = ee_init_if_necessary
self.fast_time_slicing = fast_time_slicing

# Initialize executor_kwargs
if executor_kwargs is None:
Expand Down Expand Up @@ -834,15 +845,27 @@ def _slice_collection(self, image_slice: slice) -> ee.Image:
self._ee_init_check()
start, stop, stride = image_slice.indices(self.shape[0])

# If the input images have IDs, just slice them. Otherwise, we need to do
# an expensive `toList()` operation.
if self.store.image_ids:
if self.store.fast_time_slicing and self.store.image_ids:
imgs = self.store.image_ids[start:stop:stride]
else:
if self.store.fast_time_slicing:
logging.warning(
"fast_time_slicing is enabled but ImageCollection images don't have"
' IDs. Reverting to default behavior.'
)
if stop > _TO_LIST_WARNING_LIMIT:
logging.warning(
'Xee is indexing into the ImageCollection beyond %s images. This'
' operation can be slow. To improve performance, consider filtering'
' the ImageCollection prior to using Xee or enabling'
' fast_time_slicing.',
_TO_LIST_WARNING_LIMIT,
)
# TODO(alxr, mahrsee): Find a way to make this case more efficient.
list_range = stop - start
col0 = self.store.image_collection
imgs = col0.toList(list_range, offset=start).slice(0, list_range, stride)
imgs = self.store.image_collection.toList(list_range, offset=start).slice(
0, list_range, stride
)

col = ee.ImageCollection(imgs)

Expand Down Expand Up @@ -1006,6 +1029,7 @@ def open_dataset(
ee_init_kwargs: Optional[Dict[str, Any]] = None,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
) -> xarray.Dataset: # type: ignore
"""Open an Earth Engine ImageCollection as an Xarray Dataset.
Expand Down Expand Up @@ -1084,6 +1108,10 @@ def open_dataset(
- 'max_retries', the maximum number of retry attempts. Defaults to 6.
- 'initial_delay', the initial delay in milliseconds before the first
retry. Defaults to 500.
fast_time_slicing (optional): Whether to perform an optimization that
makes slicing an ImageCollection across time faster. This optimization
loads EE images in a slice by ID, so any modifications to images in a
computed ImageCollection will not be reflected.
Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
"""
Expand Down Expand Up @@ -1114,6 +1142,7 @@ def open_dataset(
ee_init_if_necessary=ee_init_if_necessary,
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down
32 changes: 32 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,38 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_fast_time_slicing(self):
band = 'temperature_2m'
hourly = (
ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY')
.filterDate('2024-01-01', '2024-01-02')
.select(band)
)
first = hourly.first()
props = ['system:id', 'system:time_start']
fake_collection = ee.ImageCollection(
hourly.toList(count=hourly.size()).replace(
first, ee.Image(0).rename(band).copyProperties(first, props)
)
)

params = dict(
filename_or_obj=fake_collection,
engine=xee.EarthEngineBackendEntrypoint,
geometry=ee.Geometry.BBox(-83.86, 41.13, -76.83, 46.15),
projection=first.projection().atScale(100000),
)

# With slow slicing, the returned data should include the modified image.
slow_slicing = xr.open_dataset(**params)
slow_slicing_data = getattr(slow_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(slow_slicing_data == 0))

# With fast slicing, the returned data should include the original image.
fast_slicing = xr.open_dataset(**params, fast_time_slicing=True)
fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(fast_slicing_data > 0))

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_projected_dataset_to_raster(self):
# ensure that a projected dataset written to a raster intersects with the
Expand Down

0 comments on commit 5fe56bb

Please sign in to comment.