Skip to content

Commit

Permalink
feat: keep Collection properties.created date when resupplying (#1160)
Browse files Browse the repository at this point in the history
### Motivation

Ensure that the STAC Collection `properties.created` datetime is not
changed when resupplying.

### Modifications

- Add `--published-path` parameter to `collection_from_items.py`, to get
created datetime on resupply.
- Add `--current-datetime` parameter to `collection_from_items.py` to
use that instead of the actual clock time for created/updated datetimes.

### Verification

<!-- TODO: Say how you tested your changes. -->
  • Loading branch information
l0b0 authored Nov 26, 2024
1 parent 9d95b02 commit 8af3e21
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 171 deletions.
17 changes: 17 additions & 0 deletions scripts/collection_from_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def parse_args(args: List[str] | None) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--uri", dest="uri", help="s3 path to items and collection.json write location", required=True)
parser.add_argument("--collection-id", dest="collection_id", help="Collection ID", required=True)
parser.add_argument(
"--odr-url",
dest="odr_url",
help="The path of the published dataset. Example: 's3://nz-imagery/wellington/porirua_2024_0.1m/rgb/2193/'",
required=False,
)
parser.add_argument(
"--category",
dest="category",
Expand Down Expand Up @@ -104,6 +110,15 @@ def parse_args(args: List[str] | None) -> Namespace:
help="Add a title suffix to the collection title based on the lifecycle. For example, '[TITLE] - Preview'",
required=False,
)
parser.add_argument(
"--current-datetime",
dest="current_datetime",
help=(
"The datetime that is used as current datetime in the metadata. "
"Format: RFC 3339 UTC datetime, `YYYY-MM-DDThh:mm:ssZ`."
),
required=True,
)

return parser.parse_args(args)

Expand Down Expand Up @@ -184,13 +199,15 @@ def main(args: List[str] | None = None) -> None:
collection_id=collection_id,
linz_slug=arguments.linz_slug,
collection_metadata=collection_metadata,
current_datetime=arguments.current_datetime,
producers=coalesce_multi_single(arguments.producer_list, arguments.producer),
licensors=coalesce_multi_single(arguments.licensor_list, arguments.licensor),
stac_items=items_to_add,
item_polygons=polygons,
add_capture_dates=arguments.capture_dates,
uri=uri,
add_title_suffix=arguments.add_title_suffix,
odr_url=arguments.odr_url,
)

destination = os.path.join(uri, "collection.json")
Expand Down
16 changes: 0 additions & 16 deletions scripts/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,6 @@ def format_rfc_3339_nz_midnight_datetime_string(datetime_object: datetime) -> st
return format_rfc_3339_datetime_string(datetime_utc)


def utc_now() -> datetime:
"""
Get the current datetime with UTC time zone
Should return something close to the current time:
>>> current_timestamp = datetime.now().timestamp()
>>> current_timestamp - 5 < utc_now().timestamp() < current_timestamp + 5
True
Should have UTC time zone:
>>> utc_now().tzname()
'UTC'
"""
return datetime.now(tz=timezone.utc)


class NaiveDatetimeError(Exception):
def __init__(self) -> None:
super().__init__("Can't convert naive datetime to UTC")
15 changes: 0 additions & 15 deletions scripts/files/fs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import os
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

from boto3 import client
from linz_logger import get_log
Expand All @@ -11,11 +8,6 @@
from scripts.files import fs_local, fs_s3
from scripts.stac.util.checksum import multihash_as_hex

if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client
else:
S3Client = dict


def write(destination: str, source: bytes, content_type: str | None = None) -> str:
"""Write a file from its source to a destination path.
Expand Down Expand Up @@ -87,13 +79,6 @@ def exists(path: str) -> bool:
return fs_local.exists(path)


def modified(path: str, s3_client: S3Client | None = None) -> datetime:
"""Get modified datetime for S3 URL or local path"""
if is_s3(path):
return fs_s3.modified(fs_s3.bucket_name_from_path(path), fs_s3.prefix_from_path(path), s3_client)
return fs_local.modified(Path(path))


def write_all(inputs: list[str], target: str, concurrency: int | None = 4, generate_name: bool | None = True) -> list[str]:
"""Writes list of files to target destination using multithreading.
Args:
Expand Down
8 changes: 0 additions & 8 deletions scripts/files/fs_local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from datetime import datetime, timezone
from pathlib import Path


def write(destination: str, source: bytes) -> None:
Expand Down Expand Up @@ -38,9 +36,3 @@ def exists(path: str) -> bool:
True if the path exists
"""
return os.path.exists(path)


def modified(path: Path) -> datetime:
"""Get path modified datetime as UTC"""
modified_timestamp = os.path.getmtime(path)
return datetime.fromtimestamp(modified_timestamp, tz=timezone.utc)
6 changes: 0 additions & 6 deletions scripts/files/fs_s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import Generator
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING, Any

from boto3 import client
Expand Down Expand Up @@ -236,8 +235,3 @@ def get_object_parallel_multithreading(
yield key, future.result()
else:
yield key, exception


def modified(bucket_name: str, key: str, s3_client: S3Client | None) -> datetime:
s3_client = s3_client or client("s3")
return _get_object(bucket_name, key, s3_client)["LastModified"]
12 changes: 1 addition & 11 deletions scripts/files/tests/fs_local_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
from pathlib import Path

import pytest

from scripts.files.fs_local import exists, modified, read, write
from scripts.tests.datetimes_test import any_epoch_datetime
from scripts.files.fs_local import exists, read, write


@pytest.mark.dependency(name="write")
Expand Down Expand Up @@ -45,11 +43,3 @@ def test_exists(setup: str) -> None:
def test_exists_file_not_found() -> None:
found = exists("/tmp/test.file")
assert found is False


def test_should_get_modified_datetime(setup: str) -> None:
path = Path(os.path.join(setup, "modified.file"))
path.touch()
modified_datetime = any_epoch_datetime()
os.utime(path, times=(any_epoch_datetime().timestamp(), modified_datetime.timestamp()))
assert modified(path) == modified_datetime
20 changes: 1 addition & 19 deletions scripts/files/tests/fs_s3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
from boto3 import client
from botocore.exceptions import ClientError
from moto import mock_aws
from moto.core.models import DEFAULT_ACCOUNT_ID
from moto.s3.models import s3_backends
from moto.s3.responses import DEFAULT_REGION_NAME
from moto.wafv2.models import GLOBAL_REGION
from mypy_boto3_s3 import S3Client
from pytest import CaptureFixture, raises
from pytest_subtests import SubTests

from scripts.files.files_helper import ContentType
from scripts.files.fs_s3 import exists, list_files_in_uri, modified, read, write
from scripts.tests.datetimes_test import any_epoch_datetime
from scripts.files.fs_s3 import exists, list_files_in_uri, read, write


@mock_aws
Expand Down Expand Up @@ -165,17 +161,3 @@ def test_list_files_in_uri(subtests: SubTests) -> None:

with subtests.test():
assert "data/image.tiff" not in files


@mock_aws
def test_should_get_modified_datetime() -> None:
bucket_name = "any-bucket-name"
key = "any-key"
modified_datetime = any_epoch_datetime()

s3_client: S3Client = client("s3", region_name=DEFAULT_REGION_NAME)
s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(Bucket=bucket_name, Key=key, Body=b"any body")
s3_backends[DEFAULT_ACCOUNT_ID][GLOBAL_REGION].buckets[bucket_name].keys[key].last_modified = modified_datetime

assert modified(bucket_name, key, s3_client) == modified_datetime
28 changes: 1 addition & 27 deletions scripts/files/tests/fs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@

from boto3 import client
from moto import mock_aws
from moto.core.models import DEFAULT_ACCOUNT_ID
from moto.s3.models import s3_backends
from moto.s3.responses import DEFAULT_REGION_NAME
from moto.wafv2.models import GLOBAL_REGION
from mypy_boto3_s3 import S3Client
from pytest import CaptureFixture, raises
from pytest_subtests import SubTests

from scripts.files.fs import NoSuchFileError, modified, read, write, write_all, write_sidecars
from scripts.tests.datetimes_test import any_epoch_datetime
from scripts.files.fs import NoSuchFileError, read, write, write_all, write_sidecars


def test_read_key_not_found_local() -> None:
Expand Down Expand Up @@ -103,25 +99,3 @@ def test_write_all_in_order(setup: str) -> None:
i += 1
written_files = write_all(inputs=inputs, target=setup, generate_name=False)
assert written_files == inputs


@mock_aws
def test_should_get_s3_object_modified_datetime() -> None:
bucket_name = "any-bucket-name"
key = "any-key"
modified_datetime = any_epoch_datetime()

s3_client: S3Client = client("s3", region_name=DEFAULT_REGION_NAME)
s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(Bucket=bucket_name, Key=key, Body=b"any body")
s3_backends[DEFAULT_ACCOUNT_ID][GLOBAL_REGION].buckets[bucket_name].keys[key].last_modified = modified_datetime

assert modified(f"s3://{bucket_name}/{key}", s3_client) == modified_datetime


def test_should_get_local_file_modified_datetime(setup: str) -> None:
path = os.path.join(setup, "modified.file")
Path(path).touch()
modified_datetime = any_epoch_datetime()
os.utime(path, times=(any_epoch_datetime().timestamp(), modified_datetime.timestamp()))
assert modified(path) == modified_datetime
10 changes: 4 additions & 6 deletions scripts/stac/imagery/collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from collections.abc import Callable
from datetime import datetime
from typing import Any

import ulid
Expand Down Expand Up @@ -41,7 +39,8 @@ class ImageryCollection:
def __init__(
self,
metadata: CollectionMetadata,
now: Callable[[], datetime],
created_datetime: str,
updated_datetime: str,
linz_slug: str,
collection_id: str | None = None,
providers: list[Provider] | None = None,
Expand All @@ -52,7 +51,6 @@ def __init__(

self.metadata = metadata

now_string = format_rfc_3339_datetime_string(now())
self.stac = {
"type": "Collection",
"stac_version": STAC_VERSION,
Expand All @@ -67,8 +65,8 @@ def __init__(
"linz:region": metadata["region"],
"linz:security_classification": "unclassified",
"linz:slug": linz_slug,
"created": now_string,
"updated": now_string,
"created": created_datetime,
"updated": updated_datetime,
}

# Optional metadata
Expand Down
36 changes: 27 additions & 9 deletions scripts/stac/imagery/create_stac.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import os
from typing import Any
from typing import Any, TypeAlias, cast

from linz_logger import get_log
from shapely.geometry.base import BaseGeometry

from scripts.datetimes import utc_now
from scripts.files import fs
from scripts.files.files_helper import get_file_name_from_path
from scripts.files.fs import NoSuchFileError, read
Expand All @@ -20,19 +19,24 @@
from scripts.stac.util import checksum
from scripts.stac.util.media_type import StacMediaType

JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
JSON_Dict: TypeAlias = dict[str, "JSON"]


# pylint: disable=too-many-arguments
def create_collection(
collection_id: str,
linz_slug: str,
collection_metadata: CollectionMetadata,
current_datetime: str,
producers: list[str],
licensors: list[str],
stac_items: list[dict[Any, Any]],
item_polygons: list[BaseGeometry],
add_capture_dates: bool,
uri: str,
add_title_suffix: bool = False,
odr_url: str | None = None,
) -> ImageryCollection:
"""Create an ImageryCollection object.
If `item_polygons` is not empty, it will add a generated capture area to the collection.
Expand All @@ -41,29 +45,30 @@ def create_collection(
collection_id: id of the collection
linz_slug: the linz:slug attribute for this collection
collection_metadata: metadata of the collection
current_datetime: datetime string that represents the current time when the item is created.
producers: producers of the dataset
licensors: licensors of the dataset
stac_items: items to link to the collection
item_polygons: polygons of the items linked to the collection
add_capture_dates: whether to add a capture-dates.geojson.gz file to the collection assets
uri: path of the dataset
add_title_suffix: whether to add a title suffix to the collection title based on the lifecycle
odr_url: path of the published dataset. Defaults to None.
Returns:
an ImageryCollection object
"""
providers: list[Provider] = []
for producer_name in producers:
providers.append({"name": producer_name, "roles": [ProviderRole.PRODUCER]})
for licensor_name in licensors:
providers.append({"name": licensor_name, "roles": [ProviderRole.LICENSOR]})
existing_collection = {}
if odr_url:
existing_collection = get_published_file_contents(odr_url, "collection")

collection = ImageryCollection(
metadata=collection_metadata,
now=utc_now,
created_datetime=cast(str, existing_collection.get("created", current_datetime)),
updated_datetime=current_datetime,
linz_slug=linz_slug,
collection_id=collection_id,
providers=providers,
providers=get_providers(licensors, producers),
add_title_suffix=add_title_suffix,
)

Expand All @@ -82,6 +87,15 @@ def create_collection(
return collection


def get_providers(licensors: list[str], producers: list[str]) -> list[Provider]:
providers: list[Provider] = []
for producer_name in producers:
providers.append({"name": producer_name, "roles": [ProviderRole.PRODUCER]})
for licensor_name in licensors:
providers.append({"name": licensor_name, "roles": [ProviderRole.LICENSOR]})
return providers


def create_item(
asset_path: str,
start_datetime: str,
Expand Down Expand Up @@ -194,3 +208,7 @@ def create_or_load_base_item(
)

return ImageryItem(id_, stac_asset, stac_processing)


def get_published_file_contents(odr_url: str, filename: str) -> JSON_Dict:
return cast(JSON_Dict, json.loads(read(os.path.join(odr_url, f"{filename}.json")).decode("UTF-8")))
Loading

0 comments on commit 8af3e21

Please sign in to comment.