Skip to content

Commit

Permalink
Merge pull request #216 from allenai/r2
Browse files Browse the repository at this point in the history
Support for R2
  • Loading branch information
dirkgr authored Feb 22, 2024
2 parents 23bde11 + 1b492ce commit 8b3e94b
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for R2 (`r2://*`).
- `verbose` parameter for `find_latest_cached()`
- Added support for extracting RAR files.

Expand Down
5 changes: 3 additions & 2 deletions cached_path/schemes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .hf import hf_get_from_cache
from .http import HttpClient
from .s3 import S3Client
from .r2 import R2Client
from .scheme_client import SchemeClient

__all__ = ["GsClient", "HttpClient", "S3Client", "SchemeClient", "hf_get_from_cache"]
__all__ = ["GsClient", "HttpClient", "S3Client", "R2Client", "SchemeClient", "hf_get_from_cache"]

try:
from .beaker import BeakerClient
Expand Down Expand Up @@ -36,7 +37,7 @@ def add_scheme_client(client: Type[SchemeClient]) -> None:
raise ValueError(f"Unexpected type for {client} scheme: {client.scheme}")


for client in (HttpClient, S3Client, GsClient):
for client in (HttpClient, S3Client, R2Client, GsClient):
add_scheme_client(client) # type: ignore
if BeakerClient is not None:
add_scheme_client(BeakerClient)
Expand Down
80 changes: 80 additions & 0 deletions cached_path/schemes/r2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Cloudflare R2.
"""
import io
import os
from typing import Optional

import boto3.session
from botocore.config import Config
import botocore.exceptions

from .scheme_client import SchemeClient
from ..common import _split_cloud_path


class R2Client(SchemeClient):
recoverable_errors = SchemeClient.recoverable_errors + (
botocore.exceptions.HTTPClientError,
botocore.exceptions.ConnectionError,
)

scheme = "r2"

def __init__(self, resource: str) -> None:
SchemeClient.__init__(self, resource)
self.bucket_name, self.path = _split_cloud_path(resource, "r2")

# find credentials
endpoint_url = os.environ.get("R2_ENDPOINT_URL")
if endpoint_url is None:
raise ValueError(
"R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
)
profile_name = os.environ.get("R2_PROFILE")
access_key_id = os.environ.get("R2_ACCESS_KEY_ID")
secret_access_key = os.environ.get("R2_SECRET_ACCESS_KEY")
if access_key_id is not None and secret_access_key is not None:
client_kwargs = {
"aws_access_key_id": access_key_id,
"aws_secret_access_key": secret_access_key,
}
elif profile_name is not None:
client_kwargs = {"profile_name": profile_name}
else:
raise ValueError(
"To authenticate for R2, you either have to set the 'R2_PROFILE' env var and set up this profile, "
"or set R2_ACCESS_KEY_ID and R2_SECRET_ACCESS_KEY."
)

self.s3 = boto3.client(
service_name="s3",
endpoint_url=endpoint_url,
region_name="auto",
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
**client_kwargs,
)
self.object_info = None

def _ensure_object_info(self):
if self.object_info is None:
self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)

def get_etag(self) -> Optional[str]:
self._ensure_object_info()
assert self.object_info is not None
return self.object_info.get("ETag")

def get_size(self) -> Optional[int]:
self._ensure_object_info()
assert self.object_info is not None
return self.object_info.get("ContentLength")

def get_resource(self, temp_file: io.BufferedWriter) -> None:
self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)

def get_bytes_range(self, index: int, length: int) -> bytes:
response = self.s3.get_object(
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
)
return response["Body"].read()
25 changes: 25 additions & 0 deletions tests/cached_path_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,31 @@ def test_bytes_and_cache_object(self):
assert bytes_snippet_2 == bytes_snippet


class TestCachedPathR2(BaseTestClass):
@flaky
@pytest.mark.skip(
reason="R2 doesn't support publically readable buckets and we don't want to require authentication to run tests."
)
def test_bytes_and_cache_object(self):
# Get some bytes without downloading the file.
bytes_snippet = get_bytes_range(
"r2://allennlp/test_file_for_cached_path_unittests.txt", 5, 10
)
assert len(bytes_snippet) == 10

# Download the file.
path = cached_path("r2://allennlp/test_file_for_cached_path_unittests.txt")
assert path.is_file()
meta = Meta.from_path(_meta_file_path(path))
assert meta.etag is not None

# Now get a range of bytes, this time it should read from the cached file.
bytes_snippet_2 = get_bytes_range(
"r2://allennlp/test_file_for_cached_path_unittests.txt", 5, 10
)
assert bytes_snippet_2 == bytes_snippet


class TestCachedPathHf(BaseTestClass):
@flaky
def test_cached_download_no_user_or_org(self):
Expand Down

0 comments on commit 8b3e94b

Please sign in to comment.