Skip to content

Commit

Permalink
feat: Add done method for pipeline, training, and batch prediction jo…
Browse files Browse the repository at this point in the history
…bs (#1062)

Added a done method via a `DoneMixin` class to check the status of long running jobs (returns True or False based on job state):
* Implemented by `PipelineJob`, `_Job`, and `_TrainingJob`
* Added system tests in `aiplatform/tests/system/aiplatform/test_e2e_tabular.py`
* Added pipeline job tests in `tests/unit/aiplatform/test_pipeline_jobs.py`
* Still need to add unit tests in `test_jobs` and `test_training_jobs`

Fixes b/215396514
  • Loading branch information
sararob authored Mar 9, 2022
1 parent 6002d5d commit f3338fc
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 3 deletions.
55 changes: 55 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,58 @@ def get_annotation_class(annotation: type) -> type:
return annotation.__args__[0]
else:
return annotation


class DoneMixin(abc.ABC):
"""An abstract class for implementing a done method, indicating
whether a job has completed.
"""

@abc.abstractmethod
def done(self) -> bool:
"""Method indicating whether a job has completed."""
pass


class StatefulResource(DoneMixin):
"""Extends DoneMixin to check whether a job returning a stateful resource has compted."""

@property
@abc.abstractmethod
def state(self):
"""The current state of the job."""
pass

@property
@classmethod
@abc.abstractmethod
def _valid_done_states(cls):
"""A set() containing all job states associated with a completed job."""
pass

def done(self) -> bool:
"""Method indicating whether a job has completed.
Returns:
True if the job has completed.
"""
if self.state in self._valid_done_states:
return True
else:
return False


class VertexAiStatefulResource(VertexAiResourceNounWithFutureManager, StatefulResource):
"""Extends StatefulResource to include a check for self._gca_resource."""

def done(self) -> bool:
"""Method indicating whether a job has completed.
Returns:
True if the job has completed.
"""
if self._gca_resource and self._gca_resource.name:
return super().done()
else:
return False
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)


class _Job(base.VertexAiResourceNounWithFutureManager):
class _Job(base.VertexAiStatefulResource):
"""Class that represents a general Job resource in Vertex AI.
Cannot be directly instantiated.
Expand All @@ -83,6 +83,9 @@ class _Job(base.VertexAiResourceNounWithFutureManager):

client_class = utils.JobClientWithOverride

# Required by the done() method
_valid_done_states = _JOB_COMPLETE_STATES

def __init__(
self,
job_name: str,
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _set_enable_caching_value(
task["cachingOptions"] = {"enableCache": enable_caching}


class PipelineJob(base.VertexAiResourceNounWithFutureManager):
class PipelineJob(base.VertexAiStatefulResource):

client_class = utils.PipelineJobClientWithOverride
_resource_noun = "pipelineJobs"
Expand All @@ -87,6 +87,9 @@ class PipelineJob(base.VertexAiResourceNounWithFutureManager):
_parse_resource_name_method = "parse_pipeline_job_path"
_format_resource_name_method = "pipeline_job_path"

# Required by the done() method
_valid_done_states = _PIPELINE_COMPLETE_STATES

def __init__(
self,
display_name: str,
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)


class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
class _TrainingJob(base.VertexAiStatefulResource):

client_class = utils.PipelineClientWithOverride
_resource_noun = "trainingPipelines"
Expand All @@ -76,6 +76,9 @@ class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
_parse_resource_name_method = "parse_training_pipeline_path"
_format_resource_name_method = "training_pipeline_path"

# Required by the done() method
_valid_done_states = _PIPELINE_COMPLETE_STATES

def __init__(
self,
display_name: str,
Expand Down
8 changes: 8 additions & 0 deletions tests/system/aiplatform/test_e2e_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def test_end_to_end_tabular(self, shared_state):

shared_state["resources"].append(custom_batch_prediction_job)

in_progress_done_check = custom_job.done()
custom_job.wait_for_resource_creation()

automl_job.wait_for_resource_creation()
custom_batch_prediction_job.wait_for_resource_creation()

Expand Down Expand Up @@ -174,6 +176,8 @@ def test_end_to_end_tabular(self, shared_state):
# Test lazy loading of Endpoint, check getter was never called after predict()
custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name)
custom_endpoint.predict([_INSTANCE])

completion_done_check = custom_job.done()
assert custom_endpoint._skipped_getter_call()

assert (
Expand Down Expand Up @@ -201,3 +205,7 @@ def test_end_to_end_tabular(self, shared_state):
assert 200000 > custom_result > 50000
except KeyError as e:
raise RuntimeError("Unexpected prediction response structure:", e)

# Check done() method works correctly
assert in_progress_done_check is False
assert completion_done_check is True
29 changes: 29 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ def test_batch_prediction_job_status(self, get_batch_prediction_job_mock):
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=base._DEFAULT_RETRY
)

def test_batch_prediction_job_done_get(self, get_batch_prediction_job_mock):
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)

assert bp.done() is False
assert get_batch_prediction_job_mock.call_count == 2

@pytest.mark.usefixtures("get_batch_prediction_job_gcs_output_mock")
def test_batch_prediction_iter_dirs_gcs(self, storage_list_blobs_mock):
bp = jobs.BatchPredictionJob(
Expand Down Expand Up @@ -507,6 +515,27 @@ def test_batch_predict_gcs_source_and_dest(
batch_prediction_job=expected_gapic_batch_prediction_job,
)

@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

# Make SDK batch_predict method call
batch_prediction_job = jobs.BatchPredictionJob.create(
model_name=_TEST_MODEL_NAME,
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=False,
)

batch_prediction_job.wait_for_resource_creation()

assert batch_prediction_job.done() is False

batch_prediction_job.wait()

assert batch_prediction_job.done() is True

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_gcs_source_bq_dest(
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,39 @@ def test_submit_call_pipeline_service_pipeline_job_create(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
)
def test_done_method_pipeline_service(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec_json,
mock_load_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize(
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
)
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,65 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(

assert job._has_logged_custom_job

def test_custom_training_tabular_done(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_python_package_to_gcs,
mock_tabular_dataset,
mock_model_service_get,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
)

job.run(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME,
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
sync=False,
)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_bigquery_destination(
self,
Expand Down Expand Up @@ -2323,6 +2382,59 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_custom_container_training_tabular_done(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_tabular_dataset,
mock_model_service_get,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.CustomContainerTrainingJob(
display_name=_TEST_DISPLAY_NAME,
labels=_TEST_LABELS,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
command=_TEST_TRAINING_CONTAINER_CMD,
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
model_description=_TEST_MODEL_DESCRIPTION,
)

job.run(
dataset=mock_tabular_dataset,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
args=_TEST_RUN_ARGS,
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
machine_type=_TEST_MACHINE_TYPE,
accelerator_type=_TEST_ACCELERATOR_TYPE,
accelerator_count=_TEST_ACCELERATOR_COUNT,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
model_labels=_TEST_MODEL_LABELS,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
service_account=_TEST_SERVICE_ACCOUNT,
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
sync=False,
)

assert job.done() is False

job.wait()

assert job.done() is True

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_tabular_dataset(
self,
Expand Down

0 comments on commit f3338fc

Please sign in to comment.