From adfa4b26a604b92a1ac2e0b8de5eb7597bf39fab Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Wed, 2 Sep 2020 20:29:15 -0700 Subject: [PATCH] Enable custom training job in SageMaker plugin (#113) * checking key existence in GetRole() * (training job plugin only) inject hp; refactor image uri getting mechanism; some other smaller refactoring * change the key and value of the flyte sm command * get rid of a obsolete function * add and change unit tests * add unit tests * lint * refactor a dummy object generation function * refactor a dummy object generation function * add unit tests * lint * forming the runner cmd and inject it as a hyperparameter in SageMaker plugin * fix unit tests * fix unit tests * add custom training plugin * revert overriden changes * compose the __FLYTE_SAGEMAKER_CMD__ in custom training job plugin * add unit test * lint * fix getEventInfoForJob * fix image getting * fix unit tests * stick with inputs.pb and modify hp injecting logic accordingly * fix args converting logic * use default file-based output for custom training job * expanding PluginContext interface with necessary methods so SM plugin can access the DataStore and such * lint error * add unit tests * add logic to inject env vars into hyperparameters * fix output prefix * fix output prefix * remove job name from output prefix for now * fix a unit test * accommodating new arg and env var passsing syntax * injecting a env var to force disable statsd for sagemaker custom training * correcting variable name * remove unused constant * remove comments * fix unit tests * merge template.go * pr comments * add guarding statement wrt algorithm name for custom training plugin and built-in training plugin * refactor file structures: splitting the code into multiple files to optimize for readability * add documentations to a set of constants, and fix a constant's name * split tests into multiple files * correcting error types: make permanent failures * refactor --- go/tasks/pluginmachinery/k8s/plugin.go | 8 + .../plugins/k8s/sagemaker/builtin_training.go | 255 +++++++ .../k8s/sagemaker/builtin_training_test.go | 271 ++++++++ .../plugins/k8s/sagemaker/config/config.go | 4 + go/tasks/plugins/k8s/sagemaker/constants.go | 35 + .../plugins/k8s/sagemaker/custom_training.go | 201 ++++++ .../k8s/sagemaker/custom_training_test.go | 214 ++++++ .../k8s/sagemaker/hyperparameter_tuning.go | 283 ++++++++ .../sagemaker/hyperparameter_tuning_test.go | 126 ++++ go/tasks/plugins/k8s/sagemaker/outputs.go | 71 ++ .../plugins/k8s/sagemaker/outputs_test.go | 34 + go/tasks/plugins/k8s/sagemaker/plugin.go | 99 +++ go/tasks/plugins/k8s/sagemaker/plugin_test.go | 60 ++ ...sagemaker_test.go => plugin_test_utils.go} | 348 +++++----- go/tasks/plugins/k8s/sagemaker/sagemaker.go | 631 ------------------ .../k8s/sagemaker/testdata/config.yaml | 4 +- go/tasks/plugins/k8s/sagemaker/utils.go | 186 ++++-- go/tasks/plugins/k8s/sagemaker/utils_test.go | 132 +++- 18 files changed, 2088 insertions(+), 874 deletions(-) create mode 100644 go/tasks/plugins/k8s/sagemaker/builtin_training.go create mode 100644 go/tasks/plugins/k8s/sagemaker/builtin_training_test.go create mode 100644 go/tasks/plugins/k8s/sagemaker/custom_training.go create mode 100644 go/tasks/plugins/k8s/sagemaker/custom_training_test.go create mode 100644 go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go create mode 100644 go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go create mode 100644 go/tasks/plugins/k8s/sagemaker/outputs.go create mode 100644 go/tasks/plugins/k8s/sagemaker/outputs_test.go create mode 100644 go/tasks/plugins/k8s/sagemaker/plugin.go create mode 100644 go/tasks/plugins/k8s/sagemaker/plugin_test.go rename go/tasks/plugins/k8s/sagemaker/{sagemaker_test.go => plugin_test_utils.go} (58%) delete mode 100644 go/tasks/plugins/k8s/sagemaker/sagemaker.go diff --git a/go/tasks/pluginmachinery/k8s/plugin.go b/go/tasks/pluginmachinery/k8s/plugin.go index cebeb09bc..6a198ae0f 100644 --- a/go/tasks/pluginmachinery/k8s/plugin.go +++ b/go/tasks/pluginmachinery/k8s/plugin.go @@ -3,6 +3,8 @@ package k8s import ( "context" + "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" "k8s.io/apimachinery/pkg/runtime" @@ -48,6 +50,12 @@ type PluginContext interface { // Provides an output sync of type io.OutputWriter OutputWriter() io.OutputWriter + + // Returns a handle to the currently configured storage backend that can be used to communicate with the tasks or write metadata + DataStore() *storage.DataStore + + // Returns the max allowed dataset size that the outputwriter will accept + MaxDatasetSizeBytes() int64 } // Defines a simplified interface to author plugins for k8s resources. diff --git a/go/tasks/plugins/k8s/sagemaker/builtin_training.go b/go/tasks/plugins/k8s/sagemaker/builtin_training.go new file mode 100644 index 000000000..8926c0da3 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/builtin_training.go @@ -0,0 +1,255 @@ +package sagemaker + +import ( + "context" + "fmt" + "strings" + "time" + + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob" + + awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils" + + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + "github.com/aws/aws-sdk-go/service/sagemaker" + + taskError "github.com/lyft/flyteplugins/go/tasks/errors" + + flyteSageMakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" +) + +func (m awsSagemakerPlugin) buildResourceForTrainingJob( + ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + + logger.Infof(ctx, "Building a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + taskTemplate, err := getTaskTemplate(ctx, taskCtx) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to get the task template of the training job task") + } + + // Unmarshal the custom field of the task template back into the Hyperparameter Tuning Job struct generated in flyteidl + sagemakerTrainingJob := flyteSageMakerIdl.TrainingJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType) + } + if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJobResourceConfig] of the TrainingJob does not exist") + } + if sagemakerTrainingJob.GetAlgorithmSpecification() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [AlgorithmSpecification] does not exist") + } + if sagemakerTrainingJob.GetAlgorithmSpecification().GetAlgorithmName() == flyteSageMakerIdl.AlgorithmName_CUSTOM { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Custom algorithm is not supported by the built-in training job plugin") + } + + taskInput, err := taskCtx.InputReader().Get(ctx) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs") + } + + // Get inputs from literals + inputLiterals := taskInput.GetLiterals() + err = checkIfRequiredInputLiteralsExist(inputLiterals, + []string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable}) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist") + } + + trainPathLiteral := inputLiterals[TrainPredefinedInputVariable] + validationPathLiteral := inputLiterals[ValidationPredefinedInputVariable] + staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable] + + if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[train] Input is required and should be of Type [Scalar.Blob]") + } + if validationPathLiteral.GetScalar() == nil || validationPathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[validation] Input is required and should be of Type [Scalar.Blob]") + } + + // Convert the hyperparameters to the spec value + staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type") + } + + outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), TrainingJobOutputPathSubDir) + + jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, &sagemakerTrainingJob) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image") + } + + logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", staticHyperparams) + + cfg := config.GetSagemakerConfig() + + var metricDefinitions []commonv1.MetricDefinition + idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions() + for _, md := range idlMetricDefinitions { + metricDefinitions = append(metricDefinitions, + commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) + } + + apiContentType, err := getAPIContentType(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType()) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]", sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType().String()) + } + + inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String())) + + role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) + if role == "" { + role = cfg.RoleArn + } + + trainingJob := &trainingjobv1.TrainingJob{ + Spec: trainingjobv1.TrainingJobSpec{ + AlgorithmSpecification: &commonv1.AlgorithmSpecification{ + // If the specify a value for this AlgorithmName parameter, the user can't specify a value for TrainingImage. + // in this Flyte plugin, we always use the algorithm name and version the user provides via Flytekit to map to an image + // so we intentionally leave this field nil + AlgorithmName: nil, + TrainingImage: ToStringPtr(trainingImageStr), + TrainingInputMode: commonv1.TrainingInputMode(inputModeString), + MetricDefinitions: metricDefinitions, + }, + // The support of spot training will come in a later version + EnableManagedSpotTraining: nil, + HyperParameters: staticHyperparams, + InputDataConfig: []commonv1.Channel{ + { + ChannelName: ToStringPtr(TrainPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), + InputMode: inputModeString, + }, + { + ChannelName: ToStringPtr(ValidationPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(validationPathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), + InputMode: inputModeString, + }, + }, + OutputDataConfig: &commonv1.OutputDataConfig{ + S3OutputPath: ToStringPtr(outputPath), + }, + CheckpointConfig: nil, + ResourceConfig: &commonv1.ResourceConfig{ + InstanceType: sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceType(), + InstanceCount: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceCount()), + VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()), + VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future + }, + RoleArn: ToStringPtr(role), + Region: ToStringPtr(cfg.Region), + StoppingCondition: &commonv1.StoppingCondition{ + MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout + MaxWaitTimeInSeconds: nil, // TODO: decide how to coordinate this and Flyte's timeout and queueing budget + }, + TensorBoardOutputConfig: nil, + Tags: nil, + TrainingJobName: &jobName, + }, + } + logger.Infof(ctx, "Successfully built a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + return trainingJob, nil +} + +func (m awsSagemakerPlugin) getTaskPhaseForTrainingJob( + ctx context.Context, pluginContext k8s.PluginContext, trainingJob *trainingjobv1.TrainingJob) (pluginsCore.PhaseInfo, error) { + + logger.Infof(ctx, "Getting task phase for sagemaker training job [%v]", trainingJob.Status.SageMakerTrainingJobName) + info, err := m.getEventInfoForTrainingJob(ctx, trainingJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "Failed to get event info for the job") + } + + occurredAt := time.Now() + + switch trainingJob.Status.TrainingJobStatus { + case trainingjobController.ReconcilingTrainingJobStatus: + logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", trainingJob.Status.Additional) + // TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call pluginErrors + execError := &flyteIdlCore.ExecutionError{ + Message: trainingJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: trainingjobController.ReconcilingTrainingJobStatus, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil + case sagemaker.TrainingJobStatusFailed: + execError := &flyteIdlCore.ExecutionError{ + Message: trainingJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: sagemaker.TrainingJobStatusFailed, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil + case sagemaker.TrainingJobStatusStopped: + reason := fmt.Sprintf("Training Job Stopped") + return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil + case sagemaker.TrainingJobStatusCompleted: + // Now that it is a success we will set the outputs as expected by the task + + // We have specified an output path in the CRD, and we know SageMaker will automatically upload the + // model tarball to s3:////output/model.tar.gz + + // Therefore, here we create a output literal map, where we fill in the above path to the URI field of the + // blob output, which will later be written out by the OutputWriter to the outputs.pb remotely on S3 + outputLiteralMap, err := getOutputLiteralMapFromTaskInterface(ctx, pluginContext.TaskReader(), + createModelOutputPath(trainingJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(), trainingJob.Status.SageMakerTrainingJobName)) + if err != nil { + logger.Errorf(ctx, "Failed to create outputs, err: %s", err) + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task") + } + // Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter + if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil)); err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unable to write output to the remote location") + } + logger.Debugf(ctx, "Successfully produced and returned outputs") + return pluginsCore.PhaseInfoSuccess(info), nil + case "": + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil + } + + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil +} + +func (m awsSagemakerPlugin) getEventInfoForTrainingJob(ctx context.Context, trainingJob *trainingjobv1.TrainingJob) (*pluginsCore.TaskInfo, error) { + + var jobRegion, jobName, jobTypeInURL, sagemakerLinkName string + jobRegion = *trainingJob.Spec.Region + jobName = *trainingJob.Spec.TrainingJobName + jobTypeInURL = "jobs" + sagemakerLinkName = TrainingJobSageMakerLinkName + + logger.Infof(ctx, "Getting event information for SageMaker BuiltinAlgorithmTrainingJob task, job region: [%v], job name: [%v], "+ + "job type in url: [%v], sagemaker link name: [%v]", jobRegion, jobName, jobTypeInURL, sagemakerLinkName) + + return createTaskInfo(ctx, jobRegion, jobName, jobTypeInURL, sagemakerLinkName) +} diff --git a/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go new file mode 100644 index 000000000..0df234443 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go @@ -0,0 +1,271 @@ +package sagemaker + +import ( + "context" + "fmt" + "testing" + + trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob" + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + taskError "github.com/lyft/flyteplugins/go/tasks/errors" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" + + stdConfig "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + "github.com/aws/aws-sdk-go/service/sagemaker" + sagemakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + "github.com/stretchr/testify/assert" +) + +func Test_awsSagemakerPlugin_BuildResourceForTrainingJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + t.Run("If roleAnnotationKey has a match, the role from the metadata should be fetched", func(t *testing.T) { + // Injecting a config which contains a matching roleAnnotationKey -> expecting to get the role from metadata + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + SearchPaths: []string{"testdata/config.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job x", tjObj) + + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false)) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + assert.Equal(t, "metadata_role", *trainingJob.Spec.RoleArn) + }) + + t.Run("If roleAnnotationKey does not have a match, the role from the config should be fetched", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job y", tjObj) + + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false)) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + assert.Equal(t, "config_role", *trainingJob.Spec.RoleArn) + }) + + t.Run("In a custom training job we should see the FLYTE_SAGEMAKER_CMD being injected", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false)) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + assert.Equal(t, "config_role", *trainingJob.Spec.RoleArn) + + expectedHPs := []*commonv1.KeyValuePair{ + {Name: "a", Value: "1"}, + {Name: "b", Value: "2"}, + } + + assert.ElementsMatch(t, + func(kvs []*commonv1.KeyValuePair) []commonv1.KeyValuePair { + ret := make([]commonv1.KeyValuePair, 0, len(kvs)) + for _, kv := range kvs { + ret = append(ret, *kv) + } + return ret + }(expectedHPs), + func(kvs []*commonv1.KeyValuePair) []commonv1.KeyValuePair { + ret := make([]commonv1.KeyValuePair, 0, len(kvs)) + for _, kv := range kvs { + ret = append(ret, *kv) + } + return ret + }(trainingJob.Spec.HyperParameters)) + }) +} + +func Test_awsSagemakerPlugin_GetTaskPhaseForTrainingJob(t *testing.T) { + ctx := context.TODO() + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + t.Run("ReconcilingTrainingJobStatus should lead to a retryable failure", func(t *testing.T) { + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + trainingJob.Status.TrainingJobStatus = trainingjobController.ReconcilingTrainingJobStatus + phaseInfo, err := awsSageMakerTrainingJobHandler.getTaskPhaseForTrainingJob(ctx, taskCtx, trainingJob) + assert.Nil(t, err) + assert.Equal(t, phaseInfo.Phase(), pluginsCore.PhaseRetryableFailure) + assert.Equal(t, phaseInfo.Err().GetKind(), flyteIdlCore.ExecutionError_USER) + assert.Equal(t, phaseInfo.Err().GetCode(), trainingjobController.ReconcilingTrainingJobStatus) + assert.Equal(t, phaseInfo.Err().GetMessage(), "") + }) + t.Run("TrainingJobStatusFailed should be a permanent failure", func(t *testing.T) { + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + trainingJob.Status.TrainingJobStatus = sagemaker.TrainingJobStatusFailed + + phaseInfo, err := awsSageMakerTrainingJobHandler.getTaskPhaseForTrainingJob(ctx, taskCtx, trainingJob) + assert.Nil(t, err) + assert.Equal(t, phaseInfo.Phase(), pluginsCore.PhasePermanentFailure) + assert.Equal(t, phaseInfo.Err().GetKind(), flyteIdlCore.ExecutionError_USER) + assert.Equal(t, phaseInfo.Err().GetCode(), sagemaker.TrainingJobStatusFailed) + assert.Equal(t, phaseInfo.Err().GetMessage(), "") + }) + t.Run("TrainingJobStatusFailed should be a permanent failure", func(t *testing.T) { + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + trainingJob.Status.TrainingJobStatus = sagemaker.TrainingJobStatusStopped + + phaseInfo, err := awsSageMakerTrainingJobHandler.getTaskPhaseForTrainingJob(ctx, taskCtx, trainingJob) + assert.Nil(t, err) + assert.Equal(t, phaseInfo.Phase(), pluginsCore.PhaseRetryableFailure) + assert.Equal(t, phaseInfo.Err().GetKind(), flyteIdlCore.ExecutionError_USER) + assert.Equal(t, phaseInfo.Err().GetCode(), taskError.DownstreamSystemError) + // We have a default message for TrainingJobStatusStopped + assert.Equal(t, phaseInfo.Err().GetMessage(), "Training Job Stopped") + }) +} + +func Test_awsSagemakerPlugin_getEventInfoForTrainingJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + t.Run("get event info should return correctly formatted log links for training job", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + + taskInfo, err := awsSageMakerTrainingJobHandler.getEventInfoForTrainingJob(ctx, trainingJob) + if err != nil { + panic(err) + } + + expectedTaskLogs := []*flyteIdlCore.TaskLog{ + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=/aws/sagemaker/TrainingJobs;prefix=%s;streamFilter=typeLogStreamPrefix", + "us-west-2", "us-west-2", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: CloudWatchLogLinkName, + MessageFormat: flyteIdlCore.TaskLog_JSON, + }, + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/sagemaker/home?region=%s#/%s/%s", + "us-west-2", "us-west-2", "jobs", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: TrainingJobSageMakerLinkName, + MessageFormat: flyteIdlCore.TaskLog_UNKNOWN, + }, + } + + expectedCustomInfo, _ := utils.MarshalObjToStruct(map[string]string{}) + assert.Equal(t, + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(expectedTaskLogs), + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(taskInfo.Logs)) + assert.Equal(t, *expectedCustomInfo, *taskInfo.CustomInfo) + }) +} diff --git a/go/tasks/plugins/k8s/sagemaker/config/config.go b/go/tasks/plugins/k8s/sagemaker/config/config.go index 036a1fbaa..1316f88a3 100644 --- a/go/tasks/plugins/k8s/sagemaker/config/config.go +++ b/go/tasks/plugins/k8s/sagemaker/config/config.go @@ -60,3 +60,7 @@ func GetSagemakerConfig() *Config { func SetSagemakerConfig(cfg *Config) error { return sagemakerConfigSection.SetConfig(cfg) } + +func ResetSagemakerConfig() error { + return sagemakerConfigSection.SetConfig(&defaultConfig) +} diff --git a/go/tasks/plugins/k8s/sagemaker/constants.go b/go/tasks/plugins/k8s/sagemaker/constants.go index fdd8a4f6a..83113afa2 100644 --- a/go/tasks/plugins/k8s/sagemaker/constants.go +++ b/go/tasks/plugins/k8s/sagemaker/constants.go @@ -5,6 +5,11 @@ const ( trainingJobTaskType = "sagemaker_training_job_task" ) +const ( + customTrainingJobTaskPluginID = "sagemaker_custom_training" + customTrainingJobTaskType = "sagemaker_custom_training_job_task" +) + const ( hyperparameterTuningJobTaskPluginID = "sagemaker_hyperparameter_tuning" hyperparameterTuningJobTaskType = "sagemaker_hyperparameter_tuning_job_task" @@ -13,3 +18,33 @@ const ( const ( TEXTCSVInputContentType string = "text/csv" ) + +const ( + FlyteSageMakerEnvVarKeyPrefix string = "__FLYTE_ENV_VAR_" + FlyteSageMakerKeySuffix string = "__" + FlyteSageMakerCmdKeyPrefix string = "__FLYTE_CMD_" + FlyteSageMakerCmdDummyValue string = "__FLYTE_CMD_DUMMY_VALUE__" + FlyteSageMakerEnvVarKeyStatsdDisabled string = "FLYTE_STATSD_DISABLED" +) + +const ( + TrainingJobOutputPathSubDir = "training_outputs" + HyperparameterOutputPathSubDir = "hyperparameter_tuning_outputs" +) + +const ( + // These constants are the default input channel names for built-in algorithms + // When dealing with built-in algorithm training tasks, the plugin would assume these inputs exist and + // access these keys in the input literal map + // The same keys (except for the static hyperparameter key) would also be used when filling out the inputDataConfig fields of the CRD + TrainPredefinedInputVariable = "train" + ValidationPredefinedInputVariable = "validation" + StaticHyperparametersPredefinedInputVariable = "static_hyperparameters" +) + +const ( + CloudWatchLogLinkName = "CloudWatch Logs" + TrainingJobSageMakerLinkName = "SageMaker Built-in Algorithm Training Job" + CustomTrainingJobSageMakerLinkName = "SageMaker Custom Training Job" + HyperparameterTuningJobSageMakerLinkName = "SageMaker Hyperparameter Tuning Job" +) diff --git a/go/tasks/plugins/k8s/sagemaker/custom_training.go b/go/tasks/plugins/k8s/sagemaker/custom_training.go new file mode 100644 index 000000000..75f409d84 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/custom_training.go @@ -0,0 +1,201 @@ +package sagemaker + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob" + "github.com/aws/aws-sdk-go/service/sagemaker" + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + taskError "github.com/lyft/flyteplugins/go/tasks/errors" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flytestdlib/logger" + + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" + + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + flyteSageMakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" +) + +func (m awsSagemakerPlugin) buildResourceForCustomTrainingJob( + ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + + logger.Infof(ctx, "Building a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + taskTemplate, err := getTaskTemplate(ctx, taskCtx) + if err != nil { + return nil, err + } + + // Unmarshal the custom field of the task template back into the Hyperparameter Tuning Job struct generated in flyteidl + sagemakerTrainingJob := flyteSageMakerIdl.TrainingJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob) + if err != nil { + + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType) + } + + if sagemakerTrainingJob.GetAlgorithmSpecification() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The unmarshaled training job does not have a AlgorithmSpecification field") + } + if sagemakerTrainingJob.GetAlgorithmSpecification().GetAlgorithmName() != flyteSageMakerIdl.AlgorithmName_CUSTOM { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The algorithm name [%v] is not supported by the custom training job plugin", + sagemakerTrainingJob.GetAlgorithmSpecification().GetAlgorithmName().String()) + } + + inputChannels := make([]commonv1.Channel, 0) + inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String())) + + jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + outputPath := taskCtx.OutputWriter().GetOutputPrefixPath().String() + + if taskTemplate.GetContainer() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The task template points to a nil container") + } + + if taskTemplate.GetContainer().GetImage() == "" { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Invalid image of the container") + } + + trainingImageStr := taskTemplate.GetContainer().GetImage() + + cfg := config.GetSagemakerConfig() + + var metricDefinitions []commonv1.MetricDefinition + idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions() + for _, md := range idlMetricDefinitions { + metricDefinitions = append(metricDefinitions, + commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) + } + + // TODO: When dealing with HPO job, we will need to deal with the following + // jobOutputPath := NewJobOutputPaths(ctx, taskCtx.DataStore(), taskCtx.OutputWriter().GetOutputPrefixPath(), jobName) + + hyperParameters, err := injectArgsAndEnvVars(ctx, taskCtx, taskTemplate) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to inject the task template's container env vars to the hyperparameter list") + } + + // Statsd is not available in SM because we are not allowed kick off a side car container nor are we allowed to customize our own AMI in an SM environment + // Therefore we need to inject a env var to disable statsd for sagemaker tasks + statsdDisableEnvVarName := fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, FlyteSageMakerEnvVarKeyStatsdDisabled, FlyteSageMakerKeySuffix) + logger.Infof(ctx, "Injecting %v=%v to force disable statsd for SageMaker tasks only", statsdDisableEnvVarName, strconv.FormatBool(true)) + hyperParameters = append(hyperParameters, &commonv1.KeyValuePair{ + Name: statsdDisableEnvVarName, + Value: strconv.FormatBool(true), + }) + + logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", hyperParameters) + + trainingJob := &trainingjobv1.TrainingJob{ + Spec: trainingjobv1.TrainingJobSpec{ + AlgorithmSpecification: &commonv1.AlgorithmSpecification{ + // If the specify a value for this AlgorithmName parameter, the user can't specify a value for TrainingImage. + // in this Flyte plugin, we always use the algorithm name and version the user provides via Flytekit to map to an image + // so we intentionally leave this field nil + AlgorithmName: nil, + TrainingImage: ToStringPtr(trainingImageStr), + TrainingInputMode: commonv1.TrainingInputMode(inputModeString), + MetricDefinitions: metricDefinitions, + }, + // The support of spot training will come in a later version + EnableManagedSpotTraining: nil, + HyperParameters: hyperParameters, + InputDataConfig: inputChannels, + OutputDataConfig: &commonv1.OutputDataConfig{ + S3OutputPath: ToStringPtr(outputPath), + }, + CheckpointConfig: nil, + ResourceConfig: &commonv1.ResourceConfig{ + InstanceType: sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceType(), + InstanceCount: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceCount()), + VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()), + VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future + }, + RoleArn: ToStringPtr(cfg.RoleArn), + Region: ToStringPtr(cfg.Region), + StoppingCondition: &commonv1.StoppingCondition{ + MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout + MaxWaitTimeInSeconds: nil, // TODO: decide how to coordinate this and Flyte's timeout and queueing budget + }, + TensorBoardOutputConfig: nil, + Tags: nil, + TrainingJobName: &jobName, + }, + } + logger.Infof(ctx, "Successfully built a custom training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + return trainingJob, nil +} + +func (m awsSagemakerPlugin) getTaskPhaseForCustomTrainingJob( + ctx context.Context, pluginContext k8s.PluginContext, trainingJob *trainingjobv1.TrainingJob) (pluginsCore.PhaseInfo, error) { + + logger.Infof(ctx, "Getting task phase for sagemaker training job [%v]", trainingJob.Status.SageMakerTrainingJobName) + info, err := m.getEventInfoForCustomTrainingJob(ctx, trainingJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "Failed to get event info for the job") + } + + occurredAt := time.Now() + + switch trainingJob.Status.TrainingJobStatus { + case trainingjobController.ReconcilingTrainingJobStatus: + logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", trainingJob.Status.Additional) + // TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call pluginErrors + execError := &flyteIdlCore.ExecutionError{ + Message: trainingJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: trainingjobController.ReconcilingTrainingJobStatus, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil + case sagemaker.TrainingJobStatusFailed: + execError := &flyteIdlCore.ExecutionError{ + Message: trainingJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: sagemaker.TrainingJobStatusFailed, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil + case sagemaker.TrainingJobStatusStopped: + reason := fmt.Sprintf("Training Job Stopped") + return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil + case sagemaker.TrainingJobStatusCompleted: + // Now that it is a success we will set the outputs as expected by the task + + logger.Infof(ctx, "Looking for the output.pb under %s", pluginContext.OutputWriter().GetOutputPrefixPath()) + outputReader := ioutils.NewRemoteFileOutputReader(ctx, pluginContext.DataStore(), pluginContext.OutputWriter(), pluginContext.MaxDatasetSizeBytes()) + + // Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter + if err := pluginContext.OutputWriter().Put(ctx, outputReader); err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to write output to the remote location") + } + logger.Debugf(ctx, "Successfully produced and returned outputs") + return pluginsCore.PhaseInfoSuccess(info), nil + case "": + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil + } + + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil +} + +func (m awsSagemakerPlugin) getEventInfoForCustomTrainingJob(ctx context.Context, trainingJob *trainingjobv1.TrainingJob) (*pluginsCore.TaskInfo, error) { + + var jobRegion, jobName, jobTypeInURL, sagemakerLinkName string + jobRegion = *trainingJob.Spec.Region + jobName = *trainingJob.Spec.TrainingJobName + jobTypeInURL = "jobs" + sagemakerLinkName = CustomTrainingJobSageMakerLinkName + + logger.Infof(ctx, "Getting event information for SageMaker CustomTrainingJob task, job region: [%v], job name: [%v], "+ + "job type in url: [%v], sagemaker link name: [%v]", jobRegion, jobName, jobTypeInURL, sagemakerLinkName) + + return createTaskInfo(ctx, jobRegion, jobName, jobTypeInURL, sagemakerLinkName) +} diff --git a/go/tasks/plugins/k8s/sagemaker/custom_training_test.go b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go new file mode 100644 index 000000000..823f84c42 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go @@ -0,0 +1,214 @@ +package sagemaker + +import ( + "context" + "fmt" + "strconv" + "testing" + + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" + + sagemakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + "github.com/stretchr/testify/assert" + + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + "github.com/aws/aws-sdk-go/service/sagemaker" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + stdConfig "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" +) + +func Test_awsSagemakerPlugin_BuildResourceForCustomTrainingJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + t.Run("In a custom training job we should see the FLYTE_SAGEMAKER_CMD being injected", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: customTrainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockCustomTrainingJobTaskContext(taskTemplate, false)) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + assert.Equal(t, "config_role", *trainingJob.Spec.RoleArn) + //assert.Equal(t, 1, len(trainingJob.Spec.HyperParameters)) + fmt.Printf("%v", trainingJob.Spec.HyperParameters) + expectedHPs := []*commonv1.KeyValuePair{ + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 0, "service_venv", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 1, "pyflyte-execute", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 2, "--test-opt1", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 3, "value1", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 4, "--test-opt2", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 5, "value2", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 6, "--test-flag", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, "Env_Var", FlyteSageMakerKeySuffix), Value: "Env_Val"}, + {Name: fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, FlyteSageMakerEnvVarKeyStatsdDisabled, FlyteSageMakerKeySuffix), Value: strconv.FormatBool(true)}, + } + assert.Equal(t, len(expectedHPs), len(trainingJob.Spec.HyperParameters)) + for i := range expectedHPs { + assert.Equal(t, expectedHPs[i].Name, trainingJob.Spec.HyperParameters[i].Name) + assert.Equal(t, expectedHPs[i].Value, trainingJob.Spec.HyperParameters[i].Value) + } + + assert.Equal(t, testImage, *trainingJob.Spec.AlgorithmSpecification.TrainingImage) + }) +} + +func Test_awsSagemakerPlugin_GetTaskPhaseForCustomTrainingJob(t *testing.T) { + ctx := context.TODO() + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: customTrainingJobTaskType} + + t.Run("TrainingJobStatusCompleted", func(t *testing.T) { + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.Error(t, err) + assert.Nil(t, trainingJobResource) + }) + + t.Run("TrainingJobStatusCompleted", func(t *testing.T) { + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + + trainingJob.Status.TrainingJobStatus = sagemaker.TrainingJobStatusCompleted + phaseInfo, err := awsSageMakerTrainingJobHandler.getTaskPhaseForCustomTrainingJob(ctx, taskCtx, trainingJob) + assert.Nil(t, err) + assert.Equal(t, phaseInfo.Phase(), pluginsCore.PhaseSuccess) + }) + t.Run("OutputWriter.Put returns an error", func(t *testing.T) { + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, true) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + + trainingJob.Status.TrainingJobStatus = sagemaker.TrainingJobStatusCompleted + phaseInfo, err := awsSageMakerTrainingJobHandler.getTaskPhaseForCustomTrainingJob(ctx, taskCtx, trainingJob) + assert.NotNil(t, err) + assert.Equal(t, phaseInfo.Phase(), pluginsCore.PhaseUndefined) + }) +} + +func Test_awsSagemakerPlugin_getEventInfoForCustomTrainingJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + t.Run("get event info should return correctly formatted log links for custom training job", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: customTrainingJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false) + trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, trainingJobResource) + + trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) + assert.True(t, ok) + + taskInfo, err := awsSageMakerTrainingJobHandler.getEventInfoForCustomTrainingJob(ctx, trainingJob) + if err != nil { + panic(err) + } + + expectedTaskLogs := []*flyteIdlCore.TaskLog{ + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=/aws/sagemaker/TrainingJobs;prefix=%s;streamFilter=typeLogStreamPrefix", + "us-west-2", "us-west-2", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: CloudWatchLogLinkName, + MessageFormat: flyteIdlCore.TaskLog_JSON, + }, + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/sagemaker/home?region=%s#/%s/%s", + "us-west-2", "us-west-2", "jobs", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: CustomTrainingJobSageMakerLinkName, + MessageFormat: flyteIdlCore.TaskLog_UNKNOWN, + }, + } + + expectedCustomInfo, _ := utils.MarshalObjToStruct(map[string]string{}) + assert.Equal(t, + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(expectedTaskLogs), + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(taskInfo.Logs)) + assert.Equal(t, *expectedCustomInfo, *taskInfo.CustomInfo) + }) +} diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go new file mode 100644 index 000000000..4c5a91ebb --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -0,0 +1,283 @@ +package sagemaker + +import ( + "context" + "fmt" + "strings" + "time" + + awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils" + + hpojobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/hyperparametertuningjob" + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + "github.com/aws/aws-sdk-go/service/sagemaker" + + taskError "github.com/lyft/flyteplugins/go/tasks/errors" + + flyteSageMakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" +) + +func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( + ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + + logger.Infof(ctx, "Building a hyperparameter tuning job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + + taskTemplate, err := getTaskTemplate(ctx, taskCtx) + if err != nil { + return nil, err + } + + // Unmarshal the custom field of the task template back into the HyperparameterTuningJob struct generated in flyteidl + sagemakerHPOJob := flyteSageMakerIdl.HyperparameterTuningJob{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerHPOJob) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "invalid HyperparameterTuningJob task specification: not able to unmarshal the custom field to [%s]", hyperparameterTuningJobTaskType) + } + if sagemakerHPOJob.GetTrainingJob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJob] of the HyperparameterTuningJob does not exist") + } + if sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [AlgorithmSpecification] of the HyperparameterTuningJob's underlying TrainingJob does not exist") + } + if sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJobResourceConfig] of the HyperparameterTuningJob's underlying TrainingJob does not exist") + } + + taskInput, err := taskCtx.InputReader().Get(ctx) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs") + } + + // Get inputs from literals + inputLiterals := taskInput.GetLiterals() + err = checkIfRequiredInputLiteralsExist(inputLiterals, + []string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable}) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist") + } + + trainPathLiteral := inputLiterals[TrainPredefinedInputVariable] + validatePathLiteral := inputLiterals[ValidationPredefinedInputVariable] + staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable] + hpoJobConfigLiteral := inputLiterals["hyperparameter_tuning_job_config"] + if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", TrainPredefinedInputVariable) + } + if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", ValidationPredefinedInputVariable) + } + // Convert the hyperparameters to the spec value + staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type") + } + // hyperparameter_tuning_job_config is marshaled into a byte array in flytekit, so will have to unmarshal it back + hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to convert hyperparameter tuning job config literal to spec type") + } + + outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), HyperparameterOutputPathSubDir) + + if hpoJobConfig.GetTuningObjective() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TuningObjective] does not exist") + } + + // Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter + // and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry + // in the static map and let the one in the map of the tunable hyperparameters take precedence + staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobConfig.GetHyperparameterRanges().GetParameterRangeMap()) + + jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, sagemakerHPOJob.GetTrainingJob()) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image") + } + + hpoJobParameterRanges := buildParameterRanges(hpoJobConfig) + + logger.Infof(ctx, "The Sagemaker HyperparameterTuningJob Task plugin received the following inputs: \n"+ + "static hyperparameters: [%v]\n"+ + "hyperparameter tuning job config: [%v]\n"+ + "parameter ranges: [%v]", staticHyperparams, hpoJobConfig, hpoJobParameterRanges) + + cfg := config.GetSagemakerConfig() + + var metricDefinitions []commonv1.MetricDefinition + idlMetricDefinitions := sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetMetricDefinitions() + for _, md := range idlMetricDefinitions { + metricDefinitions = append(metricDefinitions, + commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) + } + + apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType()) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]", + sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String()) + } + + inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String())) + tuningStrategyString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningStrategy().String())) + tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String())) + trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String())) + + role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) + if role == "" { + role = cfg.RoleArn + } + + hpoJob := &hpojobv1.HyperparameterTuningJob{ + Spec: hpojobv1.HyperparameterTuningJobSpec{ + HyperParameterTuningJobName: &jobName, + HyperParameterTuningJobConfig: &commonv1.HyperParameterTuningJobConfig{ + ResourceLimits: &commonv1.ResourceLimits{ + MaxNumberOfTrainingJobs: ToInt64Ptr(sagemakerHPOJob.GetMaxNumberOfTrainingJobs()), + MaxParallelTrainingJobs: ToInt64Ptr(sagemakerHPOJob.GetMaxParallelTrainingJobs()), + }, + Strategy: commonv1.HyperParameterTuningJobStrategyType(tuningStrategyString), + HyperParameterTuningJobObjective: &commonv1.HyperParameterTuningJobObjective{ + Type: commonv1.HyperParameterTuningJobObjectiveType(tuningObjectiveTypeString), + MetricName: ToStringPtr(hpoJobConfig.GetTuningObjective().GetMetricName()), + }, + ParameterRanges: hpoJobParameterRanges, + TrainingJobEarlyStoppingType: commonv1.TrainingJobEarlyStoppingType(trainingJobEarlyStoppingTypeString), + }, + TrainingJobDefinition: &commonv1.HyperParameterTrainingJobDefinition{ + StaticHyperParameters: staticHyperparams, + AlgorithmSpecification: &commonv1.HyperParameterAlgorithmSpecification{ + TrainingImage: ToStringPtr(trainingImageStr), + TrainingInputMode: commonv1.TrainingInputMode(inputModeString), + MetricDefinitions: metricDefinitions, + AlgorithmName: nil, + }, + InputDataConfig: []commonv1.Channel{ + { + ChannelName: ToStringPtr(TrainPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata + InputMode: inputModeString, + }, + { + ChannelName: ToStringPtr(ValidationPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata + InputMode: inputModeString, + }, + }, + OutputDataConfig: &commonv1.OutputDataConfig{ + S3OutputPath: ToStringPtr(outputPath), + }, + ResourceConfig: &commonv1.ResourceConfig{ + InstanceType: sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetInstanceType(), + InstanceCount: ToInt64Ptr(sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetInstanceCount()), + VolumeSizeInGB: ToInt64Ptr(sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetVolumeSizeInGb()), + VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future + }, + RoleArn: ToStringPtr(role), + StoppingCondition: &commonv1.StoppingCondition{ + MaxRuntimeInSeconds: ToInt64Ptr(86400), + MaxWaitTimeInSeconds: nil, + }, + }, + Region: ToStringPtr(cfg.Region), + }, + } + + logger.Infof(ctx, "Successfully built a hyperparameter tuning job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) + return hpoJob, nil +} + +func (m awsSagemakerPlugin) getTaskPhaseForHyperparameterTuningJob( + ctx context.Context, pluginContext k8s.PluginContext, hpoJob *hpojobv1.HyperparameterTuningJob) (pluginsCore.PhaseInfo, error) { + + logger.Infof(ctx, "Getting task phase for hyperparameter tuning job [%v]", hpoJob.Status.SageMakerHyperParameterTuningJobName) + info, err := m.getEventInfoForHyperparameterTuningJob(ctx, hpoJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "Failed to get event info for the job") + } + + occurredAt := time.Now() + + switch hpoJob.Status.HyperParameterTuningJobStatus { + case hpojobController.ReconcilingTuningJobStatus: + logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", hpoJob.Status.Additional) + // TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call pluginErrors + execError := &flyteIdlCore.ExecutionError{ + Message: hpoJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: hpojobController.ReconcilingTuningJobStatus, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil + case sagemaker.HyperParameterTuningJobStatusFailed: + execError := &flyteIdlCore.ExecutionError{ + Message: hpoJob.Status.Additional, + Kind: flyteIdlCore.ExecutionError_USER, + Code: sagemaker.HyperParameterTuningJobStatusFailed, + } + return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil + case sagemaker.HyperParameterTuningJobStatusStopped: + reason := fmt.Sprintf("Hyperparameter tuning job stopped") + return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil + case sagemaker.HyperParameterTuningJobStatusCompleted: + // Now that it is a success we will set the outputs as expected by the task + + // TODO: + // Check task template -> custom training job -> if custom: assume output.pb exist, and fail if it doesn't. If it exists, then + // -> if not custom: check model.tar.gz + out, err := getOutputLiteralMapFromTaskInterface(ctx, pluginContext.TaskReader(), + createModelOutputPath(hpoJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(), + *hpoJob.Status.BestTrainingJob.TrainingJobName)) + if err != nil { + logger.Errorf(ctx, "Failed to create outputs, err: %s", err) + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task") + } + if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(out, nil)); err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + logger.Debugf(ctx, "Successfully produced and returned outputs") + return pluginsCore.PhaseInfoSuccess(info), nil + case "": + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil + } + + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil +} + +func (m awsSagemakerPlugin) getEventInfoForHyperparameterTuningJob(ctx context.Context, hpoJob *hpojobv1.HyperparameterTuningJob) (*pluginsCore.TaskInfo, error) { + + var jobRegion, jobName, jobTypeInURL, sagemakerLinkName string + jobRegion = *hpoJob.Spec.Region + jobName = *hpoJob.Spec.HyperParameterTuningJobName + jobTypeInURL = "hyper-tuning-jobs" + sagemakerLinkName = HyperparameterTuningJobSageMakerLinkName + + logger.Infof(ctx, "Getting event information for SageMaker HyperparameterTuningJob task, job region: [%v], job name: [%v], "+ + "job type in url: [%v], sagemaker link name: [%v]", jobRegion, jobName, jobTypeInURL, sagemakerLinkName) + + return createTaskInfo(ctx, jobRegion, jobName, jobTypeInURL, sagemakerLinkName) +} diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go new file mode 100644 index 000000000..c43dd49e8 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go @@ -0,0 +1,126 @@ +package sagemaker + +import ( + "context" + "fmt" + "testing" + + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + stdConfig "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" + + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + sagemakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + "github.com/stretchr/testify/assert" +) + +func Test_awsSagemakerPlugin_BuildResourceForHyperparameterTuningJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + err := config.ResetSagemakerConfig() + if err != nil { + panic(err) + } + defaultCfg := config.GetSagemakerConfig() + awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) + taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) + hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, generateMockHyperparameterTuningJobTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, hpoJobResource) + + hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) + assert.True(t, ok) + assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition) + assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.IntegerParameterRanges)) + assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.ContinuousParameterRanges)) + assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.CategoricalParameterRanges)) + assert.Equal(t, "us-east-1", *hpoJob.Spec.Region) + assert.Equal(t, "default_role", *hpoJob.Spec.TrainingJobDefinition.RoleArn) + + err = config.SetSagemakerConfig(defaultCfg) + if err != nil { + panic(err) + } +} + +func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T) { + // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + t.Run("get event info should return correctly formatted log links for custom training job", func(t *testing.T) { + // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + // Use a different + SearchPaths: []string{"testdata/config2.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) + htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) + taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) + taskCtx := generateMockHyperparameterTuningJobTaskContext(taskTemplate) + hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, taskCtx) + assert.NoError(t, err) + assert.NotNil(t, hpoJobResource) + + hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) + assert.True(t, ok) + + taskInfo, err := awsSageMakerHPOJobHandler.getEventInfoForHyperparameterTuningJob(ctx, hpoJob) + if err != nil { + panic(err) + } + + expectedTaskLogs := []*flyteIdlCore.TaskLog{ + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=/aws/sagemaker/TrainingJobs;prefix=%s;streamFilter=typeLogStreamPrefix", + "us-west-2", "us-west-2", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: CloudWatchLogLinkName, + MessageFormat: flyteIdlCore.TaskLog_JSON, + }, + { + Uri: fmt.Sprintf("https://%s.console.aws.amazon.com/sagemaker/home?region=%s#/%s/%s", + "us-west-2", "us-west-2", "hyper-tuning-jobs", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()), + Name: HyperparameterTuningJobSageMakerLinkName, + MessageFormat: flyteIdlCore.TaskLog_UNKNOWN, + }, + } + + expectedCustomInfo, _ := utils.MarshalObjToStruct(map[string]string{}) + assert.Equal(t, + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(expectedTaskLogs), + func(tis []*flyteIdlCore.TaskLog) []flyteIdlCore.TaskLog { + ret := make([]flyteIdlCore.TaskLog, 0, len(tis)) + for _, ti := range tis { + ret = append(ret, *ti) + } + return ret + }(taskInfo.Logs)) + assert.Equal(t, *expectedCustomInfo, *taskInfo.CustomInfo) + }) +} diff --git a/go/tasks/plugins/k8s/sagemaker/outputs.go b/go/tasks/plugins/k8s/sagemaker/outputs.go new file mode 100644 index 000000000..f4cb82ad5 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/outputs.go @@ -0,0 +1,71 @@ +package sagemaker + +import ( + "context" + "fmt" + + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" + + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" +) + +func createOutputLiteralMap(tk *core.TaskTemplate, outputPath string) *core.LiteralMap { + op := &core.LiteralMap{} + for k := range tk.Interface.Outputs.Variables { + // if v != core.LiteralType_Blob{} + op.Literals = make(map[string]*core.Literal) + op.Literals[k] = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Blob{ + Blob: &core.Blob{ + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{Dimensionality: core.BlobType_SINGLE}, + }, + Uri: outputPath, + }, + }, + }, + }, + } + } + return op +} + +func getOutputLiteralMapFromTaskInterface(ctx context.Context, tr pluginsCore.TaskReader, outputPath string) (*flyteIdlCore.LiteralMap, error) { + tk, err := tr.Read(ctx) + if err != nil { + return nil, err + } + if tk.Interface.Outputs != nil && tk.Interface.Outputs.Variables == nil { + logger.Warnf(ctx, "No outputs declared in the output interface. Ignoring the generated outputs.") + return nil, nil + } + + // We know that for XGBoost task there is only one output to be generated + if len(tk.Interface.Outputs.Variables) > 1 { + return nil, fmt.Errorf("expected to generate more than one outputs of type [%v]", tk.Interface.Outputs.Variables) + } + op := createOutputLiteralMap(tk, outputPath) + return op, nil +} + +func createOutputPath(prefix string, subdir string) string { + return fmt.Sprintf("%s/%s", prefix, subdir) +} + +func createModelOutputPath(job k8s.Resource, prefix, jobName string) string { + switch job.(type) { + case *trainingjobv1.TrainingJob: + return fmt.Sprintf("%s/%s/output/model.tar.gz", createOutputPath(prefix, TrainingJobOutputPathSubDir), jobName) + case *hpojobv1.HyperparameterTuningJob: + return fmt.Sprintf("%s/%s/output/model.tar.gz", createOutputPath(prefix, HyperparameterOutputPathSubDir), jobName) + default: + return fmt.Sprintf("") + } +} diff --git a/go/tasks/plugins/k8s/sagemaker/outputs_test.go b/go/tasks/plugins/k8s/sagemaker/outputs_test.go new file mode 100644 index 000000000..153797ee6 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/outputs_test.go @@ -0,0 +1,34 @@ +package sagemaker + +import ( + "testing" + + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" +) + +func Test_createModelOutputPath(t *testing.T) { + type args struct { + job k8s.Resource + prefix string + bestExperiment string + } + tests := []struct { + name string + args args + want string + }{ + {name: "training job: simple output path", args: args{job: &trainingjobv1.TrainingJob{}, prefix: "s3://my-bucket", bestExperiment: "job-ABC"}, + want: "s3://my-bucket/training_outputs/job-ABC/output/model.tar.gz"}, + {name: "hpo job: simple output path", args: args{job: &hpojobv1.HyperparameterTuningJob{}, prefix: "s3://my-bucket", bestExperiment: "job-ABC"}, + want: "s3://my-bucket/hyperparameter_tuning_outputs/job-ABC/output/model.tar.gz"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := createModelOutputPath(tt.args.job, tt.args.prefix, tt.args.bestExperiment); got != tt.want { + t.Errorf("createModelOutputPath() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go/tasks/plugins/k8s/sagemaker/plugin.go b/go/tasks/plugins/k8s/sagemaker/plugin.go new file mode 100644 index 000000000..4e87eb665 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/plugin.go @@ -0,0 +1,99 @@ +package sagemaker + +import ( + "context" + + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" + + "k8s.io/client-go/kubernetes/scheme" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" +) + +// Sanity test that the plugin implements method of k8s.Plugin +var _ k8s.Plugin = awsSagemakerPlugin{} + +type awsSagemakerPlugin struct { + TaskType pluginsCore.TaskType +} + +func (m awsSagemakerPlugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (k8s.Resource, error) { + if m.TaskType == trainingJobTaskType || m.TaskType == customTrainingJobTaskType { + return &trainingjobv1.TrainingJob{}, nil + } + if m.TaskType == hyperparameterTuningJobTaskType { + return &hpojobv1.HyperparameterTuningJob{}, nil + } + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The sagemaker plugin is unable to build identity resource for an unknown task type [%v]", m.TaskType) +} + +func (m awsSagemakerPlugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + + // Unmarshal the custom field of the task template back into the HyperparameterTuningJob struct generated in flyteidl + if m.TaskType == trainingJobTaskType { + return m.buildResourceForTrainingJob(ctx, taskCtx) + } + if m.TaskType == customTrainingJobTaskType { + return m.buildResourceForCustomTrainingJob(ctx, taskCtx) + } + if m.TaskType == hyperparameterTuningJobTaskType { + return m.buildResourceForHyperparameterTuningJob(ctx, taskCtx) + } + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The SageMaker plugin is unable to build resource for unknown task type [%s]", m.TaskType) +} + +func (m awsSagemakerPlugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource k8s.Resource) (pluginsCore.PhaseInfo, error) { + if m.TaskType == trainingJobTaskType { + job := resource.(*trainingjobv1.TrainingJob) + return m.getTaskPhaseForTrainingJob(ctx, pluginContext, job) + } else if m.TaskType == customTrainingJobTaskType { + job := resource.(*trainingjobv1.TrainingJob) + return m.getTaskPhaseForCustomTrainingJob(ctx, pluginContext, job) + } else if m.TaskType == hyperparameterTuningJobTaskType { + job := resource.(*hpojobv1.HyperparameterTuningJob) + return m.getTaskPhaseForHyperparameterTuningJob(ctx, pluginContext, job) + } + return pluginsCore.PhaseInfoUndefined, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "cannot get task phase for unknown task type [%s]", m.TaskType) +} + +func init() { + if err := commonv1.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + // Registering the plugin for HyperparameterTuningJob + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: hyperparameterTuningJobTaskPluginID, + RegisteredTaskTypes: []pluginsCore.TaskType{hyperparameterTuningJobTaskType}, + ResourceToWatch: &hpojobv1.HyperparameterTuningJob{}, + Plugin: awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType}, + IsDefault: false, + }) + + // Registering the plugin for standalone TrainingJob + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: trainingJobTaskPluginID, + RegisteredTaskTypes: []pluginsCore.TaskType{trainingJobTaskType}, + ResourceToWatch: &trainingjobv1.TrainingJob{}, + Plugin: awsSagemakerPlugin{TaskType: trainingJobTaskType}, + IsDefault: false, + }) + + // Registering the plugin for custom TrainingJob + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: customTrainingJobTaskPluginID, + RegisteredTaskTypes: []pluginsCore.TaskType{customTrainingJobTaskType}, + ResourceToWatch: &trainingjobv1.TrainingJob{}, + Plugin: awsSagemakerPlugin{TaskType: customTrainingJobTaskType}, + IsDefault: false, + }) +} diff --git a/go/tasks/plugins/k8s/sagemaker/plugin_test.go b/go/tasks/plugins/k8s/sagemaker/plugin_test.go new file mode 100644 index 000000000..ef5f592c8 --- /dev/null +++ b/go/tasks/plugins/k8s/sagemaker/plugin_test.go @@ -0,0 +1,60 @@ +package sagemaker + +import ( + "context" + "reflect" + "testing" + + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils/labeled" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + + hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" + trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" +) + +func Test_awsSagemakerPlugin_BuildIdentityResource(t *testing.T) { + ctx := context.TODO() + type fields struct { + TaskType pluginsCore.TaskType + } + type args struct { + in0 context.Context + in1 pluginsCore.TaskExecutionMetadata + } + tests := []struct { + name string + fields fields + args args + want k8s.Resource + wantErr bool + }{ + {name: "Training Job Identity Resource", fields: fields{TaskType: trainingJobTaskType}, + args: args{in0: ctx, in1: genMockTaskExecutionMetadata()}, want: &trainingjobv1.TrainingJob{}, wantErr: false}, + {name: "HPO Job Identity Resource", fields: fields{TaskType: hyperparameterTuningJobTaskType}, + args: args{in0: ctx, in1: genMockTaskExecutionMetadata()}, want: &hpojobv1.HyperparameterTuningJob{}, wantErr: false}, + {name: "Unsupported Job Identity Resource", fields: fields{TaskType: "bad type"}, + args: args{in0: ctx, in1: genMockTaskExecutionMetadata()}, want: nil, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := awsSagemakerPlugin{ + TaskType: tt.fields.TaskType, + } + got, err := m.BuildIdentityResource(tt.args.in0, tt.args.in1) + if (err != nil) != tt.wantErr { + t.Errorf("BuildIdentityResource() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("BuildIdentityResource() got = %v, want %v", got, tt.want) + } + }) + } +} + +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} diff --git a/go/tasks/plugins/k8s/sagemaker/sagemaker_test.go b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go similarity index 58% rename from go/tasks/plugins/k8s/sagemaker/sagemaker_test.go rename to go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go index 6aa2c2a36..f468c639f 100644 --- a/go/tasks/plugins/k8s/sagemaker/sagemaker_test.go +++ b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go @@ -1,18 +1,12 @@ package sagemaker import ( - "context" - "testing" + "github.com/lyft/flytestdlib/promutils" - "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" + "github.com/pkg/errors" "github.com/golang/protobuf/proto" - stdConfig "github.com/lyft/flytestdlib/config" - "github.com/lyft/flytestdlib/config/viper" - - hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" - trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" @@ -23,7 +17,6 @@ import ( pluginIOMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/lyft/flytestdlib/storage" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -39,7 +32,18 @@ var ( } testArgs = []string{ - "test-args", + "service_venv", + "pyflyte-execute", + "--test-opt1", + "value1", + "--test-opt2", + "value2", + "--test-flag", + } + + testCmds = []string{ + "test-cmds1", + "test-cmds2", } resourceRequirements = &corev1.ResourceRequirements{ @@ -74,9 +78,10 @@ func generateMockTrainingJobTaskTemplate(id string, trainingJobCustomObj *sagema Type: "container", Target: &flyteIdlCore.TaskTemplate_Container{ Container: &flyteIdlCore.Container{ - Image: testImage, - Args: testArgs, - Env: dummyEnvVars, + Command: testCmds, + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, }, }, Custom: &structObj, @@ -107,10 +112,113 @@ func generateMockHyperparameterTuningJobTaskTemplate(id string, hpoJobCustomObj }, }, Custom: &structObj, + Interface: &flyteIdlCore.TypedInterface{ + Inputs: &flyteIdlCore.VariableMap{ + Variables: map[string]*flyteIdlCore.Variable{ + "input": { + Type: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_CollectionType{ + CollectionType: &flyteIdlCore.LiteralType{Type: &flyteIdlCore.LiteralType_Simple{Simple: flyteIdlCore.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + Outputs: &flyteIdlCore.VariableMap{ + Variables: map[string]*flyteIdlCore.Variable{ + "output": { + Type: &flyteIdlCore.LiteralType{ + Type: &flyteIdlCore.LiteralType_CollectionType{ + CollectionType: &flyteIdlCore.LiteralType{Type: &flyteIdlCore.LiteralType_Simple{Simple: flyteIdlCore.SimpleType_INTEGER}}, + }, + }, + }, + }, + }, + }, + } +} + +// nolint +func generateMockCustomTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate, outputReaderPutError bool) pluginsCore.TaskExecutionContext { + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) + inputReader.OnGetInputPath().Return(storage.DataReference("/input")) + + trainBlobLoc := storage.DataReference("train-blob-loc") + validationBlobLoc := storage.DataReference("validation-blob-loc") + + inputReader.OnGetMatch(mock.Anything).Return( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "train": generateMockBlobLiteral(trainBlobLoc), + "validation": generateMockBlobLiteral(validationBlobLoc), + "hp_int": utils.MustMakeLiteral(1), + "hp_float": utils.MustMakeLiteral(1.5), + "hp_bool": utils.MustMakeLiteral(false), + "hp_string": utils.MustMakeLiteral("a"), + }, + }, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) + outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/")) + if outputReaderPutError { + outputReader.OnPutMatch(mock.Anything, mock.Anything).Return(errors.Errorf("err")) + } else { + outputReader.OnPutMatch(mock.Anything, mock.Anything).Return(nil) + } + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(flyteIdlCore.TaskExecutionIdentifier{ + NodeExecutionId: &flyteIdlCore.NodeExecutionIdentifier{ + ExecutionId: &flyteIdlCore.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + resources := &mocks.TaskOverrides{} + resources.OnGetResources().Return(resourceRequirements) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(resources) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + if err != nil { + panic(err) } + taskCtx.OnDataStore().Return(dataStore) + + taskCtx.OnMaxDatasetSizeBytes().Return(10000) + + return taskCtx } -func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) pluginsCore.TaskExecutionContext { +// nolint +func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate, outputReaderPutError bool) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) @@ -120,43 +228,12 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) validationBlobLoc := storage.DataReference("validation-blob-loc") shp := map[string]string{"a": "1", "b": "2"} shpStructObj, _ := utils.MarshalObjToStruct(shp) + inputReader.OnGetMatch(mock.Anything).Return( &flyteIdlCore.LiteralMap{ Literals: map[string]*flyteIdlCore.Literal{ - "train": { - Value: &flyteIdlCore.Literal_Scalar{ - Scalar: &flyteIdlCore.Scalar{ - Value: &flyteIdlCore.Scalar_Blob{ - Blob: &flyteIdlCore.Blob{ - Uri: trainBlobLoc.String(), - Metadata: &flyteIdlCore.BlobMetadata{ - Type: &flyteIdlCore.BlobType{ - Dimensionality: flyteIdlCore.BlobType_SINGLE, - Format: "csv", - }, - }, - }, - }, - }, - }, - }, - "validation": { - Value: &flyteIdlCore.Literal_Scalar{ - Scalar: &flyteIdlCore.Scalar{ - Value: &flyteIdlCore.Scalar_Blob{ - Blob: &flyteIdlCore.Blob{ - Uri: validationBlobLoc.String(), - Metadata: &flyteIdlCore.BlobMetadata{ - Type: &flyteIdlCore.BlobType{ - Dimensionality: flyteIdlCore.BlobType_SINGLE, - Format: "csv", - }, - }, - }, - }, - }, - }, - }, + "train": generateMockBlobLiteral(trainBlobLoc), + "validation": generateMockBlobLiteral(validationBlobLoc), "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), }, }, nil) @@ -165,6 +242,10 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) outputReader := &pluginIOMocks.OutputWriter{} outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/")) + if outputReaderPutError { + outputReader.OnPutMatch(mock.Anything).Return(errors.Errorf("err")) + } taskCtx.OnOutputWriter().Return(outputReader) taskReader := &mocks.TaskReader{} @@ -202,6 +283,26 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) return taskCtx } +func generateMockBlobLiteral(loc storage.DataReference) *flyteIdlCore.Literal { + return &flyteIdlCore.Literal{ + Value: &flyteIdlCore.Literal_Scalar{ + Scalar: &flyteIdlCore.Scalar{ + Value: &flyteIdlCore.Scalar_Blob{ + Blob: &flyteIdlCore.Blob{ + Uri: loc.String(), + Metadata: &flyteIdlCore.BlobMetadata{ + Type: &flyteIdlCore.BlobType{ + Dimensionality: flyteIdlCore.BlobType_SINGLE, + Format: "csv", + }, + }, + }, + }, + }, + }, + } +} + func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -238,40 +339,8 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T inputReader.OnGetMatch(mock.Anything).Return( &flyteIdlCore.LiteralMap{ Literals: map[string]*flyteIdlCore.Literal{ - "train": { - Value: &flyteIdlCore.Literal_Scalar{ - Scalar: &flyteIdlCore.Scalar{ - Value: &flyteIdlCore.Scalar_Blob{ - Blob: &flyteIdlCore.Blob{ - Uri: trainBlobLoc.String(), - Metadata: &flyteIdlCore.BlobMetadata{ - Type: &flyteIdlCore.BlobType{ - Dimensionality: flyteIdlCore.BlobType_SINGLE, - Format: "csv", - }, - }, - }, - }, - }, - }, - }, - "validation": { - Value: &flyteIdlCore.Literal_Scalar{ - Scalar: &flyteIdlCore.Scalar{ - Value: &flyteIdlCore.Scalar_Blob{ - Blob: &flyteIdlCore.Blob{ - Uri: validationBlobLoc.String(), - Metadata: &flyteIdlCore.BlobMetadata{ - Type: &flyteIdlCore.BlobType{ - Dimensionality: flyteIdlCore.BlobType_SINGLE, - Format: "csv", - }, - }, - }, - }, - }, - }, - }, + "train": generateMockBlobLiteral(trainBlobLoc), + "validation": generateMockBlobLiteral(validationBlobLoc), "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), "hyperparameter_tuning_job_config": utils.MakeBinaryLiteral(hpoJobConfigByteArray), }, @@ -281,12 +350,18 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T outputReader := &pluginIOMocks.OutputWriter{} outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/")) taskCtx.OnOutputWriter().Return(outputReader) taskReader := &mocks.TaskReader{} taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) taskCtx.OnTaskReader().Return(taskReader) + taskExecutionMetadata := genMockTaskExecutionMetadata() + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx +} +func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata { tID := &mocks.TaskExecutionID{} tID.OnGetID().Return(flyteIdlCore.TaskExecutionIdentifier{ NodeExecutionId: &flyteIdlCore.NodeExecutionIdentifier{ @@ -297,6 +372,7 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T }, }, }) + tID.OnGetGeneratedName().Return("some-acceptable-name") resources := &mocks.TaskOverrides{} @@ -314,8 +390,7 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T taskExecutionMetadata.OnIsInterruptible().Return(true) taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) - taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) - return taskCtx + return taskExecutionMetadata } // nolint @@ -347,108 +422,3 @@ func generateMockHyperparameterTuningJobCustomObj( MaxParallelTrainingJobs: maxParallelTrainingJobs, } } - -func Test_awsSagemakerPlugin_BuildResourceForTrainingJob(t *testing.T) { - // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config - ctx := context.TODO() - defaultCfg := config.GetSagemakerConfig() - awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: trainingJobTaskType} - - tjObj := generateMockTrainingJobCustomObj( - sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, - sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) - taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) - - trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate)) - assert.NoError(t, err) - assert.NotNil(t, trainingJobResource) - - trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob) - assert.True(t, ok) - assert.Equal(t, "default_role", *trainingJob.Spec.RoleArn) - assert.Equal(t, "File", string(trainingJob.Spec.AlgorithmSpecification.TrainingInputMode)) - - // Injecting a config which contains a matching roleAnnotationKey -> expecting to get the role from metadata - configAccessor := viper.NewAccessor(stdConfig.Options{ - StrictMode: true, - SearchPaths: []string{"testdata/config.yaml"}, - }) - - err = configAccessor.UpdateConfig(context.TODO()) - assert.NoError(t, err) - - awsSageMakerTrainingJobHandler = awsSagemakerPlugin{TaskType: trainingJobTaskType} - - tjObj = generateMockTrainingJobCustomObj( - sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, - sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) - taskTemplate = generateMockTrainingJobTaskTemplate("the job", tjObj) - - trainingJobResource, err = awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate)) - assert.NoError(t, err) - assert.NotNil(t, trainingJobResource) - - trainingJob, ok = trainingJobResource.(*trainingjobv1.TrainingJob) - assert.True(t, ok) - assert.Equal(t, "metadata_role", *trainingJob.Spec.RoleArn) - - // Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config - configAccessor = viper.NewAccessor(stdConfig.Options{ - StrictMode: true, - // Use a different - SearchPaths: []string{"testdata/config2.yaml"}, - }) - - err = configAccessor.UpdateConfig(context.TODO()) - assert.NoError(t, err) - - awsSageMakerTrainingJobHandler = awsSagemakerPlugin{TaskType: trainingJobTaskType} - - tjObj = generateMockTrainingJobCustomObj( - sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, - sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) - taskTemplate = generateMockTrainingJobTaskTemplate("the job", tjObj) - - trainingJobResource, err = awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate)) - assert.NoError(t, err) - assert.NotNil(t, trainingJobResource) - - trainingJob, ok = trainingJobResource.(*trainingjobv1.TrainingJob) - assert.True(t, ok) - assert.Equal(t, "config_role", *trainingJob.Spec.RoleArn) - - err = config.SetSagemakerConfig(defaultCfg) - if err != nil { - panic(err) - } -} - -func Test_awsSagemakerPlugin_BuildResourceForHyperparameterTuningJob(t *testing.T) { - // Default config does not contain a roleAnnotationKey -> expecting to get the role from default config - ctx := context.TODO() - defaultCfg := config.GetSagemakerConfig() - awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} - - tjObj := generateMockTrainingJobCustomObj( - sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, - sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25) - htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) - taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) - hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, generateMockHyperparameterTuningJobTaskContext(taskTemplate)) - assert.NoError(t, err) - assert.NotNil(t, hpoJobResource) - - hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) - assert.True(t, ok) - assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition) - assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.IntegerParameterRanges)) - assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.ContinuousParameterRanges)) - assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.CategoricalParameterRanges)) - assert.Equal(t, "us-east-1", *hpoJob.Spec.Region) - assert.Equal(t, "default_role", *hpoJob.Spec.TrainingJobDefinition.RoleArn) - - err = config.SetSagemakerConfig(defaultCfg) - if err != nil { - panic(err) - } -} diff --git a/go/tasks/plugins/k8s/sagemaker/sagemaker.go b/go/tasks/plugins/k8s/sagemaker/sagemaker.go deleted file mode 100644 index 0bcbb3899..000000000 --- a/go/tasks/plugins/k8s/sagemaker/sagemaker.go +++ /dev/null @@ -1,631 +0,0 @@ -package sagemaker - -import ( - "context" - "fmt" - "strings" - "time" - - awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils" - - hpojobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/hyperparametertuningjob" - trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob" - "github.com/lyft/flytestdlib/logger" - "github.com/pkg/errors" - "k8s.io/client-go/kubernetes/scheme" - - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" - - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - - pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - - commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" - hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" - trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" - "github.com/aws/aws-sdk-go/service/sagemaker" - - taskError "github.com/lyft/flyteplugins/go/tasks/errors" - - sagemakerSpec "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" - - "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" -) - -// Sanity test that the plugin implements method of k8s.Plugin -var _ k8s.Plugin = awsSagemakerPlugin{} - -type awsSagemakerPlugin struct { - TaskType pluginsCore.TaskType -} - -func (m awsSagemakerPlugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (k8s.Resource, error) { - if m.TaskType == trainingJobTaskType { - return &trainingjobv1.TrainingJob{}, nil - } - if m.TaskType == hyperparameterTuningJobTaskType { - return &hpojobv1.HyperparameterTuningJob{}, nil - } - return nil, errors.Errorf("The sagemaker plugin is unable to build identity resource for an unknown task type [%v]", m.TaskType) -} - -func (m awsSagemakerPlugin) BuildResourceForTrainingJob( - ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { - - logger.Infof(ctx, "Building a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) - taskTemplate, err := getTaskTemplate(ctx, taskCtx) - if err != nil { - return nil, err - } - - // Unmarshal the custom field of the task template back into the Hyperparameter Tuning Job struct generated in flyteidl - sagemakerTrainingJob := sagemakerSpec.TrainingJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob) - if err != nil { - return nil, errors.Wrapf(err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType) - } - if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil { - return nil, errors.Errorf("Required field [TrainingJobResourceConfig] of the TrainingJob does not exist") - } - - taskInput, err := taskCtx.InputReader().Get(ctx) - if err != nil { - return nil, errors.Wrapf(err, "unable to fetch task inputs") - } - - // Get inputs from literals - inputLiterals := taskInput.GetLiterals() - - trainPathLiteral, ok := inputLiterals["train"] - if !ok { - return nil, errors.Errorf("Required input not specified: [train]") - } - if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { - return nil, errors.Errorf("[train] Input is required and should be of Type [Scalar.Blob]") - } - validatePathLiteral, ok := inputLiterals["validation"] - if !ok { - return nil, errors.Errorf("Required input not specified: [validation]") - } - if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil { - return nil, errors.Errorf("[validation] Input is required and should be of Type [Scalar.Blob]") - } - staticHyperparamsLiteral, ok := inputLiterals["static_hyperparameters"] - if !ok { - return nil, errors.Errorf("Required input not specified: [static_hyperparameters]") - } - - outputPath := createOutputPath(taskCtx.OutputWriter().GetOutputPrefixPath().String()) - - // Convert the hyperparameters to the spec value - staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) - if err != nil { - return nil, errors.Wrapf(err, "could not convert static hyperparameters to spec type") - } - - taskName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetExecutionId().GetName() - - trainingImageStr, err := getTrainingImage(ctx, &sagemakerTrainingJob) - if err != nil { - return nil, errors.Wrapf(err, "failed to find the training image") - } - - logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", staticHyperparams) - - cfg := config.GetSagemakerConfig() - - if sagemakerTrainingJob.GetAlgorithmSpecification() == nil { - return nil, errors.Errorf("Required field [AlgorithmSpecification] does not exist") - } - var metricDefinitions []commonv1.MetricDefinition - idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions() - for _, md := range idlMetricDefinitions { - metricDefinitions = append(metricDefinitions, - commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) - } - - apiContentType, err := getAPIContentType(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType()) - if err != nil { - return nil, errors.Wrapf(err, "Unsupported input file type [%v]", sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType().String()) - } - - inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String())) - - role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) - if role == "" { - role = cfg.RoleArn - } - trainingJob := &trainingjobv1.TrainingJob{ - Spec: trainingjobv1.TrainingJobSpec{ - AlgorithmSpecification: &commonv1.AlgorithmSpecification{ - // If the specify a value for this AlgorithmName parameter, the user can't specify a value for TrainingImage. - // in this Flyte plugin, we always use the algorithm name and version the user provides via Flytekit to map to an image - // so we intentionally leave this field nil - AlgorithmName: nil, - TrainingImage: ToStringPtr(trainingImageStr), - TrainingInputMode: commonv1.TrainingInputMode(inputModeString), - MetricDefinitions: metricDefinitions, - }, - // The support of spot training will come in a later version - EnableManagedSpotTraining: nil, - HyperParameters: staticHyperparams, - InputDataConfig: []commonv1.Channel{ - { - ChannelName: ToStringPtr("train"), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), - InputMode: inputModeString, - }, - { - ChannelName: ToStringPtr("validation"), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), - InputMode: inputModeString, - }, - }, - OutputDataConfig: &commonv1.OutputDataConfig{ - S3OutputPath: ToStringPtr(outputPath), - }, - CheckpointConfig: nil, - ResourceConfig: &commonv1.ResourceConfig{ - InstanceType: sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceType(), - InstanceCount: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceCount()), - VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()), - VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future - }, - RoleArn: ToStringPtr(role), - Region: ToStringPtr(cfg.Region), - StoppingCondition: &commonv1.StoppingCondition{ - MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout - MaxWaitTimeInSeconds: nil, // TODO: decide how to coordinate this and Flyte's timeout and queueing budget - }, - TensorBoardOutputConfig: nil, - Tags: nil, - TrainingJobName: &taskName, - }, - } - logger.Infof(ctx, "Successfully built a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) - return trainingJob, nil -} - -func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob( - ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { - - logger.Infof(ctx, "Building a hyperparameter tuning job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) - - taskTemplate, err := getTaskTemplate(ctx, taskCtx) - if err != nil { - return nil, err - } - - // Unmarshal the custom field of the task template back into the HyperparameterTuningJob struct generated in flyteidl - sagemakerHPOJob := sagemakerSpec.HyperparameterTuningJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerHPOJob) - if err != nil { - return nil, errors.Wrapf(err, "invalid HyperparameterTuningJob task specification: not able to unmarshal the custom field to [%s]", hyperparameterTuningJobTaskType) - } - if sagemakerHPOJob.GetTrainingJob() == nil { - return nil, errors.Errorf("Required field [TrainingJob] of the HyperparameterTuningJob does not exist") - } - if sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification() == nil { - return nil, errors.Errorf("Required field [AlgorithmSpecification] of the HyperparameterTuningJob's underlying TrainingJob does not exist") - } - if sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig() == nil { - return nil, errors.Errorf("Required field [TrainingJobResourceConfig] of the HyperparameterTuningJob's underlying TrainingJob does not exist") - } - - taskInput, err := taskCtx.InputReader().Get(ctx) - if err != nil { - return nil, errors.Wrapf(err, "unable to fetch task inputs") - } - - // Get inputs from literals - inputLiterals := taskInput.GetLiterals() - - trainPathLiteral, ok := inputLiterals["train"] - if !ok { - return nil, errors.Errorf("Required input not specified: [train]") - } - if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { - return nil, errors.Errorf("[train] Input is required and should be of Type [Scalar.Blob]") - } - validatePathLiteral, ok := inputLiterals["validation"] - if !ok { - return nil, errors.Errorf("Required input not specified: [validation]") - } - if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil { - return nil, errors.Errorf("[validation] Input is required and should be of Type [Scalar.Blob]") - } - staticHyperparamsLiteral, ok := inputLiterals["static_hyperparameters"] - if !ok { - return nil, errors.Errorf("Required input not specified: [static_hyperparameters]") - } - - hpoJobConfigLiteral, ok := inputLiterals["hyperparameter_tuning_job_config"] - if !ok { - return nil, errors.Errorf("Required input not specified: [hyperparameter_tuning_job_config]") - } - - outputPath := createOutputPath(taskCtx.OutputWriter().GetOutputPrefixPath().String()) - - // Convert the hyperparameters to the spec value - staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) - if err != nil { - return nil, errors.Wrapf(err, "could not convert static hyperparameters to spec type") - } - - // hyperparameter_tuning_job_config is marshaled into a byte array in flytekit, so will have to unmarshal it back - hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral) - if err != nil { - return nil, errors.Wrapf(err, "failed to convert hyperparameter tuning job config literal to spec type") - } - - if hpoJobConfig.GetTuningObjective() == nil { - return nil, errors.Errorf("Required field [TuningObjective] does not exist") - } - - // Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter - // and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry - // in the static map and let the one in the map of the tunable hyperparameters take precedence - staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobConfig.GetHyperparameterRanges().GetParameterRangeMap()) - - taskName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetExecutionId().GetName() - - trainingImageStr, err := getTrainingImage(ctx, sagemakerHPOJob.GetTrainingJob()) - if err != nil { - return nil, errors.Wrapf(err, "failed to find the training image") - } - - hpoJobParameterRanges := buildParameterRanges(hpoJobConfig) - - logger.Infof(ctx, "The Sagemaker HyperparameterTuningJob Task plugin received the following inputs: \n"+ - "static hyperparameters: [%v]\n"+ - "hyperparameter tuning job config: [%v]\n"+ - "parameter ranges: [%v]", staticHyperparams, hpoJobConfig, hpoJobParameterRanges) - - cfg := config.GetSagemakerConfig() - - var metricDefinitions []commonv1.MetricDefinition - idlMetricDefinitions := sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetMetricDefinitions() - for _, md := range idlMetricDefinitions { - metricDefinitions = append(metricDefinitions, - commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) - } - - apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType()) - if err != nil { - return nil, errors.Wrapf(err, "Unsupported input file type [%v]", - sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String()) - } - - inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String())) - tuningStrategyString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningStrategy().String())) - tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String())) - trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String())) - - role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) - if role == "" { - role = cfg.RoleArn - } - - hpoJob := &hpojobv1.HyperparameterTuningJob{ - Spec: hpojobv1.HyperparameterTuningJobSpec{ - HyperParameterTuningJobName: &taskName, - HyperParameterTuningJobConfig: &commonv1.HyperParameterTuningJobConfig{ - ResourceLimits: &commonv1.ResourceLimits{ - MaxNumberOfTrainingJobs: ToInt64Ptr(sagemakerHPOJob.GetMaxNumberOfTrainingJobs()), - MaxParallelTrainingJobs: ToInt64Ptr(sagemakerHPOJob.GetMaxParallelTrainingJobs()), - }, - Strategy: commonv1.HyperParameterTuningJobStrategyType(tuningStrategyString), - HyperParameterTuningJobObjective: &commonv1.HyperParameterTuningJobObjective{ - Type: commonv1.HyperParameterTuningJobObjectiveType(tuningObjectiveTypeString), - MetricName: ToStringPtr(hpoJobConfig.GetTuningObjective().GetMetricName()), - }, - ParameterRanges: hpoJobParameterRanges, - TrainingJobEarlyStoppingType: commonv1.TrainingJobEarlyStoppingType(trainingJobEarlyStoppingTypeString), - }, - TrainingJobDefinition: &commonv1.HyperParameterTrainingJobDefinition{ - StaticHyperParameters: staticHyperparams, - AlgorithmSpecification: &commonv1.HyperParameterAlgorithmSpecification{ - TrainingImage: ToStringPtr(trainingImageStr), - TrainingInputMode: commonv1.TrainingInputMode(inputModeString), - MetricDefinitions: metricDefinitions, - AlgorithmName: nil, - }, - InputDataConfig: []commonv1.Channel{ - { - ChannelName: ToStringPtr("train"), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata - InputMode: inputModeString, - }, - { - ChannelName: ToStringPtr("validation"), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata - InputMode: inputModeString, - }, - }, - OutputDataConfig: &commonv1.OutputDataConfig{ - S3OutputPath: ToStringPtr(outputPath), - }, - ResourceConfig: &commonv1.ResourceConfig{ - InstanceType: sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetInstanceType(), - InstanceCount: ToInt64Ptr(sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetInstanceCount()), - VolumeSizeInGB: ToInt64Ptr(sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetVolumeSizeInGb()), - VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future - }, - RoleArn: ToStringPtr(role), - StoppingCondition: &commonv1.StoppingCondition{ - MaxRuntimeInSeconds: ToInt64Ptr(86400), - MaxWaitTimeInSeconds: nil, - }, - }, - Region: ToStringPtr(cfg.Region), - }, - } - - logger.Infof(ctx, "Successfully built a hyperparameter tuning job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) - return hpoJob, nil -} - -func getTaskTemplate(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (*core.TaskTemplate, error) { - taskTemplate, err := taskCtx.TaskReader().Read(ctx) - if err != nil { - return nil, errors.Wrapf(err, "unable to fetch task specification") - } else if taskTemplate == nil { - return nil, errors.Errorf("nil task specification") - } - return taskTemplate, nil -} - -func (m awsSagemakerPlugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { - - // Unmarshal the custom field of the task template back into the HyperparameterTuningJob struct generated in flyteidl - if m.TaskType == trainingJobTaskType { - return m.BuildResourceForTrainingJob(ctx, taskCtx) - } - if m.TaskType == hyperparameterTuningJobTaskType { - return m.BuildResourceForHyperparameterTuningJob(ctx, taskCtx) - } - return nil, errors.Errorf("The SageMaker plugin is unable to build resource for unknown task type [%s]", m.TaskType) -} - -func (m awsSagemakerPlugin) getEventInfoForJob(ctx context.Context, job k8s.Resource) (*pluginsCore.TaskInfo, error) { - - var jobRegion, jobName, jobTypeInURL, sagemakerLinkName string - if m.TaskType == trainingJobTaskType { - trainingJob := job.(*trainingjobv1.TrainingJob) - jobRegion = *trainingJob.Spec.Region - jobName = *trainingJob.Spec.TrainingJobName - jobTypeInURL = "jobs" - sagemakerLinkName = "SageMaker Training Job" - } else if m.TaskType == hyperparameterTuningJobTaskType { - trainingJob := job.(*hpojobv1.HyperparameterTuningJob) - jobRegion = *trainingJob.Spec.Region - jobName = *trainingJob.Spec.HyperParameterTuningJobName - jobTypeInURL = "hyper-tuning-jobs" - sagemakerLinkName = "SageMaker Hyperparameter Tuning Job" - } else { - return nil, errors.Errorf("The plugin is unable to get event info for unknown task type {%v}", m.TaskType) - } - - logger.Infof(ctx, "Getting event information for task type: [%v], job region: [%v], job name: [%v], "+ - "job type in url: [%v], sagemaker link name: [%v]", m.TaskType, jobRegion, jobName, jobTypeInURL, sagemakerLinkName) - - cwLogURL := fmt.Sprintf("https://%s.console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=/aws/sagemaker/TrainingJobs;prefix=%s;streamFilter=typeLogStreamPrefix", - jobRegion, jobRegion, jobName) - smLogURL := fmt.Sprintf("https://%s.console.aws.amazon.com/sagemaker/home?region=%s#/%s/%s", - jobRegion, jobRegion, jobTypeInURL, jobName) - - taskLogs := []*core.TaskLog{ - { - Uri: cwLogURL, - Name: "CloudWatch Logs", - MessageFormat: core.TaskLog_JSON, - }, - { - Uri: smLogURL, - Name: sagemakerLinkName, - MessageFormat: core.TaskLog_UNKNOWN, - }, - } - - customInfoMap := make(map[string]string) - - customInfo, err := utils.MarshalObjToStruct(customInfoMap) - if err != nil { - return nil, err - } - - return &pluginsCore.TaskInfo{ - Logs: taskLogs, - CustomInfo: customInfo, - }, nil -} - -func getOutputs(ctx context.Context, tr pluginsCore.TaskReader, outputPath string) (*core.LiteralMap, error) { - tk, err := tr.Read(ctx) - if err != nil { - return nil, err - } - if tk.Interface.Outputs != nil && tk.Interface.Outputs.Variables == nil { - logger.Warnf(ctx, "No outputs declared in the output interface. Ignoring the generated outputs.") - return nil, nil - } - - // We know that for XGBoost task there is only one output to be generated - if len(tk.Interface.Outputs.Variables) > 1 { - return nil, fmt.Errorf("expected to generate more than one outputs of type [%v]", tk.Interface.Outputs.Variables) - } - op := createOutputLiteralMap(tk, outputPath) - return op, nil -} - -func createOutputPath(prefix string) string { - return fmt.Sprintf("%s/hyperparameter_tuning_outputs", prefix) -} - -func createModelOutputPath(prefix, bestExperiment string) string { - return fmt.Sprintf("%s/%s/output/model.tar.gz", createOutputPath(prefix), bestExperiment) -} - -func (m awsSagemakerPlugin) GetTaskPhaseForTrainingJob( - ctx context.Context, pluginContext k8s.PluginContext, trainingJob *trainingjobv1.TrainingJob) (pluginsCore.PhaseInfo, error) { - - logger.Infof(ctx, "Getting task phase for sagemaker training job [%v]", trainingJob.Status.SageMakerTrainingJobName) - info, err := m.getEventInfoForJob(ctx, trainingJob) - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - - occurredAt := time.Now() - - switch trainingJob.Status.TrainingJobStatus { - case trainingjobController.ReconcilingTrainingJobStatus: - logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", trainingJob.Status.Additional) - // TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call errors - execError := &core.ExecutionError{ - Message: trainingJob.Status.Additional, - Kind: core.ExecutionError_USER, - Code: trainingjobController.ReconcilingTrainingJobStatus, - } - return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil - case sagemaker.TrainingJobStatusFailed: - execError := &core.ExecutionError{ - Message: trainingJob.Status.Additional, - Kind: core.ExecutionError_USER, - Code: sagemaker.TrainingJobStatusFailed, - } - return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil - case sagemaker.TrainingJobStatusStopped: - reason := fmt.Sprintf("Training Job Stopped") - return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil - case sagemaker.TrainingJobStatusCompleted: - // Now that it is success we will set the outputs as expected by the task - out, err := getOutputs(ctx, pluginContext.TaskReader(), createModelOutputPath(pluginContext.OutputWriter().GetOutputPrefixPath().String(), trainingJob.Status.SageMakerTrainingJobName)) - if err != nil { - logger.Errorf(ctx, "Failed to create outputs, err: %s", err) - return pluginsCore.PhaseInfoUndefined, errors.Wrapf(err, "failed to create outputs for the task") - } - if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(out, nil)); err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - logger.Debugf(ctx, "Successfully produced and returned outputs") - return pluginsCore.PhaseInfoSuccess(info), nil - case "": - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil - } - - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil -} - -func (m awsSagemakerPlugin) GetTaskPhaseForHyperparameterTuningJob( - ctx context.Context, pluginContext k8s.PluginContext, hpoJob *hpojobv1.HyperparameterTuningJob) (pluginsCore.PhaseInfo, error) { - - logger.Infof(ctx, "Getting task phase for hyperparameter tuning job [%v]", hpoJob.Status.SageMakerHyperParameterTuningJobName) - info, err := m.getEventInfoForJob(ctx, hpoJob) - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - - occurredAt := time.Now() - - switch hpoJob.Status.HyperParameterTuningJobStatus { - case hpojobController.ReconcilingTuningJobStatus: - logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", hpoJob.Status.Additional) - // TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call errors - execError := &core.ExecutionError{ - Message: hpoJob.Status.Additional, - Kind: core.ExecutionError_USER, - Code: hpojobController.ReconcilingTuningJobStatus, - } - return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil - case sagemaker.HyperParameterTuningJobStatusFailed: - execError := &core.ExecutionError{ - Message: hpoJob.Status.Additional, - Kind: core.ExecutionError_USER, - Code: sagemaker.HyperParameterTuningJobStatusFailed, - } - return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil - case sagemaker.HyperParameterTuningJobStatusStopped: - reason := fmt.Sprintf("Hyperparameter tuning job stopped") - return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil - case sagemaker.HyperParameterTuningJobStatusCompleted: - // Now that it is success we will set the outputs as expected by the task - out, err := getOutputs(ctx, pluginContext.TaskReader(), createModelOutputPath(pluginContext.OutputWriter().GetOutputPrefixPath().String(), *hpoJob.Status.BestTrainingJob.TrainingJobName)) - if err != nil { - logger.Errorf(ctx, "Failed to create outputs, err: %s", err) - return pluginsCore.PhaseInfoUndefined, errors.Wrapf(err, "failed to create outputs for the task") - } - if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(out, nil)); err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - logger.Debugf(ctx, "Successfully produced and returned outputs") - return pluginsCore.PhaseInfoSuccess(info), nil - case "": - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil - } - - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil -} - -func (m awsSagemakerPlugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource k8s.Resource) (pluginsCore.PhaseInfo, error) { - if m.TaskType == trainingJobTaskType { - job := resource.(*trainingjobv1.TrainingJob) - return m.GetTaskPhaseForTrainingJob(ctx, pluginContext, job) - } else if m.TaskType == hyperparameterTuningJobTaskType { - job := resource.(*hpojobv1.HyperparameterTuningJob) - return m.GetTaskPhaseForHyperparameterTuningJob(ctx, pluginContext, job) - } - return pluginsCore.PhaseInfoUndefined, errors.Errorf("cannot get task phase for unknown task type [%s]", m.TaskType) -} - -func init() { - if err := commonv1.AddToScheme(scheme.Scheme); err != nil { - panic(err) - } - - // Registering the plugin for HyperparameterTuningJob - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: hyperparameterTuningJobTaskPluginID, - RegisteredTaskTypes: []pluginsCore.TaskType{hyperparameterTuningJobTaskType}, - ResourceToWatch: &hpojobv1.HyperparameterTuningJob{}, - Plugin: awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType}, - IsDefault: false, - }) - - // Registering the plugin for standalone TrainingJob - pluginmachinery.PluginRegistry().RegisterK8sPlugin( - k8s.PluginEntry{ - ID: trainingJobTaskPluginID, - RegisteredTaskTypes: []pluginsCore.TaskType{trainingJobTaskType}, - ResourceToWatch: &trainingjobv1.TrainingJob{}, - Plugin: awsSagemakerPlugin{TaskType: trainingJobTaskType}, - IsDefault: false, - }) -} diff --git a/go/tasks/plugins/k8s/sagemaker/testdata/config.yaml b/go/tasks/plugins/k8s/sagemaker/testdata/config.yaml index f669952b4..f30a7978c 100644 --- a/go/tasks/plugins/k8s/sagemaker/testdata/config.yaml +++ b/go/tasks/plugins/k8s/sagemaker/testdata/config.yaml @@ -9,6 +9,6 @@ plugins: - region: "us-west-2" versionConfigs: - version: "0.90" - image: "image-0.90" + image: "XGBOOST_us-west-2_image-0.90" - version: "1.0" - image: "image-1.0" + image: "XGBOOST_us-west-2_image-1.0" diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index 11155fa83..31c7060ef 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -3,23 +3,30 @@ package sagemaker import ( "context" "fmt" + "sort" "strings" + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flytestdlib/logger" "github.com/Masterminds/semver" commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" awssagemaker "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/controllertest" + "github.com/golang/protobuf/proto" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - sagemakerSpec "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" + flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + flyteSagemakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" "github.com/pkg/errors" - - "github.com/golang/protobuf/proto" ) -func getAPIContentType(fileType sagemakerSpec.InputContentType_Value) (string, error) { - if fileType == sagemakerSpec.InputContentType_TEXT_CSV { +func getAPIContentType(fileType flyteSagemakerIdl.InputContentType_Value) (string, error) { + if fileType == flyteSagemakerIdl.InputContentType_TEXT_CSV { return TEXTCSVInputContentType, nil } return "", errors.Errorf("Unsupported input file type [%v]", fileType.String()) @@ -44,12 +51,22 @@ func getLatestTrainingImage(versionConfigs []config.VersionConfig) (string, erro return latestImg, nil } -func getTrainingImage(ctx context.Context, job *sagemakerSpec.TrainingJob) (string, error) { +func getTrainingJobImage(ctx context.Context, _ pluginsCore.TaskExecutionContext, job *flyteSagemakerIdl.TrainingJob) (string, error) { + image, err := getPrebuiltTrainingImage(ctx, job) + if err != nil { + return "", errors.Wrapf(err, "Failed to get prebuilt image for job [%v]", *job) + } + return image, nil +} + +func getPrebuiltTrainingImage(ctx context.Context, job *flyteSagemakerIdl.TrainingJob) (string, error) { + // This function determines which image URI to put into the CRD of the training job and the hyperparameter tuning job + cfg := config.GetSagemakerConfig() var foundAlgorithmCfg *config.PrebuiltAlgorithmConfig var foundRegionalCfg *config.RegionalConfig - if specifiedAlg := job.GetAlgorithmSpecification().GetAlgorithmName(); specifiedAlg != sagemakerSpec.AlgorithmName_CUSTOM { + if specifiedAlg := job.GetAlgorithmSpecification().GetAlgorithmName(); specifiedAlg != flyteSagemakerIdl.AlgorithmName_CUSTOM { // Built-in algorithm mode apiAlgorithmName := specifiedAlg.String() @@ -108,10 +125,11 @@ func getTrainingImage(ctx context.Context, job *sagemakerSpec.TrainingJob) (stri return "", errors.Errorf("Failed to find an image for [%v]:[%v]:[%v]", job.GetAlgorithmSpecification().GetAlgorithmName(), cfg.Region, job.GetAlgorithmSpecification().GetAlgorithmVersion()) } - return "custom image", errors.Errorf("Custom images are not supported yet") + // Custom image + return "", errors.Errorf("It is invalid to try getting a prebuilt image for AlgorithmName == CUSTOM ") } -func buildParameterRanges(hpoJobConfig *sagemakerSpec.HyperparameterTuningJobConfig) *commonv1.ParameterRanges { +func buildParameterRanges(hpoJobConfig *flyteSagemakerIdl.HyperparameterTuningJobConfig) *commonv1.ParameterRanges { prMap := hpoJobConfig.GetHyperparameterRanges().GetParameterRangeMap() var retValue = &commonv1.ParameterRanges{ CategoricalParameterRanges: []commonv1.CategoricalParameterRange{}, @@ -122,14 +140,14 @@ func buildParameterRanges(hpoJobConfig *sagemakerSpec.HyperparameterTuningJobCon for prName, pr := range prMap { scalingTypeString := strings.Title(strings.ToLower(pr.GetContinuousParameterRange().GetScalingType().String())) switch pr.GetParameterRangeType().(type) { - case *sagemakerSpec.ParameterRangeOneOf_CategoricalParameterRange: + case *flyteSagemakerIdl.ParameterRangeOneOf_CategoricalParameterRange: var newElem = commonv1.CategoricalParameterRange{ Name: awssagemaker.ToStringPtr(prName), Values: pr.GetCategoricalParameterRange().GetValues(), } retValue.CategoricalParameterRanges = append(retValue.CategoricalParameterRanges, newElem) - case *sagemakerSpec.ParameterRangeOneOf_ContinuousParameterRange: + case *flyteSagemakerIdl.ParameterRangeOneOf_ContinuousParameterRange: var newElem = commonv1.ContinuousParameterRange{ MaxValue: awssagemaker.ToStringPtr(fmt.Sprintf("%f", pr.GetContinuousParameterRange().GetMaxValue())), MinValue: awssagemaker.ToStringPtr(fmt.Sprintf("%f", pr.GetContinuousParameterRange().GetMinValue())), @@ -138,7 +156,7 @@ func buildParameterRanges(hpoJobConfig *sagemakerSpec.HyperparameterTuningJobCon } retValue.ContinuousParameterRanges = append(retValue.ContinuousParameterRanges, newElem) - case *sagemakerSpec.ParameterRangeOneOf_IntegerParameterRange: + case *flyteSagemakerIdl.ParameterRangeOneOf_IntegerParameterRange: var newElem = commonv1.IntegerParameterRange{ MaxValue: awssagemaker.ToStringPtr(fmt.Sprintf("%d", pr.GetIntegerParameterRange().GetMaxValue())), MinValue: awssagemaker.ToStringPtr(fmt.Sprintf("%d", pr.GetIntegerParameterRange().GetMinValue())), @@ -152,8 +170,8 @@ func buildParameterRanges(hpoJobConfig *sagemakerSpec.HyperparameterTuningJobCon return retValue } -func convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral *core.Literal) (*sagemakerSpec.HyperparameterTuningJobConfig, error) { - var retValue = &sagemakerSpec.HyperparameterTuningJobConfig{} +func convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral *core.Literal) (*flyteSagemakerIdl.HyperparameterTuningJobConfig, error) { + var retValue = &flyteSagemakerIdl.HyperparameterTuningJobConfig{} if hpoJobConfigLiteral.GetScalar() == nil || hpoJobConfigLiteral.GetScalar().GetBinary() == nil { return nil, errors.Errorf("[Hyperparameters] should be of type [Scalar.Binary]") } @@ -174,13 +192,23 @@ func convertStaticHyperparamsLiteralToSpecType(hyperparamLiteral *core.Literal) if hyperFields == nil { return nil, errors.Errorf("Failed to get the static hyperparameters field from the literal") } - for k, v := range hyperFields { + + keys := make([]string, 0) + for k := range hyperFields { + keys = append(keys, k) + } + sort.Strings(keys) + + fmt.Printf("[%v]", keys) + for _, k := range keys { + v := hyperFields[k] var newElem = commonv1.KeyValuePair{ Name: k, Value: v.GetStringValue(), } retValue = append(retValue, &newElem) } + return retValue, nil } @@ -212,43 +240,117 @@ func ToFloat64Ptr(f float64) *float64 { return &f } -func createOutputLiteralMap(tk *core.TaskTemplate, outputPath string) *core.LiteralMap { - op := &core.LiteralMap{} - for k := range tk.Interface.Outputs.Variables { - // if v != core.LiteralType_Blob{} - op.Literals = make(map[string]*core.Literal) - op.Literals[k] = &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Blob{ - Blob: &core.Blob{ - Metadata: &core.BlobMetadata{ - Type: &core.BlobType{Dimensionality: core.BlobType_SINGLE}, - }, - Uri: outputPath, - }, - }, - }, - }, - } - } - return op -} - func deleteConflictingStaticHyperparameters( ctx context.Context, staticHPs []*commonv1.KeyValuePair, - tunableHPMap map[string]*sagemakerSpec.ParameterRangeOneOf) []*commonv1.KeyValuePair { + tunableHPMap map[string]*flyteSagemakerIdl.ParameterRangeOneOf) []*commonv1.KeyValuePair { - finalStaticHPs := make([]*commonv1.KeyValuePair, 0, len(staticHPs)) + resolvedStaticHPs := make([]*commonv1.KeyValuePair, 0, len(staticHPs)) for _, hp := range staticHPs { if _, found := tunableHPMap[hp.Name]; !found { - finalStaticHPs = append(finalStaticHPs, hp) + resolvedStaticHPs = append(resolvedStaticHPs, hp) } else { logger.Infof(ctx, "Static hyperparameter [%v] is removed because the same hyperparameter can be found in the map of tunable hyperparameters", hp.Name) } } - return finalStaticHPs + return resolvedStaticHPs +} + +func makeHyperparametersKeysValuesFromArgs(_ context.Context, args []string) []*commonv1.KeyValuePair { + ret := make([]*commonv1.KeyValuePair, 0) + for argOrder, arg := range args { + ret = append(ret, &commonv1.KeyValuePair{ + Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, argOrder, arg, FlyteSageMakerKeySuffix), + Value: FlyteSageMakerCmdDummyValue, + }) + } + return ret +} + +func injectTaskTemplateEnvVarToHyperparameters(ctx context.Context, taskTemplate *flyteIdlCore.TaskTemplate, hps []*commonv1.KeyValuePair) ([]*commonv1.KeyValuePair, error) { + if taskTemplate == nil || taskTemplate.GetContainer() == nil { + return hps, errors.Errorf("The taskTemplate is nil or the container is nil") + } + + if hps == nil { + return nil, errors.Errorf("A nil slice of hyperparameters is passed in") + } + + for _, ev := range taskTemplate.GetContainer().GetEnv() { + hpKey := fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, ev.Key, FlyteSageMakerKeySuffix) + logger.Infof(ctx, "Injecting env var {%v: %v} into the hyperparameter list", hpKey, ev.Value) + hps = append(hps, &commonv1.KeyValuePair{ + Name: hpKey, + Value: ev.Value}) + } + + return hps, nil +} + +func injectArgsAndEnvVars(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, taskTemplate *flyteIdlCore.TaskTemplate) ([]*commonv1.KeyValuePair, error) { + templateArgs := taskTemplate.GetContainer().GetArgs() + templateArgs, err := utils.ReplaceTemplateCommandArgs(ctx, templateArgs, taskCtx.InputReader(), taskCtx.OutputWriter()) + if err != nil { + return nil, errors.Wrapf(err, "Failed to de-template the hyperparameter values") + } + hyperParameters := makeHyperparametersKeysValuesFromArgs(ctx, templateArgs) + hyperParameters, err = injectTaskTemplateEnvVarToHyperparameters(ctx, taskTemplate, hyperParameters) + if err != nil { + return nil, errors.Wrapf(err, "Failed to inject the task template's container env vars to the hyperparameter list") + } + return hyperParameters, nil +} + +func checkIfRequiredInputLiteralsExist(inputLiterals map[string]*flyteIdlCore.Literal, inputKeys []string) error { + for _, inputKey := range inputKeys { + _, ok := inputLiterals[inputKey] + if !ok { + return errors.Errorf("Required input not specified: [%v]", inputKey) + } + } + return nil +} + +func getTaskTemplate(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (*flyteIdlCore.TaskTemplate, error) { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, errors.Wrapf(err, "unable to fetch task specification") + } else if taskTemplate == nil { + return nil, errors.Errorf("nil task specification") + } + return taskTemplate, nil +} + +func createTaskInfo(_ context.Context, jobRegion string, jobName string, jobTypeInURL string, sagemakerLinkName string) (*pluginsCore.TaskInfo, error) { + cwLogURL := fmt.Sprintf("https://%s.console.aws.amazon.com/cloudwatch/home?region=%s#logStream:group=/aws/sagemaker/TrainingJobs;prefix=%s;streamFilter=typeLogStreamPrefix", + jobRegion, jobRegion, jobName) + smLogURL := fmt.Sprintf("https://%s.console.aws.amazon.com/sagemaker/home?region=%s#/%s/%s", + jobRegion, jobRegion, jobTypeInURL, jobName) + + taskLogs := []*flyteIdlCore.TaskLog{ + { + Uri: cwLogURL, + Name: CloudWatchLogLinkName, + MessageFormat: flyteIdlCore.TaskLog_JSON, + }, + { + Uri: smLogURL, + Name: sagemakerLinkName, + MessageFormat: flyteIdlCore.TaskLog_UNKNOWN, + }, + } + + customInfoMap := make(map[string]string) + + customInfo, err := utils.MarshalObjToStruct(customInfoMap) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "Unable to create a custom info object") + } + + return &pluginsCore.TaskInfo{ + Logs: taskLogs, + CustomInfo: customInfo, + }, nil } diff --git a/go/tasks/plugins/k8s/sagemaker/utils_test.go b/go/tasks/plugins/k8s/sagemaker/utils_test.go index 3aa1777c6..641d141c4 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils_test.go +++ b/go/tasks/plugins/k8s/sagemaker/utils_test.go @@ -2,10 +2,13 @@ package sagemaker import ( "context" + "fmt" "reflect" "strconv" "testing" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common" sagemakerSpec "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" "github.com/lyft/flytestdlib/config/viper" @@ -13,6 +16,7 @@ import ( stdConfig "github.com/lyft/flytestdlib/config" + flyteSagemakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker" "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" sagemakerConfig "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config" ) @@ -221,7 +225,7 @@ func Test_getLatestTrainingImage(t *testing.T) { } } -func Test_getTrainingImage(t *testing.T) { +func Test_getPrebuiltTrainingImage(t *testing.T) { ctx := context.TODO() _ = sagemakerConfig.SetSagemakerConfig(generateMockSageMakerConfig()) @@ -265,23 +269,23 @@ func Test_getTrainingImage(t *testing.T) { InputContentType: 0, }, TrainingJobResourceConfig: nil, - }}, want: "custom image", wantErr: true}, + }}, want: "", wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getTrainingImage(tt.args.ctx, tt.args.job) + got, err := getPrebuiltTrainingImage(tt.args.ctx, tt.args.job) if (err != nil) != tt.wantErr { - t.Errorf("getTrainingImage() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("getPrebuiltTrainingImage() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { - t.Errorf("getTrainingImage() got = %v, want %v", got, tt.want) + t.Errorf("getPrebuiltTrainingImage() got = %v, want %v", got, tt.want) } }) } } -func Test_getTrainingImage_LoadConfig(t *testing.T) { +func Test_getPrebuiltTrainingImage_LoadConfig(t *testing.T) { configAccessor := viper.NewAccessor(stdConfig.Options{ StrictMode: true, SearchPaths: []string{"testdata/config.yaml"}, @@ -292,19 +296,127 @@ func Test_getTrainingImage_LoadConfig(t *testing.T) { assert.NotNil(t, config.GetSagemakerConfig()) - image, err := getTrainingImage(context.TODO(), &sagemakerSpec.TrainingJob{AlgorithmSpecification: &sagemakerSpec.AlgorithmSpecification{ + image, err := getPrebuiltTrainingImage(context.TODO(), &sagemakerSpec.TrainingJob{AlgorithmSpecification: &sagemakerSpec.AlgorithmSpecification{ AlgorithmName: sagemakerSpec.AlgorithmName_XGBOOST, AlgorithmVersion: "0.90", }}) assert.NoError(t, err) - assert.Equal(t, "image-0.90", image) + assert.Equal(t, "XGBOOST_us-west-2_image-0.90", image) - image, err = getTrainingImage(context.TODO(), &sagemakerSpec.TrainingJob{AlgorithmSpecification: &sagemakerSpec.AlgorithmSpecification{ + image, err = getPrebuiltTrainingImage(context.TODO(), &sagemakerSpec.TrainingJob{AlgorithmSpecification: &sagemakerSpec.AlgorithmSpecification{ AlgorithmName: sagemakerSpec.AlgorithmName_XGBOOST, AlgorithmVersion: "1.0", }}) assert.NoError(t, err) - assert.Equal(t, "image-1.0", image) + assert.Equal(t, "XGBOOST_us-west-2_image-1.0", image) +} + +func Test_getTrainingJobImage(t *testing.T) { + + ctx := context.TODO() + defaultCfg := config.GetSagemakerConfig() + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + type Result struct { + name string + want string + wantErr bool + } + + configAccessor := viper.NewAccessor(stdConfig.Options{ + StrictMode: true, + SearchPaths: []string{"testdata/config.yaml"}, + }) + + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + expectedResult := Result{ + "Should retrieve image url from config for built-in algorithms", "XGBOOST_us-west-2_image-0.90", false, + } + t.Run(expectedResult.name, func(t *testing.T) { + tjObj := generateMockTrainingJobCustomObj( + flyteSagemakerIdl.InputMode_FILE, + flyteSagemakerIdl.AlgorithmName_XGBOOST, + "0.90", + []*flyteSagemakerIdl.MetricDefinition{}, + flyteSagemakerIdl.InputContentType_TEXT_CSV, + 1, + "ml.m4.xlarge", + 25) + taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj) + taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false) + sagemakerTrainingJob := flyteSagemakerIdl.TrainingJob{} + err := utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob) + if err != nil { + panic(err) + } + + got, err := getTrainingJobImage(ctx, taskCtx, &sagemakerTrainingJob) + if (err != nil) != expectedResult.wantErr { + t.Errorf("getTrainingJobImage() error = %v, wantErr %v", err, expectedResult.wantErr) + return + } + if got != expectedResult.want { + t.Errorf("getTrainingJobImage() got = %v, want %v", got, expectedResult.want) + } + }) + +} + +func Test_makeHyperparametersKeysValuesFromArgs(t *testing.T) { + outputPrefix := "s3://abcdefghijklmnopqrtsuvwxyz/abcdefghijklmnopqrtsuvwxyz/abcdefghijklmnopqrtsuvwxyz" + inputs := "s3://abcdefghijklmnopqrtsuvwxyz/abcdefghijklmnopqrtsuvwxyz/abcdefghijklmnopqrtsuvwxyz/inputs.pb" + type args struct { + in0 context.Context + args []string + } + tests := []struct { + name string + args args + want []*commonv1.KeyValuePair + }{ + {name: "service pyflyte-execute", + args: args{ + in0: context.TODO(), + args: []string{ + "service_venv", + "pyflyte-execute", + "--task-module", + "abc", + "--task-name", + "abc", + "--output-prefix", + outputPrefix, + "--inputs", + inputs, + "--test", + }, + }, + want: []*commonv1.KeyValuePair{ + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 0, "service_venv", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 1, "pyflyte-execute", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 2, "--task-module", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 3, "abc", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 4, "--task-name", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 5, "abc", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 6, "--output-prefix", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 7, outputPrefix, FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 8, "--inputs", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 9, inputs, FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + {Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 10, "--test", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := makeHyperparametersKeysValuesFromArgs(tt.args.in0, tt.args.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("makeHyperparametersKeysValuesFromArgs() = %v, want %v", got, tt.want) + } + }) + } }