Skip to content

Commit

Permalink
make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
scottyhq committed Oct 1, 2020
1 parent d3e0327 commit 9cdc792
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 60 deletions.
30 changes: 11 additions & 19 deletions intake_stac/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings

import satstac
import yaml
from intake.catalog import Catalog
from intake.catalog.local import LocalCatalogEntry
from pkg_resources import get_distribution
Expand Down Expand Up @@ -68,12 +67,9 @@ def serialize(self):
Returns
-------
A string with the yaml-formatted catalog.
A string with the yaml-formatted catalog (just top-level).
"""
output = {'metadata': self.metadata, 'sources': {}}
for key, entry in self.items():
output['sources'][key] = yaml.safe_load(entry.yaml())['sources']
return yaml.dump(output)
return self.yaml()


class StacCatalog(AbstractStacCatalog):
Expand Down Expand Up @@ -268,13 +264,20 @@ def _get_band_info(self):
)
return band_info

def stack_bands(self, bands, regrid=False):
def stack_bands(self, bands):
"""
Stack the listed bands over the ``band`` dimension.
This method only works for STAC Items using the 'eo' Extension
https://github.com/radiantearth/stac-spec/tree/master/extensions/eo
NOTE: This method is not aware of geotransform information. It *assumes*
bands for a given STAC Item have the same coordinate reference system (CRS).
This is usually the case for a given multi-band satellite acquisition.
Coordinate alignment is performed automatically upon calling the
`to_dask()` method to load into an Xarray DataArray if bands have diffent
ground sample distance (gsd) or array shapes.
Parameters
----------
bands : list of strings representing the different bands
Expand Down Expand Up @@ -307,7 +310,7 @@ def stack_bands(self, bands, regrid=False):
if info is not None:
band = info.get('id', info.get('name'))

if band not in assets or (regrid is False and info is None):
if band not in assets or info is None:
valid_band_names = []
for b in band_info:
valid_band_names.append(b.get('id', b.get('name')))
Expand All @@ -331,17 +334,6 @@ def stack_bands(self, bands, regrid=False):
'band info in a fixed section of the url'
)

if regrid is False:
gsd = info.get('gsd')
if 'gsd' not in item:
item['gsd'] = gsd
elif item['gsd'] != gsd:
raise ValueError(
f'Stacking failed: {band} has different ground '
f'sampling distance ({gsd}) than other bands '
f'({item["gsd"]})'
)

titles.append(band)
item['urlpath'].append(href)

Expand Down
64 changes: 24 additions & 40 deletions intake_stac/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def test_init_catalog_from_url(stac_cat_url):
assert isinstance(cat, intake.catalog.Catalog)
assert cat.name == 'stac-catalog'
assert cat.discover()['container'] == 'catalog'
assert int(cat.metadata['stac_version'][0]) >= 1

cat = StacCatalog.from_url(stac_cat_url)
assert isinstance(cat, intake.catalog.Catalog)
assert cat.name == 'stac-catalog'
assert cat.discover()['container'] == 'catalog'
assert int(cat.metadata['stac_version'][0]) >= 1

# test kwargs are passed through
cat = StacCatalog.from_url(stac_cat_url, name='intake-stac-test')
Expand Down Expand Up @@ -88,7 +90,6 @@ def test_init_catalog_with_bad_url_raises():
StacCatalog('foo.bar')


@pytest.mark.xfail(reason='need to fix serialization')
def test_serialize(cat):
cat_str = cat.serialize()
assert isinstance(cat_str, str)
Expand Down Expand Up @@ -121,75 +122,58 @@ def test_cat_from_item(stac_item_obj):
assert 'B5' in cat


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking(stac_item_obj):
items = StacItem(stac_item_obj)
item = StacItem(stac_item_obj)
list_of_bands = ['B1', 'B2']
new_entry = items.stack_bands(list_of_bands)
assert new_entry.description == 'Band 1 (coastal), Band 2 (blue)'
new_entry = item.stack_bands(list_of_bands)
assert isinstance(new_entry, StacEntry)
assert new_entry._description == 'B1, B2'
assert new_entry.name == 'B1_B2'
new_da = new_entry.to_dask()
new_da = new_entry().to_dask()
assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y']
assert (new_da.band == list_of_bands).all()


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking_using_common_name(stac_item_obj):
items = StacItem(stac_item_obj)
item = StacItem(stac_item_obj)
list_of_bands = ['coastal', 'blue']
new_entry = items.stack_bands(list_of_bands)
assert new_entry.description == 'Band 1 (coastal), Band 2 (blue)'
new_entry = item.stack_bands(list_of_bands)
assert isinstance(new_entry, StacEntry)
assert new_entry._description == 'B1, B2'
assert new_entry.name == 'coastal_blue'
new_da = new_entry.to_dask()
new_da = new_entry().to_dask()
assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y']
assert (new_da.band == ['B1', 'B2']).all()


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking_dims_of_different_type_raises_error(stac_item_obj):
items = StacItem(stac_item_obj)
item = StacItem(stac_item_obj)
list_of_bands = ['B1', 'ANG']
with pytest.raises(ValueError, match=('ANG not found in list of eo:bands in collection')):
items.stack_bands(list_of_bands)
item.stack_bands(list_of_bands)


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking_dims_with_nonexistent_band_raises_error(stac_item_obj,): # noqa: E501
items = StacItem(stac_item_obj)
item = StacItem(stac_item_obj)
list_of_bands = ['B1', 'foo']
with pytest.raises(ValueError, match="'B8', 'B9', 'blue', 'cirrus'"):
items.stack_bands(list_of_bands)
item.stack_bands(list_of_bands)


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking_dims_of_different_size_regrids(stac_item_obj):
items = StacItem(stac_item_obj)
item = StacItem(stac_item_obj)
list_of_bands = ['B1', 'B8']
B1_da = items.B1.to_dask()
assert B1_da.shape == (1, 7801, 7641)
B8_da = items.B8.to_dask()
assert B8_da.shape == (1, 15601, 15281)
new_entry = items.stack_bands(list_of_bands, regrid=True)
new_da = new_entry.to_dask()
assert new_da.shape == (2, 15601, 15281)
B1_da = item.B1.to_dask()
assert B1_da.shape == (1, 7791, 7651)
B8_da = item.B8.to_dask()
assert B8_da.shape == (1, 15581, 15301)
new_entry = item.stack_bands(list_of_bands)
new_da = new_entry().to_dask()
assert new_da.shape == (2, 15581, 15301)
assert sorted([dim for dim in new_da.dims]) == ['band', 'x', 'y']
assert (new_da.band == list_of_bands).all()


@pytest.mark.xfail(reason='need to fix stack bands')
def test_cat_item_stacking_dims_of_different_size_raises_error_by_default(
stac_item_obj,
): # noqa: E501
items = StacItem(stac_item_obj)
list_of_bands = ['B1', 'B8']
B1_da = items.B1.to_dask()
assert B1_da.shape == (1, 7801, 7641)
B8_da = items.B8.to_dask()
assert B8_da.shape == (1, 15601, 15281)
with pytest.raises(ValueError, match='B8 has different ground sampling'):
items.stack_bands(list_of_bands)


def test_stac_entry_constructor():
key = 'B1'
item = {
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=intake_stac
known_third_party=intake,pkg_resources,pytest,satstac,setuptools,yaml
known_third_party=intake,pkg_resources,pytest,satstac,setuptools
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down

0 comments on commit 9cdc792

Please sign in to comment.