Skip to content

Commit

Permalink
refactor: PEP-8 compliance (#771)
Browse files Browse the repository at this point in the history
* docs: Add missing parameter docs

* fix: Remove docs for non-existent parameter

* docs: Fix reference to renamed parameter

* refactor: Fix PEP-8 E713

"test for membership should be ‘not in’"
<https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes>.

* refactor: Use lowercase function names

* refactor: Avoid shadowing built-in name

* refactor: Import from the same level consistently
  • Loading branch information
l0b0 authored Dec 18, 2023
1 parent 250b82c commit 140f32a
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 64 deletions.
16 changes: 8 additions & 8 deletions scripts/aws/aws_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import urlparse

import boto3
import botocore
from boto3 import Session
from botocore.credentials import AssumeRoleCredentialFetcher, DeferredRefreshableCredentials, ReadOnlyCredentials
from botocore.session import Session as BotocoreSession
from linz_logger import get_log

from scripts.aws.aws_credential_source import CredentialSource

S3Path = NamedTuple("S3Path", [("bucket", str), ("key", str)])

aws_profile = environ.get("AWS_PROFILE")
session = boto3.Session(profile_name=aws_profile)
sessions: Dict[str, boto3.Session] = {}
session = Session(profile_name=aws_profile)
sessions: Dict[str, Session] = {}

bucket_roles: List[CredentialSource] = []

Expand All @@ -40,14 +40,14 @@ def _init_roles() -> None:
get_log().debug("bucket_config_loaded", config=bucket_config_path, prefix_count=len(bucket_roles))


def _get_client_creator(local_session: boto3.Session) -> Any:
def _get_client_creator(local_session: Session) -> Any:
def client_creator(service_name: str, **kwargs: Any) -> Any:
return local_session.client(service_name, **kwargs)

return client_creator


def get_session(prefix: str) -> boto3.Session:
def get_session(prefix: str) -> Session:
"""Get an AWS session to deal with an object on `s3`.
Args:
Expand Down Expand Up @@ -78,14 +78,14 @@ def get_session(prefix: str) -> boto3.Session:
role_arn=cfg.roleArn,
extra_args=extra_args,
)
botocore_session = botocore.session.Session()
botocore_session = BotocoreSession()

# pylint:disable=protected-access
botocore_session._credentials = DeferredRefreshableCredentials(
method="assume-role", refresh_using=fetcher.fetch_credentials
)

current_session = boto3.Session(botocore_session=botocore_session)
current_session = Session(botocore_session=botocore_session)
sessions[cfg.roleArn] = current_session

get_log().info("role_assume", prefix=prefix, bucket=cfg.bucket, role_arn=cfg.roleArn)
Expand Down
4 changes: 2 additions & 2 deletions scripts/files/files_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def is_tiff(path: str) -> bool:
return path.lower().endswith((".tiff", ".tif"))


def is_GTiff(path: str, gdalinfo_data: Optional[GdalInfo] = None) -> bool:
def is_geotiff(path: str, gdalinfo_data: Optional[GdalInfo] = None) -> bool:
"""Verifies if a file is a GTiff based on the presence of the
`coordinateSystem`.
Expand All @@ -64,7 +64,7 @@ def is_GTiff(path: str, gdalinfo_data: Optional[GdalInfo] = None) -> bool:
"""
if not gdalinfo_data:
gdalinfo_data = gdal_info(path)
if not "coordinateSystem" in gdalinfo_data:
if "coordinateSystem" not in gdalinfo_data:
return False
if gdalinfo_data["driverShortName"] == "GTiff":
return True
Expand Down
9 changes: 7 additions & 2 deletions scripts/files/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ def write_all(inputs: List[str], target: str, concurrency: Optional[int] = 4) ->
Args:
inputs: list of files to read
target: target folder to write to
concurrency: max thread pool workers
Returns:
list of written file paths
"""
written_tiffs: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futuress = {
executor.submit(write, os.path.join(target, f"{os.path.basename(input)}"), read(input)): input for input in inputs
executor.submit(write, os.path.join(target, f"{os.path.basename(input_)}"), read(input_)): input_
for input_ in inputs
}
for future in as_completed(futuress):
if future.exception():
Expand All @@ -86,6 +88,7 @@ def find_sidecars(inputs: List[str], extensions: List[str], concurrency: Optiona
Args:
inputs: list of input files to search for extensions
extensions: the sidecar file extensions
concurrency: max thread pool workers
Returns:
list of existing sidecar files
Expand All @@ -100,7 +103,9 @@ def _validate_path(path: str) -> Optional[str]:
sidecars: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
for extension in extensions:
futuress = {executor.submit(_validate_path, f"{os.path.splitext(input)[0]}{extension}"): input for input in inputs}
futuress = {
executor.submit(_validate_path, f"{os.path.splitext(input_)[0]}{extension}"): input_ for input_ in inputs
}
for future in as_completed(futuress):
if future.exception():
get_log().warn("Find sidecar failed", error=future.exception())
Expand Down
22 changes: 11 additions & 11 deletions scripts/files/fs_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Generator, List, Optional, Union

import boto3
import botocore
from boto3 import client, resource
from botocore.exceptions import ClientError
from linz_logger import get_log

from scripts.aws.aws_helper import get_session, parse_path
Expand All @@ -25,7 +25,7 @@ def write(destination: str, source: bytes, content_type: Optional[str] = None) -
raise Exception("The 'source' is None.")
s3_path = parse_path(destination)
key = s3_path.key
s3 = boto3.resource("s3")
s3 = resource("s3")

try:
s3_object = s3.Object(s3_path.bucket, key)
Expand All @@ -34,7 +34,7 @@ def write(destination: str, source: bytes, content_type: Optional[str] = None) -
else:
s3_object.put(Body=source)
get_log().debug("write_s3_success", path=destination, duration=time_in_ms() - start_time)
except botocore.exceptions.ClientError as ce:
except ClientError as ce:
get_log().error("write_s3_error", path=destination, error=f"Unable to write the file: {ce}")
raise ce

Expand All @@ -55,7 +55,7 @@ def read(path: str, needs_credentials: bool = False) -> bytes:
start_time = time_in_ms()
s3_path = parse_path(path)
key = s3_path.key
s3 = boto3.resource("s3")
s3 = resource("s3")

try:
if needs_credentials:
Expand Down Expand Up @@ -95,7 +95,7 @@ def exists(path: str, needs_credentials: bool = False) -> bool:
True if the S3 Object exists
"""
s3_path, key = parse_path(path)
s3 = boto3.resource("s3")
s3 = resource("s3")

try:
if needs_credentials:
Expand Down Expand Up @@ -168,7 +168,7 @@ def prefix_from_path(path: str) -> str:
return path.replace(f"s3://{bucket_name}/", "")


def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]:
def list_json_in_uri(uri: str, s3_client: Optional[client]) -> List[str]:
"""Get the `JSON` files from a s3 path
Args:
Expand All @@ -179,7 +179,7 @@ def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]:
a list of JSON files
"""
if not s3_client:
s3_client = boto3.client("s3")
s3_client = client("s3")
files = []
paginator = s3_client.get_paginator("list_objects_v2")
response_iterator = paginator.paginate(Bucket=bucket_name_from_path(uri), Prefix=prefix_from_path(uri))
Expand All @@ -195,7 +195,7 @@ def list_json_in_uri(uri: str, s3_client: Optional[boto3.client]) -> List[str]:
return files


def _get_object(bucket: str, file_name: str, s3_client: boto3.client) -> Any:
def _get_object(bucket: str, file_name: str, s3_client: client) -> Any:
"""Get the object from `s3`
Args:
Expand All @@ -211,7 +211,7 @@ def _get_object(bucket: str, file_name: str, s3_client: boto3.client) -> Any:


def get_object_parallel_multithreading(
bucket: str, files_to_read: List[str], s3_client: Optional[boto3.client], concurrency: int
bucket: str, files_to_read: List[str], s3_client: Optional[client], concurrency: int
) -> Generator[Any, Union[Any, BaseException], None]:
"""Get s3 objects in parallel
Expand All @@ -225,7 +225,7 @@ def get_object_parallel_multithreading(
the object when got
"""
if not s3_client:
s3_client = boto3.client("s3")
s3_client = client("s3")
with ThreadPoolExecutor(max_workers=concurrency) as executor:
future_to_key = {executor.submit(_get_object, bucket, key, s3_client): key for key in files_to_read}

Expand Down
6 changes: 3 additions & 3 deletions scripts/files/tests/file_helper_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from scripts.files.files_helper import is_GTiff, is_tiff
from scripts.files.files_helper import is_geotiff, is_tiff
from scripts.gdal.tests.gdalinfo import fake_gdal_info


Expand All @@ -21,5 +21,5 @@ def test_is_geotiff() -> None:
gdalinfo_not_geotiff["driverShortName"] = "GTiff"
gdalinfo_geotiff["coordinateSystem"] = {"wkt": "PROJCRS['NZGD2000 / New Zealand Transverse Mercator 2000']"}

assert is_GTiff("file.tiff", gdalinfo_geotiff) is True
assert is_GTiff("file.tiff", gdalinfo_not_geotiff) is False
assert is_geotiff("file.tiff", gdalinfo_geotiff) is True
assert is_geotiff("file.tiff", gdalinfo_not_geotiff) is False
10 changes: 5 additions & 5 deletions scripts/files/tests/file_tiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_check_band_count_invalid_4() -> None:
assert file_tiff.get_errors()


def test_check_band_count_valid_1_DEM() -> None:
def test_check_band_count_valid_1_dem() -> None:
"""
tests check_band_count when the input layer has a valid band count
which is 1 bands and a DEM preset
Expand All @@ -80,7 +80,7 @@ def test_check_band_count_valid_1_DEM() -> None:
assert not file_tiff.get_errors()


def test_check_band_count_invalid_alpha_DEM() -> None:
def test_check_band_count_invalid_alpha_dem() -> None:
"""
tests check_band_count when the input layer has a valid band count
which is 2 bands where the second band is Alpha and DEM preset
Expand All @@ -95,7 +95,7 @@ def test_check_band_count_invalid_alpha_DEM() -> None:
assert file_tiff.get_errors()


def test_check_band_count_invalid_3_DEM() -> None:
def test_check_band_count_invalid_3_dem() -> None:
"""
tests check_band_count when the input layer has an invalid band count
which is 3 bands where the preset is DEM.
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_check_color_interpretation_invalid() -> None:
assert file_tiff.get_errors()


def test_check_color_interpretation_valid_DEM() -> None:
def test_check_color_interpretation_valid_dem() -> None:
"""
tests check_color_interpretation with the correct color interpretation
"""
Expand All @@ -170,7 +170,7 @@ def test_check_color_interpretation_valid_DEM() -> None:
assert not file_tiff.get_errors()


def test_check_color_interpretation_invalid_DEM() -> None:
def test_check_color_interpretation_invalid_dem() -> None:
"""
tests check_color_interpretation with the incorrect color interpretation
"""
Expand Down
Loading

0 comments on commit 140f32a

Please sign in to comment.