forked from kubeflow/pipelines
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test(components): Added integration test for Sagemaker TrainingJob co…
…mponent v2. (kubeflow#8247) * Integration tests for sagemaker training v2 * removed redundant check * removed redundant print * added safety check * pr changes * updated python and kubernetes * reverting dependency versions * Revert "updated python and kubernetes" This reverts commit e92034d. * added linting
- Loading branch information
Showing
9 changed files
with
277 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
components/aws/sagemaker/tests/integration_tests/component_tests/test_v2_train_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import pytest | ||
import os | ||
import utils | ||
from utils import kfp_client_utils | ||
from utils import minio_utils | ||
from utils import ack_utils | ||
import ast | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_file_dir", | ||
[pytest.param("resources/config/ack-training-job", marks=pytest.mark.canary_test)], | ||
) | ||
def test_trainingjobV2(kfp_client, experiment_id, test_file_dir): | ||
k8s_client = ack_utils.k8s_client() | ||
test_file_dir = "resources/config/ack-training-job" | ||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated")) | ||
test_params = utils.load_params( | ||
utils.replace_placeholders( | ||
os.path.join(test_file_dir, "config.yaml"), | ||
os.path.join(download_dir, "ack-training-job.yaml"), | ||
) | ||
) | ||
input_job_name = utils.generate_random_string(10) + "-trn-job" | ||
test_params["Arguments"]["training_job_name"] = input_job_name | ||
|
||
_, _, workflow_json = kfp_client_utils.compile_run_monitor_pipeline( | ||
kfp_client, | ||
experiment_id, | ||
test_params["PipelineDefinition"], | ||
test_params["Arguments"], | ||
download_dir, | ||
test_params["TestName"], | ||
test_params["Timeout"], | ||
) | ||
outputs = { | ||
"sagemaker-trainingjob": [ | ||
"model_artifacts", | ||
] | ||
} | ||
|
||
# Get output data | ||
output_files = minio_utils.artifact_download_iterator( | ||
workflow_json, outputs, download_dir | ||
) | ||
model_artifact = utils.read_from_file_in_tar( | ||
output_files["sagemaker-trainingjob"]["model_artifacts"] | ||
) | ||
|
||
# Verify Training job was successful on SageMaker | ||
print(f"training job name: {input_job_name}") | ||
train_response = ack_utils.describe_training_job(k8s_client, input_job_name) | ||
assert train_response["status"]["trainingJobStatus"] == "Completed" | ||
|
||
# Verify model artifacts output was generated from this run | ||
model_uri = ast.literal_eval(model_artifact)["s3ModelArtifacts"] | ||
print(f"model_artifact_url: {model_uri}") | ||
assert model_uri == train_response["status"]["modelArtifacts"]["s3ModelArtifacts"] | ||
assert input_job_name in model_uri | ||
|
||
utils.remove_dir(download_dir) | ||
|
||
|
||
def test_terminate_trainingjob(kfp_client, experiment_id): | ||
k8s_client = ack_utils.k8s_client() | ||
test_file_dir = "resources/config/ack-training-job" | ||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated_terminate")) | ||
|
||
test_params = utils.load_params( | ||
utils.replace_placeholders( | ||
os.path.join(test_file_dir, "config.yaml"), | ||
os.path.join(download_dir, "woof.yaml"), | ||
) | ||
) | ||
input_job_name = utils.generate_random_string(4) + "-terminate-job" | ||
test_params["Arguments"]["training_job_name"] = input_job_name | ||
|
||
run_id, _, _ = kfp_client_utils.compile_run_monitor_pipeline( | ||
kfp_client, | ||
experiment_id, | ||
test_params["PipelineDefinition"], | ||
test_params["Arguments"], | ||
download_dir, | ||
test_params["TestName"], | ||
60, | ||
"running", | ||
) | ||
print(f"Terminating run: {run_id} where Training job_name: {input_job_name}") | ||
kfp_client_utils.terminate_run(kfp_client, run_id) | ||
desiredStatuses = ["Stopping", "Stopped"] | ||
training_status_reached = ack_utils.wait_for_trainingjob_status( | ||
k8s_client, input_job_name, desiredStatuses, 10, 6 | ||
) | ||
assert training_status_reached | ||
|
||
utils.remove_dir(download_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
...nents/aws/sagemaker/tests/integration_tests/resources/config/ack-training-job/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
PipelineDefinition: resources/definition/trainingv2_pipeline.py | ||
TestName: ack-training-job | ||
Timeout: 600 | ||
ExpectedTrainingImage: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1 | ||
Arguments: | ||
region: ((REGION)) | ||
algorithm_specification: | ||
trainingImage: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1 | ||
trainingInputMode: File | ||
hyper_parameters: | ||
k: "10" | ||
feature_dim: "784" | ||
input_data_config: | ||
- channelName: train | ||
dataSource: | ||
s3DataSource: | ||
s3URI: s3://((DATA_BUCKET))/mnist_kmeans_example/data | ||
s3DataType: S3Prefix | ||
s3DataDistributionType: FullyReplicated | ||
compressionType: None | ||
recordWrapperType: None | ||
inputMode: File | ||
resource_config: | ||
instanceCount: 1 | ||
instanceType: ml.m4.xlarge | ||
volumeSizeInGB: 20 | ||
stopping_condition: | ||
maxRuntimeInSeconds: 3600 | ||
output_data_config: | ||
s3OutputPath: s3://((DATA_BUCKET))/mnist_kmeans_example/output | ||
enable_network_isolation: "True" | ||
enable_managed_spot_training: "False" | ||
enable_inter_container_traffic_encryption: "False" | ||
role_arn: ((SAGEMAKER_ROLE_ARN)) |
39 changes: 39 additions & 0 deletions
39
components/aws/sagemaker/tests/integration_tests/resources/definition/trainingv2_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import kfp | ||
from kfp import components | ||
from kfp import dsl | ||
|
||
sagemaker_TrainingJob_op = components.load_component_from_file("../../TrainingJob/component.yaml") | ||
|
||
@dsl.pipeline(name="TrainingJob", description="SageMaker TrainingJob component") | ||
def TrainingJob( | ||
region="", | ||
algorithm_specification="", | ||
enable_inter_container_traffic_encryption="", | ||
enable_managed_spot_training="", | ||
enable_network_isolation="", | ||
hyper_parameters="", | ||
input_data_config="", | ||
output_data_config="", | ||
resource_config="", | ||
role_arn="", | ||
stopping_condition="", | ||
training_job_name="", | ||
): | ||
TrainingJob = sagemaker_TrainingJob_op( | ||
region=region, | ||
algorithm_specification=algorithm_specification, | ||
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption, | ||
enable_managed_spot_training=enable_managed_spot_training, | ||
enable_network_isolation=enable_network_isolation, | ||
hyper_parameters=hyper_parameters, | ||
input_data_config=input_data_config, | ||
output_data_config=output_data_config, | ||
resource_config=resource_config, | ||
role_arn=role_arn, | ||
stopping_condition=stopping_condition, | ||
training_job_name=training_job_name, | ||
) | ||
if __name__ == "__main__": | ||
kfp.compiler.Compiler().compile( | ||
TrainingJob, "SageMaker_trainingJob_pipeline" + ".yaml" | ||
) |
13 changes: 13 additions & 0 deletions
13
components/aws/sagemaker/tests/integration_tests/scripts/ack-rbac.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
apiVersion: rbac.authorization.k8s.io/v1 | ||
kind: ClusterRoleBinding | ||
metadata: | ||
name: ack-sagemaker-controller-rolebinding | ||
namespace: kubeflow | ||
roleRef: | ||
apiGroup: rbac.authorization.k8s.io | ||
kind: ClusterRole | ||
name: ack-sagemaker-controller | ||
subjects: | ||
- kind: ServiceAccount | ||
name: pipeline-runner | ||
namespace: kubeflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
components/aws/sagemaker/tests/integration_tests/utils/ack_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from time import sleep | ||
from kubernetes import client,config | ||
import os | ||
|
||
def k8s_client(): | ||
return config.new_client_from_config() | ||
|
||
def _get_resource(k8s_client,job_name,kvars): | ||
"""Get the custom resource detail similar to: kubectl describe <resource> JOB_NAME -n NAMESPACE. | ||
Returns: | ||
None or object: None if the resource doesnt exist in server, otherwise the | ||
custom object. | ||
""" | ||
_api = client.CustomObjectsApi(k8s_client) | ||
namespace = os.environ.get("NAMESPACE") | ||
job_description = _api.get_namespaced_custom_object( | ||
kvars["group"].lower(), | ||
kvars["version"].lower(), | ||
namespace.lower(), | ||
kvars["plural"].lower(), | ||
job_name.lower() | ||
) | ||
return job_description | ||
|
||
def describe_training_job(k8s_client,training_job_name): | ||
training_vars = { | ||
"group":"sagemaker.services.k8s.aws", | ||
"version":"v1alpha1", | ||
"plural":"trainingjobs", | ||
} | ||
return _get_resource(k8s_client,training_job_name,training_vars) | ||
|
||
#TODO: Make this a generalized function for non-job resources. | ||
def wait_for_trainingjob_status(k8s_client,training_job_name,desiredStatuses,wait_periods,period_length): | ||
for _ in range(wait_periods): | ||
response = describe_training_job(k8s_client,training_job_name) | ||
if response["status"]["trainingJobStatus"] in desiredStatuses:return True | ||
sleep(period_length) | ||
return False |