diff --git a/pygeoapi_kubernetes_papermill/argo.py b/pygeoapi_kubernetes_papermill/argo.py index 4cc43b1..7a181a1 100644 --- a/pygeoapi_kubernetes_papermill/argo.py +++ b/pygeoapi_kubernetes_papermill/argo.py @@ -29,26 +29,51 @@ from __future__ import annotations +import datetime import logging -from typing import Optional, Any +from typing import Optional, Any, cast from kubernetes import client as k8s_client, config as k8s_config -from pygeoapi.process.manager.base import BaseManager +from http import HTTPStatus +import json + +import kubernetes.client.rest + + +from pygeoapi.process.manager.base import BaseManager, DATETIME_FORMAT from pygeoapi.util import ( JobStatus, Subscriber, RequestedResponse, ) -# TODO: move elsewhere if we keep this -from .kubernetes import JobDict +from pygeoapi.process.base import ( + JobNotFoundError, +) +from .common import ( + k8s_job_name, + current_namespace, + format_annotation_key, + now_str, + parse_annotation_key, + hide_secret_values, + JobDict, +) -from .common import current_namespace, k8s_job_name LOGGER = logging.getLogger(__name__) +WORKFLOWS_API_GROUP = "argoproj.io" +WORKFLOWS_API_VERSION = "v1alpha1" + +K8S_CUSTOM_OBJECT_WORKFLOWS = { + "group": WORKFLOWS_API_GROUP, + "version": WORKFLOWS_API_VERSION, + "plural": "workflows", +} + class ArgoManager(BaseManager): def __init__(self, manager_def: dict) -> None: @@ -95,7 +120,18 @@ def get_job(self, job_id) -> Optional[JobDict]: :returns: `dict` # `pygeoapi.process.manager.Job` """ - raise NotImplementedError + try: + k8s_wf: dict = self.custom_objects_api.get_namespaced_custom_object( + **K8S_CUSTOM_OBJECT_WORKFLOWS, + name=k8s_job_name(job_id=job_id), + namespace=self.namespace, + ) + return job_from_k8s_wf(k8s_wf) + except kubernetes.client.rest.ApiException as e: + if e.status == HTTPStatus.NOT_FOUND: + raise JobNotFoundError + else: + raise def add_job(self, job_metadata): """ @@ -164,19 +200,23 @@ def _execute_handler_async( and JobStatus.accepted (i.e. initial job status) """ - api_group = "argoproj.io" - api_version = "v1alpha1" + annotations = { + "identifier": job_id, + "process_id": p.metadata.get("id"), + "job_start_datetime": now_str(), + } - # TODO test with this - # https://github.com/argoproj/argo-workflows/blob/main/examples/workflow-template/workflow-template-ref-with-entrypoint-arg-passing.yaml body = { - "apiVersion": f"{api_group}/{api_version}", + "apiVersion": f"{WORKFLOWS_API_GROUP}/{WORKFLOWS_API_VERSION}", "kind": "Workflow", "metadata": { "name": k8s_job_name(job_id), "namespace": self.namespace, # TODO: labels to identify our jobs? # "labels": {} + "annotations": { + format_annotation_key(k): v for k, v in annotations.items() + }, }, "spec": { "arguments": { @@ -190,10 +230,75 @@ def _execute_handler_async( }, } self.custom_objects_api.create_namespaced_custom_object( - group=api_group, - version=api_version, + **K8S_CUSTOM_OBJECT_WORKFLOWS, namespace=self.namespace, - plural="workflows", body=body, ) return ("application/json", {}, JobStatus.accepted) + + +def job_from_k8s_wf(workflow: dict) -> JobDict: + annotations = workflow["metadata"]["annotations"] or {} + metadata = { + parsed_key: v + for orig_key, v in annotations.items() + if (parsed_key := parse_annotation_key(orig_key)) + } + + metadata["parameters"] = json.dumps( + hide_secret_values( + { + param["name"]: param["value"] + for param in workflow["spec"]["arguments"]["parameters"] + } + ) + ) + + status = status_from_argo_phase(workflow["status"]["phase"]) + + if started_at := workflow["status"].get("startedAt"): + metadata["job_start_datetime"] = argo_date_str_to_pygeoapi_date_str(started_at) + if finished_at := workflow["status"].get("finishedAt"): + metadata["job_end_datetime"] = argo_date_str_to_pygeoapi_date_str(finished_at) + default_progress = "100" if status == JobStatus.successful else "1" + # TODO: parse progress fromm wf status progress "1/2" + + return cast( + JobDict, + { + # need this key in order not to crash, overridden by metadata: + "identifier": "", + "process_id": "", + "job_start_datetime": "", + "status": status.value, + "mimetype": None, # we don't know this in general + "message": "", # TODO: what to show here? + "progress": default_progress, + **metadata, + }, + ) + + +def argo_date_str_to_pygeoapi_date_str(argo_date_str: str) -> str: + ARGO_DATE_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + return datetime.datetime.strptime( + argo_date_str, + ARGO_DATE_FORMAT, + ).strftime(DATETIME_FORMAT) + + +def status_from_argo_phase(phase: str) -> JobStatus: + if phase == "Pending": + return JobStatus.accepted + elif phase == "Running": + return JobStatus.running + elif phase == "Succeeded": + return JobStatus.successful + elif phase == "Failed": + return JobStatus.failed + elif phase == "Error": + return JobStatus.failed + elif phase == "": + return JobStatus.accepted + else: + raise AssertionError(f"Invalid argo wf phase {phase}") diff --git a/pygeoapi_kubernetes_papermill/common.py b/pygeoapi_kubernetes_papermill/common.py index 5f2b0e3..d120713 100644 --- a/pygeoapi_kubernetes_papermill/common.py +++ b/pygeoapi_kubernetes_papermill/common.py @@ -31,12 +31,15 @@ import functools import logging import operator -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, TypedDict import re from pathlib import PurePath from http import HTTPStatus +from datetime import datetime, timezone + from pygeoapi.process.base import ProcessorExecuteError +from pygeoapi.process.manager.base import DATETIME_FORMAT from kubernetes import client as k8s_client @@ -351,7 +354,47 @@ def extra_secret_env_config(secret_name: str, num: int) -> ExtraConfig: ) +_ANNOTATIONS_PREFIX = "pygeoapi.io/" + + +def parse_annotation_key(key: str) -> Optional[str]: + matched = re.match(f"^{_ANNOTATIONS_PREFIX}(.+)", key) + return matched.group(1) if matched else None + + +def format_annotation_key(key: str) -> str: + return _ANNOTATIONS_PREFIX + key + + def current_namespace(): # getting the current namespace like this is documented, so it should be fine: # https://kubernetes.io/docs/tasks/access-application-cluster/access-cluster/ return open("/var/run/secrets/kubernetes.io/serviceaccount/namespace").read() + + +def hide_secret_values(d: dict[str, str]) -> dict[str, str]: + def transform_value(k, v): + return ( + "*" + if any(trigger in k.lower() for trigger in ["secret", "key", "password"]) + else v + ) + + return {k: transform_value(k, v) for k, v in d.items()} + + +def now_str() -> str: + return datetime.now(timezone.utc).strftime(DATETIME_FORMAT) + + +JobDict = TypedDict( + "JobDict", + { + "identifier": str, + "status": str, + "result-notebook": str, + "message": str, + "job_end_datetime": Optional[str], + }, + total=False, +) diff --git a/pygeoapi_kubernetes_papermill/kubernetes.py b/pygeoapi_kubernetes_papermill/kubernetes.py index 3a1e2f9..f9f50cb 100644 --- a/pygeoapi_kubernetes_papermill/kubernetes.py +++ b/pygeoapi_kubernetes_papermill/kubernetes.py @@ -30,14 +30,13 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime from http import HTTPStatus import json import logging -import re import time from threading import Thread -from typing import Literal, Optional, Any, TypedDict, cast +from typing import Literal, Optional, Any, cast import os from kubernetes import client as k8s_client, config as k8s_config @@ -56,7 +55,17 @@ ) from pygeoapi.process.manager.base import BaseManager, DATETIME_FORMAT -from .common import is_k8s_job_name, k8s_job_name, current_namespace +from .common import ( + is_k8s_job_name, + k8s_job_name, + parse_annotation_key, + JobDict, + current_namespace, + format_annotation_key, + hide_secret_values, + now_str, +) + LOGGER = logging.getLogger(__name__) @@ -87,19 +96,6 @@ def execute(self): ) -JobDict = TypedDict( - "JobDict", - { - "identifier": str, - "status": str, - "result-notebook": str, - "message": str, - "job_end_datetime": Optional[str], - }, - total=False, -) - - class KubernetesManager(BaseManager): def __init__(self, manager_def: dict) -> None: super().__init__(manager_def) @@ -448,18 +444,6 @@ def _pod_for_job(self, job: k8s_client.V1Job) -> Optional[k8s_client.V1Pod]: return next(iter(pods.items), None) -_ANNOTATIONS_PREFIX = "pygeoapi.io/" - - -def parse_annotation_key(key: str) -> Optional[str]: - matched = re.match(f"^{_ANNOTATIONS_PREFIX}(.+)", key) - return matched.group(1) if matched else None - - -def format_annotation_key(key: str) -> str: - return _ANNOTATIONS_PREFIX + key - - def job_status_from_k8s(status: k8s_client.V1JobStatus) -> JobStatus: # we assume only 1 run without retries @@ -526,17 +510,6 @@ def job_from_k8s(job: k8s_client.V1Job, message: Optional[str]) -> JobDict: ) -def hide_secret_values(d: dict[str, str]) -> dict[str, str]: - def transform_value(k, v): - return ( - "*" - if any(trigger in k.lower() for trigger in ["secret", "key", "password"]) - else v - ) - - return {k: transform_value(k, v) for k, v in d.items()} - - def get_completion_time(job: k8s_client.V1Job, status: JobStatus) -> Optional[datetime]: if status == JobStatus.failed: # failed jobs have special completion time field @@ -608,7 +581,3 @@ def get_jobs_by_status( return [ job for job in jobs if job_status_from_k8s(job.status) == JobStatus.failed ] - - -def now_str() -> str: - return datetime.now(timezone.utc).strftime(DATETIME_FORMAT) diff --git a/pygeoapi_kubernetes_papermill/notebook.py b/pygeoapi_kubernetes_papermill/notebook.py index 782a120..14f0022 100644 --- a/pygeoapi_kubernetes_papermill/notebook.py +++ b/pygeoapi_kubernetes_papermill/notebook.py @@ -48,7 +48,6 @@ from kubernetes import client as k8s_client from .kubernetes import ( - JobDict, KubernetesProcessor, current_namespace, format_annotation_key, @@ -62,6 +61,7 @@ JOVIAN_UID, JOVIAN_GID, setup_byoa_results_dir_cmd, + JobDict, ) LOGGER = logging.getLogger(__name__) diff --git a/tests/test_argo_manager.py b/tests/test_argo_manager.py index db88fec..b40cd54 100644 --- a/tests/test_argo_manager.py +++ b/tests/test_argo_manager.py @@ -28,6 +28,7 @@ # ================================================================= from unittest import mock +import json import pytest from kubernetes import client as k8s_client @@ -80,16 +81,27 @@ def test_execute_process_starts_async_job( ] assert job_id in job["metadata"]["name"] - # TODO - # $ assert job.metadata.annotations["pygeoapi.io/identifier"] == job_id - # $ assert ( - # $ job.metadata.annotations["pygeoapi.io/success-uri"] - # $ == "https://example.com/success" - # $ ) - # $ assert ( - # $ job.metadata.annotations["pygeoapi.io/failed-uri"] - # $ == "https://example.com/failed" - # $ ) + assert job["metadata"]["annotations"]["pygeoapi.io/identifier"] == job_id + # assert ( + # job.metadata.annotations["pygeoapi.io/success-uri"] + # == "https://example.com/success" + # ) + # assert ( + # job.metadata.annotations["pygeoapi.io/failed-uri"] + # == "https://example.com/failed" + # ) + + +def test_get_job_returns_workflow( + manager: ArgoManager, + mock_get_workflow, +): + job_id = "abc" + job = manager.get_job(job_id=job_id) + assert job["identifier"] == "annotations-identifier" + assert json.loads(job["parameters"]) == {"inpfile": "test2.txt"} + assert job["job_start_datetime"] == "2024-09-18T12:01:02.000000Z" + assert job["status"] == "successful" @pytest.fixture() @@ -102,38 +114,43 @@ def mock_create_workflow(): yield mocker -""" - -@contextmanager -def mock_list_jobs_with(*args): - with mock.patch( - "pygeoapi_kubernetes_papermill." "kubernetes.k8s_client.CustomObjectsApi.XXX", - return_value=k8s_client.V1JobList(items=args), - ): - yield - - -@pytest.fixture() -def mock_list_jobs(k8s_job): - with mock_list_jobs_with(k8s_job): - yield - - -@pytest.fixture() -def mock_list_jobs_accepted(k8s_job: k8s_client.V1Job): - k8s_job.status = k8s_client.V1JobStatus() - with mock_list_jobs_with(k8s_job): - yield - +MOCK_WORKFLOW = { + "apiVersion": "argoproj.io/v1alpha1", + "kind": "Workflow", + "metadata": { + "name": "workflow-test-instance-4", + "namespace": "test", + "annotations": { + "pygeoapi.io/identifier": "annotations-identifier", + }, + }, + "spec": { + "arguments": {"parameters": [{"name": "inpfile", "value": "test2.txt"}]}, + "entrypoint": "test", + "workflowTemplateRef": {"name": "workflow-template-test"}, + }, + "status": { + "artifactGCStatus": {"notSpecified": True}, + "artifactRepositoryRef": {"artifactRepository": {}, "default": True}, + "conditions": [ + {"status": "False", "type": "PodRunning"}, + {"status": "True", "type": "Completed"}, + ], + "finishedAt": "2024-09-18T12:01:12Z", + "phase": "Succeeded", + "progress": "1/1", + "resourcesDuration": {"cpu": 0, "memory": 3}, + "startedAt": "2024-09-18T12:01:02Z", + "taskResultsCompletionStatus": {"workflow-test-instance-4": True}, + }, +} @pytest.fixture() -def mock_patch_job(): +def mock_get_workflow(): with mock.patch( "pygeoapi_kubernetes_papermill." - "kubernetes.k8s_client.BatchV1Api.patch_namespaced_job", + "kubernetes.k8s_client.CustomObjectsApi.get_namespaced_custom_object", + return_value=MOCK_WORKFLOW, ) as mocker: yield mocker - - -""" diff --git a/tests/test_notebook_processor.py b/tests/test_notebook_processor.py index 9c8f93b..7127358 100644 --- a/tests/test_notebook_processor.py +++ b/tests/test_notebook_processor.py @@ -41,7 +41,7 @@ from pygeoapi.process.base import ProcessorExecuteError -from pygeoapi_kubernetes_papermill.kubernetes import JobDict +from pygeoapi_kubernetes_papermill.common import JobDict from pygeoapi_kubernetes_papermill.notebook import ( CONTAINER_HOME, PapermillNotebookKubernetesProcessor,