Skip to content

Commit

Permalink
Add an S3 client with get_object_meta functionality
Browse files Browse the repository at this point in the history
Adds an S3 client base with one method to get the e-tag and part size
for an object. This allows the code in remote.py to verify the e-tag.
  • Loading branch information
MrCreosote committed Aug 17, 2024
1 parent 537478f commit 5b3a064
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 5 deletions.
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@

Enables running jobs on remote compute from the KBase CDM cluster.

## Nomenclature

* CDM: Central Data Model
* The KBase data model
* CTS: CDM Task Service

## Service Requirements

* Python 3.11+
* [crane](https://github.com/google/go-containerregistry/blob/main/cmd/crane/README.md)
* An s3 instance for use as a file store
* The provided credentials must enable listing buckets, as the service performs that operation
to check the host and credentials on startup
* If using Minio, the minimum version is `2021-04-22T15-44-28Z` and the server must be run
in `--compat` mode.
* An s3 instance for use as a file store, but see "S3 requirements" below

### S3 requirements

* Any objects provided to the servcie that were created with multipart uploads **must** use the
same part size for all parts except the last.
* The service does not support objects encrypted with customer supplied keys or with the
AWS key management service.
* The provided credentials must enable listing buckets, as the service performs that operation
to check the host and credentials on startup
* If using Minio, the minimum version is `2021-04-22T15-44-28Z` and the server must be run
in `--compat` mode.

## Development

Expand Down
207 changes: 207 additions & 0 deletions cdmtaskservice/s3/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
An s3 client tailored for the needs of the CDM Task Service.
Note the client is *not threadsafe* as the underlying aiobotocore session is not threadsafe.
"""

from aiobotocore.session import get_session
from botocore.config import Config
from botocore.exceptions import EndpointConnectionError, ClientError
from botocore.parsers import ResponseParserError
import logging
from typing import Any, Self


class S3ObjectMeta:
""" Metadata about an object in S3 storage. """

__slots__ = ["path", "e_tag", "part_size", "size"]

def __init__(self, path: str, e_tag: str, size: int, part_size: int):
"""
Create the object meta. This is not expected to be instantiated manually.
path - the path of the object in s3 storage, including the bucket.
e-tag - the object e-tag.
size - the total size of the object
part_size - the part size used in a multipart upload, if relevant.
"""
# input checking seems a little pointless. Could check that the path looks like a path
# and the e-tag looks like an e-tag, but since this module should be creating this
# class why bother. This is why access modifiers for methods are good
self.path = path
self.e_tag = e_tag
self.size = size
self.part_size = part_size
# could check that the part size makes sense given the e-tag... meh

@property
def has_parts(self) -> bool:
""" Returns true if the object was uploaded as multipart. """
return len(self.e_tag) > 32

@property
def num_parts(self) -> int:
"""
Returns the number of parts used in a multipart upload or 1 if the upload was not
multipart.
"""
if len(self.e_tag) > 32:
return int(self.e_tag.split("-")[1])
return 1


class S3Client:
"""
The S3 client.
Note the client is *not threadsafe* as the underlying aiobotocore session is not threadsafe.
"""

@classmethod
async def create(
cls,
endpoint_url: str,
access_key: str,
secret_key: str,
config: dict[str, Any] = None,
skip_connection_check: bool = False
) -> Self:
"""
Create the client.
endpoint_url - the URL of the s3 endpoint.
access_key - the s3 access key.
secret_key - the s3 secret key.
config - Any client configuration options provided as a dictionary. See
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
skip_connetion_check - don't try to list_buckets while creating the client to check
the host and credentials are correct.
"""
s3c = S3Client(endpoint_url, access_key, secret_key, config)
if not skip_connection_check:
async def list_buckets(client):
return await client.list_buckets()
await s3c._run_command(list_buckets)
return s3c

def __init__(
self, endpoint_url: str, access_key: str, secret_key: str, config: dict[str, Any]
):
self._url = self._require_string(endpoint_url, "endpoint_url")
self._ak = self._require_string(access_key, "access_key")
self._sk = self._require_string(secret_key, "secret_key")
self._config = Config(**config) if config else None
self._sess = get_session()

def _require_string(self, string, name):
if not string or not string.strip():
raise ValueError(f"{name} is required")
return string.strip()

def _client(self):
# Creating a client seems to be pretty cheap, usually < 20ms.
return self._sess.create_client(
"s3",
endpoint_url=self._url,
aws_access_key_id=self._ak,
aws_secret_access_key=self._sk,
config=self._config,
)

async def _run_command(self, async_client_callable, path=None):
try:
async with self._client() as client:
return await async_client_callable(client)
except (ValueError, EndpointConnectionError) as e:
raise S3ClientConnectError(f"s3 connect failed: {e}") from e
except ResponseParserError as e:
# TODO TEST logging
# TODO LOGGING figure out how logging is going to work
logging.getLogger(__name__).error(
f"Unable to parse response from S3:\n{e}\n")
raise S3ClientConnectError(
f"s3 response from the server at {self._url} was not parseable. "
+ "See logs for details"
) from e
except ClientError as e:
code = e.response["Error"]["Code"]
if code == "SignatureDoesNotMatch":
raise S3ClientConnectError("s3 access credentials are invalid")
if code == "404":
raise S3PathError(f"The path '{path}' was not found on the s3 system") from e
if code == "AccessDenied" or code == "403": # why both? Both 403s
if not path:
raise S3ClientConnectError(
"Access denied to list buckets on the s3 system") from e
# may need to add other cases here
raise S3PathError(f"Access denied to path '{path}' on the s3 system") from e
# no way to test this since we're trying to cover all possible errors in tests
logging.getLogger(__name__).error(

Check warning on line 140 in cdmtaskservice/s3/client.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/s3/client.py#L140

Added line #L140 was not covered by tests
f"Unexpected response from S3. Response data:\n{e.response}\nTraceback:\n{e}\n")
raise S3UnexpectedError(str(e)) from e

Check warning on line 142 in cdmtaskservice/s3/client.py

View check run for this annotation

Codecov / codecov/patch

cdmtaskservice/s3/client.py#L142

Added line #L142 was not covered by tests


async def get_object_meta(self, path: str) -> S3ObjectMeta:
"""
Get metadata about an object.
path - the path of the object in s3, starting with the bucket.
"""
buk, key = _validate_and_split_path(path)
async def head(client):
return await client.head_object(Bucket=buk, Key=key, PartNumber=1)
res = await self._run_command(head, path=path.strip())
size = res["ContentLength"]
part_size = None
if "PartsCount" in res:
part_size = size
content_range = res["ResponseMetadata"]["HTTPHeaders"]["content-range"]
size = int(content_range.split("/")[1])
return S3ObjectMeta(
path=path,
e_tag = res["ETag"].strip('"'),
size=size,
part_size=part_size
)


def _validate_and_split_path(path: str) -> (str, str):
if not path or not path.strip():
raise S3PathError("An s3 path cannot be null or a whitespace string")
parts = [s.strip() for s in path.split("/", 1) if s.strip()]
if len(parts) != 2:
raise S3PathError(
f"path '{path.strip()}' must start with the s3 bucket and include a key")
_validate_bucket_name(parts[0])
return parts[0], parts[1]


def _validate_bucket_name(bucket_name: str):
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
bn = bucket_name.strip()
if len(bn) < 3 or len(bn) > 63:
raise S3PathError(f"Bucket names must be > 2 and < 64 characters: {bn}")
if "." in bn:
raise S3PathError(f"Buckets with `.` in the name are unsupported: {bn}")
if bn.startswith("-") or bn.endswith("-"):
raise S3PathError(f"Bucket names cannot start or end with '-': {bn}")
if not bn.replace("-", "").isalnum() or not bn.isascii() or not bn.islower():
raise S3PathError(
f"Bucket names may only contain '-' and lowercase ascii alphanumerics: {bn}")


class S3ClientError(Exception):
""" The base class for S3 client errors. """


class S3ClientConnectError(Exception):
""" The base class for S3 client errors. """


class S3PathError(S3ClientError):
""" Error thrown when an S3 path is incorrectly specified. """


class S3UnexpectedError(S3ClientError):
""" Error thrown when an S3 path is incorrectly specified. """
150 changes: 150 additions & 0 deletions test/s3/client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest
import io

from cdmtaskservice.s3.client import S3Client, S3ClientConnectError, S3PathError
from conftest import minio, minio_unauthed_user, assert_exception_correct


@pytest.mark.asyncio
async def test_create_fail_missing_args():
u = "https://localhost:1234"
await _create_fail(None, "foo", "bar", ValueError("endpoint_url is required"))
await _create_fail(" \t ", "foo", "bar", ValueError("endpoint_url is required"))
await _create_fail(u, None, "bar", ValueError("access_key is required"))
await _create_fail(u, " \t ", "bar", ValueError("access_key is required"))
await _create_fail(u, "foo", None, ValueError("secret_key is required"))
await _create_fail(u, "foo", " \t ", ValueError("secret_key is required"))


@pytest.mark.asyncio
async def test_create_fail_bad_args(minio, minio_unauthed_user):
bad_ep1 = f"localhost:{minio.port}"
await _create_fail(
bad_ep1, "foo", "bar",
S3ClientConnectError("s3 connect failed: Invalid endpoint: " + bad_ep1))
bad_ep2 = f"http://localhost:{minio.port + 1}"
await _create_fail(
bad_ep2, "foo", "bar",
S3ClientConnectError(
f's3 connect failed: Could not connect to the endpoint URL: "{bad_ep2}/"'),
{"connect_timeout": 0.2, "retries": {"total_max_attempts": 1}},
)
await _create_fail(
"https://google.com", "foo", "bar",
S3ClientConnectError(
"s3 response from the server at https://google.com was not parseable. "
+ "See logs for details"
),
)
await _create_fail(
f"http://localhost:{minio.port}", minio.access_key, "bar",
S3ClientConnectError("s3 access credentials are invalid"))
await _create_fail(
f"http://localhost:{minio.port}", minio_unauthed_user[0], minio_unauthed_user[1],
S3ClientConnectError("Access denied to list buckets on the s3 system"))


async def _create_fail(host, akey, skey, expected, config=None, print_stacktrace=False):
with pytest.raises(Exception) as got:
await S3Client.create(host, akey, skey, config)
assert_exception_correct(got.value, expected, print_stacktrace)


@pytest.mark.asyncio
async def test_get_object_meta_single_part(minio):
await minio.clean() # couldn't get this to work as a fixture
await minio.create_bucket("test-bucket")
await minio.upload_file("test-bucket/test_file", b"abcdefghij")

s3c = await client(minio)
objm = await s3c.get_object_meta("test-bucket/test_file")
check_obj_meta(
objm, "test-bucket/test_file", "a925576942e94b2ef57a066101b48876", 10, None, False, 1)


@pytest.mark.asyncio
async def test_get_object_meta_multipart(minio):
await minio.clean() # couldn't get this to work as a fixture
await minio.create_bucket("test-bucket")
await minio.upload_file(
"test-bucket/big_test_file", b"abcdefghij" * 6000000, 3, b"bigolfile")

s3c = await client(minio)
objm = await s3c.get_object_meta("test-bucket/big_test_file")
check_obj_meta(
objm,
"test-bucket/big_test_file",
"e0fcd4584a5157e2d465bf0217ab8268-4",
180000009,
60000000,
True,
4,
)


@pytest.mark.asyncio
async def test_get_object_meta_fail_bad_path(minio):
# Will probably want to refactor these tests so they can be generically be applied to
# any endpoints that take a path
await minio.clean()
await minio.create_bucket("fail-bucket")
await minio.upload_file("fail-bucket/foo/bar", b"foo")

charerr = "Bucket names may only contain '-' and lowercase ascii alphanumerics: "

testset = {
None: "An s3 path cannot be null or a whitespace string",
" \t ": "An s3 path cannot be null or a whitespace string",
" / ": "path '/' must start with the s3 bucket and include a key",
"foo / ": "path 'foo /' must start with the s3 bucket and include a key",
" / bar ": "path '/ bar' must start with the s3 bucket and include a key",
"il/foo": "Bucket names must be > 2 and < 64 characters: il",
("illegal-bu" * 6) + "cket/foo":
f"Bucket names must be > 2 and < 64 characters: {'illegal-bu' * 6}cket",
"illegal.bucket/foo": "Buckets with `.` in the name are unsupported: illegal.bucket",
"-illegal-bucket/foo": "Bucket names cannot start or end with '-': -illegal-bucket",
"illegal-bucket-/foo": "Bucket names cannot start or end with '-': illegal-bucket-",
"illegal*bucket/foo": charerr + "illegal*bucket",
"illegal_bucket/foo": charerr + "illegal_bucket",
"illegal-Bucket/foo": charerr + "illegal-Bucket",
"illegal-βucket/foo": charerr + "illegal-βucket",
"fake-bucket/foo/bar": "The path 'fake-bucket/foo/bar' was not found on the s3 system",
"fail-bucket/foo/baz": "The path 'fail-bucket/foo/baz' was not found on the s3 system",
}
for k, v in testset.items():
await _get_object_meta_fail(await client(minio), k, S3PathError(v))


@pytest.mark.asyncio
async def test_get_object_meta_fail_unauthed(minio, minio_unauthed_user):
# Will probably want to refactor these tests so they can be generically be applied to
# any endpoint
await minio.clean()
await minio.create_bucket("fail-bucket")
await minio.upload_file("fail-bucket/foo/bar", b"foo")

user, pwd = minio_unauthed_user
s3c = await S3Client.create(minio.host, user, pwd, skip_connection_check=True)
await _get_object_meta_fail(
s3c, "fail-bucket/foo/bar",
S3PathError("Access denied to path 'fail-bucket/foo/bar' on the s3 system")
)


async def _get_object_meta_fail(s3c, path, expected, print_stacktrace=False):
with pytest.raises(Exception) as got:
await s3c.get_object_meta(path)
assert_exception_correct(got.value, expected, print_stacktrace)


async def client(minio):
return await S3Client.create(minio.host, minio.access_key, minio.secret_key)


def check_obj_meta(objm, path, e_tag, size, part_size, has_parts, num_parts):
assert objm.path == path
assert objm.e_tag == e_tag
assert objm.size == size
assert objm.part_size == part_size
assert objm.has_parts is has_parts
assert objm.num_parts == num_parts

0 comments on commit 5b3a064

Please sign in to comment.