Skip to content

Commit

Permalink
Merge pull request #19 from teamdatatonic/features/gcp_resources
Browse files Browse the repository at this point in the history
feat: return gcp resources
  • Loading branch information
felix-datatonic authored Jun 5, 2023
2 parents 4d259e3 + 15675d4 commit 44730dd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2022 Google LLC
from typing import List, Dict
from typing import List, Dict, NamedTuple

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,10 @@

@component(
base_image="python:3.7",
packages_to_install=["google-cloud-aiplatform==1.24.1"],
packages_to_install=[
"google-cloud-aiplatform==1.24.1",
"google-cloud-pipeline-components==1.0.42",
],
)
def custom_train_job(
train_script_uri: str,
Expand All @@ -41,7 +44,7 @@ def custom_train_job(
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
parent_model: str = None,
):
) -> NamedTuple("Outputs", [("gcp_resources", str)]):
"""Run a custom training job using a training script.
The provided script will be invoked by passing the following command-line arguments:
Expand Down Expand Up @@ -87,12 +90,15 @@ def custom_train_job(
Returns:
parent_model (str): Resource URI of the parent model (empty string if the
trained model is the first model version of its kind).
NamedTuple: gcp_resources for Vertex AI UI integration.
"""
import json
import logging
import os.path
import time
import google.cloud.aiplatform as aip
from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources
from google.protobuf.json_format import MessageToJson

logging.info(f"Using train script: {train_script_uri}")
script_path = "/gcs/" + train_script_uri[5:]
Expand Down Expand Up @@ -143,3 +149,12 @@ def custom_train_job(
for k, v in parsed_metrics.items():
if type(v) is float:
metrics.log_metric(k, v)

# return GCP resource for Vertex AI UI integration
custom_job_name = job.to_dict()["trainingTaskMetadata"]["backingCustomJob"]
custom_train_job_resources = GcpResources()
ctr = custom_train_job_resources.resources.add()
ctr.resource_type = "CustomJob"
ctr.resource_uri = custom_job_name
gcp_resources = MessageToJson(custom_train_job_resources)
return (gcp_resources,)
16 changes: 13 additions & 3 deletions components/vertex-components/tests/test_custom_training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kfp.v2.dsl import Dataset, Metrics, Artifact
from unittest import mock
import pytest

import json

import vertex_components

Expand All @@ -24,7 +24,12 @@ def test_custom_train_job(mock_open, mock_exists, mock_job, tmpdir):
mock_model = Artifact(uri=tmpdir, metadata={"resourceName": ""})
mock_metrics = Metrics(uri=tmpdir)

custom_train_job(
mock_job_instance = mock_job.return_value
mock_job_instance.to_dict.return_value = {
"trainingTaskMetadata": {"backingCustomJob": "mock_custom_job_name"}
}

(gcp_resources,) = custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
Expand Down Expand Up @@ -53,6 +58,11 @@ def test_custom_train_job(mock_open, mock_exists, mock_job, tmpdir):

# Assert metrics loading
mock_open.assert_called_once_with(tmpdir, "r")
# Assert gcp_resources contains the expected value
assert (
json.loads(gcp_resources)["resources"][0]["resourceUri"]
== "mock_custom_job_name"
)


@mock.patch("google.cloud.aiplatform.CustomTrainingJob")
Expand All @@ -72,7 +82,7 @@ def test_custom_train_script_not_found(mock_open, mock_exists, mock_job, tmpdir)
mock_metrics = Metrics(uri=tmpdir)

with pytest.raises(ValueError):
custom_train_job(
(gcp_resources,) = custom_train_job(
train_script_uri="gs://my-bucket/train_script.py",
train_data=mock_train_data,
valid_data=mock_valid_data,
Expand Down

0 comments on commit 44730dd

Please sign in to comment.