From 4a532489a51b2e7f8fb57e265843eec8343e2127 Mon Sep 17 00:00:00 2001 From: Victor Engmark Date: Mon, 11 Dec 2023 16:46:49 +1300 Subject: [PATCH] refactor: Import from the same level consistently --- scripts/aws/aws_helper.py | 16 ++++----- scripts/files/fs_s3.py | 22 ++++++------- scripts/files/tests/fs_s3_test.py | 55 +++++++++++++++---------------- 3 files changed, 46 insertions(+), 47 deletions(-) diff --git a/scripts/aws/aws_helper.py b/scripts/aws/aws_helper.py index 656b57b3f..92a360ef1 100644 --- a/scripts/aws/aws_helper.py +++ b/scripts/aws/aws_helper.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, NamedTuple, Optional from urllib.parse import urlparse -import boto3 -import botocore +from boto3 import Session from botocore.credentials import AssumeRoleCredentialFetcher, DeferredRefreshableCredentials, ReadOnlyCredentials +from botocore.session import Session as BotocoreSession from linz_logger import get_log from scripts.aws.aws_credential_source import CredentialSource @@ -14,8 +14,8 @@ S3Path = NamedTuple("S3Path", [("bucket", str), ("key", str)]) aws_profile = environ.get("AWS_PROFILE") -session = boto3.Session(profile_name=aws_profile) -sessions: Dict[str, boto3.Session] = {} +session = Session(profile_name=aws_profile) +sessions: Dict[str, Session] = {} bucket_roles: List[CredentialSource] = [] @@ -40,14 +40,14 @@ def _init_roles() -> None: get_log().debug("bucket_config_loaded", config=bucket_config_path, prefix_count=len(bucket_roles)) -def _get_client_creator(local_session: boto3.Session) -> Any: +def _get_client_creator(local_session: Session) -> Any: def client_creator(service_name: str, **kwargs: Any) -> Any: return local_session.client(service_name, **kwargs) return client_creator -def get_session(prefix: str) -> boto3.Session: +def get_session(prefix: str) -> Session: """Get an AWS session to deal with an object on `s3`. Args: @@ -78,14 +78,14 @@ def get_session(prefix: str) -> boto3.Session: role_arn=cfg.roleArn, extra_args=extra_args, ) - botocore_session = botocore.session.Session() + botocore_session = BotocoreSession() # pylint:disable=protected-access botocore_session._credentials = DeferredRefreshableCredentials( method="assume-role", refresh_using=fetcher.fetch_credentials ) - current_session = boto3.Session(botocore_session=botocore_session) + current_session = Session(botocore_session=botocore_session) sessions[cfg.roleArn] = current_session get_log().info("role_assume", prefix=prefix, bucket=cfg.bucket, role_arn=cfg.roleArn) diff --git a/scripts/files/fs_s3.py b/scripts/files/fs_s3.py index af7e451ec..b72ee6009 100644 --- a/scripts/files/fs_s3.py +++ b/scripts/files/fs_s3.py @@ -2,8 +2,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Generator, List, Optional, Union -import boto3 -import botocore +from boto3 import client, resource +from botocore.exceptions import ClientError from linz_logger import get_log from scripts.aws.aws_helper import get_session, parse_path @@ -25,7 +25,7 @@ def write(destination: str, source: bytes, content_type: Optional[str] = None) - raise Exception("The 'source' is None.") s3_path = parse_path(destination) key = s3_path.key - s3 = boto3.resource("s3") + s3 = resource("s3") try: s3_object = s3.Object(s3_path.bucket, key) @@ -34,7 +34,7 @@ def write(destination: str, source: bytes, content_type: Optional[str] = None) - else: s3_object.put(Body=source) get_log().debug("write_s3_success", path=destination, duration=time_in_ms() - start_time) - except botocore.exceptions.ClientError as ce: + except ClientError as ce: get_log().error("write_s3_error", path=destination, error=f"Unable to write the file: {ce}") raise ce @@ -55,7 +55,7 @@ def read(path: str, needs_credentials: bool = False) -> bytes: start_time = time_in_ms() s3_path = parse_path(path) key = s3_path.key - s3 = boto3.resource("s3") + s3 = resource("s3") try: if needs_credentials: @@ -95,7 +95,7 @@ def exists(path: str, needs_credentials: bool = False) -> bool: True if the S3 Object exists """ s3_path, key = parse_path(path) - s3 = boto3.resource("s3") + s3 = resource("s3") try: if needs_credentials: @@ -168,7 +168,7 @@ def prefix_from_path(path: str) -> str: return path.replace(f"s3://{bucket_name}/", "") -def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]: +def list_json_in_uri(uri: str, s3_client: Optional[client]) -> List[str]: """Get the `JSON` files from a s3 path Args: @@ -179,7 +179,7 @@ def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]: a list of JSON files """ if not s3_client: - s3_client = boto3.client("s3") + s3_client = client("s3") files = [] paginator = s3_client.get_paginator("list_objects_v2") response_iterator = paginator.paginate(Bucket=bucket_name_from_path(uri), Prefix=prefix_from_path(uri)) @@ -195,7 +195,7 @@ def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]: return files -def _get_object(bucket: str, file_name: str, s3_client: boto3.client) -> Any: +def _get_object(bucket: str, file_name: str, s3_client: client) -> Any: """Get the object from `s3` Args: @@ -211,7 +211,7 @@ def _get_object(bucket: str, file_name: str, s3_client: boto3.client) -> Any: def get_object_parallel_multithreading( - bucket: str, files_to_read: List[str], s3_client: Optional[boto3.client], concurrency: int + bucket: str, files_to_read: List[str], s3_client: Optional[client], concurrency: int ) -> Generator[Any, Union[Any, BaseException], None]: """Get s3 objects in parallel @@ -225,7 +225,7 @@ def get_object_parallel_multithreading( the object when got """ if not s3_client: - s3_client = boto3.client("s3") + s3_client = client("s3") with ThreadPoolExecutor(max_workers=concurrency) as executor: future_to_key = {executor.submit(_get_object, bucket, key, s3_client): key for key in files_to_read} diff --git a/scripts/files/tests/fs_s3_test.py b/scripts/files/tests/fs_s3_test.py index 17b77825b..60aac0bf8 100644 --- a/scripts/files/tests/fs_s3_test.py +++ b/scripts/files/tests/fs_s3_test.py @@ -1,11 +1,10 @@ import json -import boto3 -import botocore -import pytest +from boto3 import client, resource +from botocore.exceptions import ClientError from moto import mock_s3 from moto.s3.responses import DEFAULT_REGION_NAME -from pytest import CaptureFixture +from pytest import CaptureFixture, raises from scripts.files.files_helper import ContentType from scripts.files.fs_s3 import exists, read, write @@ -13,35 +12,35 @@ @mock_s3 # type: ignore def test_write() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") write("s3://testbucket/test.file", b"test content") - resp = client.get_object(Bucket="testbucket", Key="test.file") + resp = boto3_client.get_object(Bucket="testbucket", Key="test.file") assert resp["Body"].read() == b"test content" assert resp["ContentType"] == "binary/octet-stream" @mock_s3 # type: ignore def test_write_content_type() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") write("s3://testbucket/test.tiff", b"test content", ContentType.GEOTIFF.value) - resp = client.get_object(Bucket="testbucket", Key="test.tiff") + resp = boto3_client.get_object(Bucket="testbucket", Key="test.tiff") assert resp["Body"].read() == b"test content" assert resp["ContentType"] == ContentType.GEOTIFF.value @mock_s3 # type: ignore def test_read() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - client.put_object(Bucket="testbucket", Key="test.file", Body=b"test content") + boto3_client.put_object(Bucket="testbucket", Key="test.file", Body=b"test content") content = read("s3://testbucket/test.file") @@ -50,7 +49,7 @@ def test_read() -> None: @mock_s3 # type: ignore def test_read_bucket_not_found(capsys: CaptureFixture[str]) -> None: - with pytest.raises(botocore.exceptions.ClientError): + with raises(ClientError): read("s3://testbucket/test.file") # python-linz-logger uses structlog which doesn't use stdlib so can't capture the logs with `caplog` @@ -60,10 +59,10 @@ def test_read_bucket_not_found(capsys: CaptureFixture[str]) -> None: @mock_s3 # type: ignore def test_read_key_not_found(capsys: CaptureFixture[str]) -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - with pytest.raises(botocore.exceptions.ClientError): + with raises(ClientError): read("s3://testbucket/test.file") logs = json.loads(capsys.readouterr().out) @@ -72,10 +71,10 @@ def test_read_key_not_found(capsys: CaptureFixture[str]) -> None: @mock_s3 # type: ignore def test_exists() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - client.put_object(Bucket="testbucket", Key="test.file", Body=b"test content") + boto3_client.put_object(Bucket="testbucket", Key="test.file", Body=b"test content") file_exists = exists("s3://testbucket/test.file") @@ -84,10 +83,10 @@ def test_exists() -> None: @mock_s3 # type: ignore def test_directory_exists() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - client.put_object(Bucket="testbucket", Key="hello/test.file", Body=b"test content") + boto3_client.put_object(Bucket="testbucket", Key="hello/test.file", Body=b"test content") directory_exists = exists("s3://testbucket/hello/") @@ -105,10 +104,10 @@ def test_exists_bucket_not_exists(capsys: CaptureFixture[str]) -> None: @mock_s3 # type: ignore def test_exists_object_not_exists() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - client.put_object(Bucket="testbucket", Key="hello/another.file", Body=b"test content") + boto3_client.put_object(Bucket="testbucket", Key="hello/another.file", Body=b"test content") file_exists = exists("s3://testbucket/test.file") @@ -117,10 +116,10 @@ def test_exists_object_not_exists() -> None: @mock_s3 # type: ignore def test_exists_object_starting_with_not_exists() -> None: - s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) - client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + s3 = resource("s3", region_name=DEFAULT_REGION_NAME) + boto3_client = client("s3", region_name=DEFAULT_REGION_NAME) s3.create_bucket(Bucket="testbucket") - client.put_object(Bucket="testbucket", Key="hello/another.file", Body=b"test content") + boto3_client.put_object(Bucket="testbucket", Key="hello/another.file", Body=b"test content") file_exists = exists("s3://testbucket/hello/another.fi")