Skip to content

Commit

Permalink
pass options to the fetch function (#404)
Browse files Browse the repository at this point in the history
* pass options to the fetch function

* Update CHANGES.md

Co-authored-by: Kyle Barron <[email protected]>

* move environment check in aws_get_object

* Update rio_tiler/io/stac.py

Co-authored-by: Kyle Barron <[email protected]>

* Update CHANGES.md

* Update rio_tiler/io/stac.py

Co-authored-by: Kyle Barron <[email protected]>
  • Loading branch information
vincentsarago and kylebarron authored Jul 28, 2021
1 parent 0c23066 commit 26d4733
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
7 changes: 6 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
## Unreleased

* add support for setting the S3 endpoint url via the `AWS_S3_ENDPOINT` environment variables in `aws_get_object` function using boto3 (https://github.com/cogeotiff/rio-tiler/pull/394)

* make `ImageStatistics.valid_percent` a value between 0 and 100 (instead of 0 and 1) (author @param-thakker, https://github.com/cogeotiff/rio-tiler/pull/400)
* add `fetch_options` to `STACReader` to allow custom configuration to the fetch client (https://github.com/cogeotiff/rio-tiler/pull/404)

```python
with STACReader("s3://...", fetch_options={"request_pays": True}):
pass
```

## 2.1.0 (2021-05-17)

Expand Down
15 changes: 9 additions & 6 deletions rio_tiler/io/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import json
from typing import Dict, Iterator, Optional, Set, Type, Union
from typing import Any, Dict, Iterator, Optional, Set, Type, Union
from urllib.parse import urlparse

import attr
Expand Down Expand Up @@ -31,13 +31,14 @@


@functools.lru_cache(maxsize=512)
def fetch(filepath: str) -> Dict:
def fetch(filepath: str, **kwargs: Any) -> Dict:
"""Fetch STAC items.
A LRU cache is set on top of this function.
Args:
filepath (str): STAC item URL.
kwargs (any): additional options to pass to client.
Returns:
dict: STAC Item content.
Expand All @@ -47,10 +48,10 @@ def fetch(filepath: str) -> Dict:
if parsed.scheme == "s3":
bucket = parsed.netloc
key = parsed.path.strip("/")
return json.loads(aws_get_object(bucket, key))
return json.loads(aws_get_object(bucket, key, **kwargs))

elif parsed.scheme in ["https", "http", "ftp"]:
return requests.get(filepath).json()
return requests.get(filepath, **kwargs).json()

else:
with open(filepath, "r") as f:
Expand Down Expand Up @@ -130,7 +131,8 @@ class STACReader(MultiBaseReader):
include_asset_types (set of string, optional): Only include some assets base on their type.
exclude_asset_types (set of string, optional): Exclude some assets base on their type.
reader (rio_tiler.io.BaseReader, optional): rio-tiler Reader. Defaults to `rio_tiler.io.COGReader`.
reader_options (dict, optional): additional option to forward to the Reader. Defaults to `{}`.
reader_options (dict, optional): Additional option to forward to the Reader. Defaults to `{}`.
fetch_options (dict, optional): Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`.
Examples:
>>> with STACReader(stac_path) as stac:
Expand Down Expand Up @@ -162,11 +164,12 @@ class STACReader(MultiBaseReader):
exclude_asset_types: Optional[Set[str]] = attr.ib(default=None)
reader: Type[BaseReader] = attr.ib(default=COGReader)
reader_options: Dict = attr.ib(factory=dict)
fetch_options: Dict = attr.ib(factory=dict)

def __attrs_post_init__(self):
"""Fetch STAC Item and get list of valid assets."""
self.item = self.item or pystac.Item.from_dict(
fetch(self.filepath), self.filepath
fetch(self.filepath, **self.fetch_options), self.filepath
)
self.bounds = self.item.bbox
self.assets = list(
Expand Down
3 changes: 2 additions & 1 deletion rio_tiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def aws_get_object(
client = session.client("s3", endpoint_url=endpoint_url)

params = {"Bucket": bucket, "Key": key}
if request_pays:
if request_pays or os.environ.get("AWS_REQUEST_PAYER", "").lower() == "requester":
params["RequestPayer"] = "requester"

response = client.get_object(**params)
return response["Body"].read()

Expand Down
39 changes: 39 additions & 0 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,42 @@ def test_relative_assets():

for asset in stac.assets:
assert stac._get_asset_url(asset).startswith(PREFIX)


@patch("rio_tiler.io.stac.aws_get_object")
@patch("rio_tiler.io.stac.requests")
def test_fetch_stac_client_options(requests, s3_get):
# HTTP
class MockResponse:
def __init__(self, data):
self.data = data

def json(self):
return json.loads(self.data)

with open(STAC_PATH, "r") as f:
requests.get.return_value = MockResponse(f.read())

with STACReader(
"http://somewhereovertherainbow.io/mystac.json",
fetch_options={"auth": ("user", "pass")},
) as stac:
assert stac.assets == ["red", "green", "blue"]
requests.get.assert_called_once()
assert requests.get.call_args[1]["auth"] == ("user", "pass")
s3_get.assert_not_called()
requests.mock_reset()

# S3
with open(STAC_PATH, "r") as f:
s3_get.return_value = f.read()

with STACReader(
"s3://somewhereovertherainbow.io/mystac.json",
fetch_options={"request_pays": True},
) as stac:
assert stac.assets == ["red", "green", "blue"]
requests.assert_not_called()
s3_get.assert_called_once()
assert s3_get.call_args[1]["request_pays"]
assert s3_get.call_args[0] == ("somewhereovertherainbow.io", "mystac.json")

0 comments on commit 26d4733

Please sign in to comment.