Skip to content

Commit

Permalink
Adding logic to insert the MPI-enabling hyperparameter in SageMaker p…
Browse files Browse the repository at this point in the history
…lugin (flyteorg#124)

* adding logic to insert the MPI-enabling hyperparameter

* add tests for logic around the distributed protocol field

* flyteidl v0.18.9

* do not override user's choice on sagemaker_mpi_enabled
  • Loading branch information
bnsblue authored Oct 9, 2020
1 parent b5bbe10 commit 0679ea0
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 52 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/hashicorp/golang-lru v0.5.4
github.com/kubeflow/pytorch-operator v0.6.0
github.com/kubeflow/tf-operator v0.5.3
github.com/lyft/flyteidl v0.18.7
github.com/lyft/flyteidl v0.18.9
github.com/lyft/flytepropeller v0.4.2
github.com/lyft/flytestdlib v0.3.9
github.com/magiconair/properties v1.8.1
Expand Down
41 changes: 7 additions & 34 deletions go.sum

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions go/tasks/plugins/k8s/sagemaker/builtin_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func Test_awsSagemakerPlugin_BuildResourceForTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job x", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false))
Expand All @@ -73,7 +73,7 @@ func Test_awsSagemakerPlugin_BuildResourceForTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job y", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false))
Expand All @@ -100,7 +100,7 @@ func Test_awsSagemakerPlugin_BuildResourceForTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockTrainingJobTaskContext(taskTemplate, false))
Expand Down Expand Up @@ -150,7 +150,7 @@ func Test_awsSagemakerPlugin_GetTaskPhaseForTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand Down Expand Up @@ -220,7 +220,7 @@ func Test_awsSagemakerPlugin_getEventInfoForTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/sagemaker/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
FlyteSageMakerCmdKeyPrefix string = "__FLYTE_CMD_"
FlyteSageMakerCmdDummyValue string = "__FLYTE_CMD_DUMMY_VALUE__"
FlyteSageMakerEnvVarKeyStatsdDisabled string = "FLYTE_STATSD_DISABLED"
SageMakerMpiEnableEnvVarName string = "sagemaker_mpi_enabled"
)

const (
Expand Down
20 changes: 20 additions & 0 deletions go/tasks/plugins/k8s/sagemaker/custom_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ func (m awsSagemakerPlugin) buildResourceForCustomTrainingJob(
Value: strconv.FormatBool(true),
})

if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil {
logger.Errorf(ctx, "TrainingJobResourceConfig is nil")
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "TrainingJobResourceConfig is nil")
}

if sagemakerTrainingJob.GetTrainingJobResourceConfig().GetDistributedProtocol() == flyteSageMakerIdl.DistributedProtocol_MPI {
// inject sagemaker_mpi_enabled=true into hyperparameters if the user code designates MPI as its distributed training framework
logger.Infof(ctx, "MPI is enabled by the user. TrainingJob.TrainingJobResourceConfig.DistributedProtocol=[%v]", sagemakerTrainingJob.GetTrainingJobResourceConfig().GetDistributedProtocol().String())
hyperParameters = append(hyperParameters, &commonv1.KeyValuePair{
Name: SageMakerMpiEnableEnvVarName,
Value: strconv.FormatBool(true),
})
} else {
// default value: injecting sagemaker_mpi_enabled=false
logger.Infof(ctx, "Distributed protocol is unspecified or a non-MPI value [%v] in the training job", sagemakerTrainingJob.GetTrainingJobResourceConfig().GetDistributedProtocol())
hyperParameters = append(hyperParameters, &commonv1.KeyValuePair{
Name: SageMakerMpiEnableEnvVarName,
Value: strconv.FormatBool(false),
})
}
logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", hyperParameters)

trainingJob := &trainingjobv1.TrainingJob{
Expand Down
90 changes: 85 additions & 5 deletions go/tasks/plugins/k8s/sagemaker/custom_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Test_awsSagemakerPlugin_BuildResourceForCustomTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockCustomTrainingJobTaskContext(taskTemplate, false))
Expand All @@ -66,6 +66,7 @@ func Test_awsSagemakerPlugin_BuildResourceForCustomTrainingJob(t *testing.T) {
{Name: fmt.Sprintf("%s%d_%s%s", FlyteSageMakerCmdKeyPrefix, 6, "--test-flag", FlyteSageMakerKeySuffix), Value: FlyteSageMakerCmdDummyValue},
{Name: fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, "Env_Var", FlyteSageMakerKeySuffix), Value: "Env_Val"},
{Name: fmt.Sprintf("%s%s%s", FlyteSageMakerEnvVarKeyPrefix, FlyteSageMakerEnvVarKeyStatsdDisabled, FlyteSageMakerKeySuffix), Value: strconv.FormatBool(true)},
{Name: SageMakerMpiEnableEnvVarName, Value: strconv.FormatBool(false)},
}
assert.Equal(t, len(expectedHPs), len(trainingJob.Spec.HyperParameters))
for i := range expectedHPs {
Expand All @@ -74,6 +75,85 @@ func Test_awsSagemakerPlugin_BuildResourceForCustomTrainingJob(t *testing.T) {
}

assert.Equal(t, testImage, *trainingJob.Spec.AlgorithmSpecification.TrainingImage)

// Since the distributed protocol is UNSPECIFIED, we should find sagemaker_mpi_enabled=false in the hyperparameters
count := 0
for i := range trainingJob.Spec.HyperParameters {
if trainingJob.Spec.HyperParameters[i].Name == SageMakerMpiEnableEnvVarName && trainingJob.Spec.HyperParameters[i].Value == strconv.FormatBool(false) {
count++
}
}
assert.Equal(t, 1, count)
})

t.Run("In a custom training job when users specify the MPI distributed protocol, even when the instance count is 1, we should find sagemaker_mpi_enabled=true in the hyperparameters", func(t *testing.T) {
// Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config
configAccessor := viper.NewAccessor(stdConfig.Options{
StrictMode: true,
// Use a different
SearchPaths: []string{"testdata/config2.yaml"},
})

err := configAccessor.UpdateConfig(context.TODO())
assert.NoError(t, err)

awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: customTrainingJobTaskType}

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_MPI)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockCustomTrainingJobTaskContext(taskTemplate, false))
assert.NoError(t, err)
assert.NotNil(t, trainingJobResource)

trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob)
assert.True(t, ok)

count := 0
for i := range trainingJob.Spec.HyperParameters {
if trainingJob.Spec.HyperParameters[i].Name == SageMakerMpiEnableEnvVarName {
count++
assert.Equal(t, trainingJob.Spec.HyperParameters[i].Value, strconv.FormatBool(true))
}
}
assert.Equal(t, 1, count)
})

t.Run("When users specify the MPI distributed protocol, we should find sagemaker_mpi_enabled=true in the hyperparameters", func(t *testing.T) {
// Injecting a config which contains a mismatched roleAnnotationKey -> expecting to get the role from the config
configAccessor := viper.NewAccessor(stdConfig.Options{
StrictMode: true,
// Use a different
SearchPaths: []string{"testdata/config2.yaml"},
})

err := configAccessor.UpdateConfig(context.TODO())
assert.NoError(t, err)

awsSageMakerTrainingJobHandler := awsSagemakerPlugin{TaskType: customTrainingJobTaskType}

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 2, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_MPI)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)

trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, generateMockCustomTrainingJobTaskContext(taskTemplate, false))
assert.NoError(t, err)
assert.NotNil(t, trainingJobResource)

trainingJob, ok := trainingJobResource.(*trainingjobv1.TrainingJob)
assert.True(t, ok)

count := 0
for i := range trainingJob.Spec.HyperParameters {
if trainingJob.Spec.HyperParameters[i].Name == SageMakerMpiEnableEnvVarName {
count++
assert.Equal(t, trainingJob.Spec.HyperParameters[i].Value, strconv.FormatBool(true))
}
}
assert.Equal(t, 1, count)
})
}

Expand All @@ -94,7 +174,7 @@ func Test_awsSagemakerPlugin_GetTaskPhaseForCustomTrainingJob(t *testing.T) {
t.Run("TrainingJobStatusCompleted", func(t *testing.T) {
tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand All @@ -105,7 +185,7 @@ func Test_awsSagemakerPlugin_GetTaskPhaseForCustomTrainingJob(t *testing.T) {
t.Run("TrainingJobStatusCompleted", func(t *testing.T) {
tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand All @@ -123,7 +203,7 @@ func Test_awsSagemakerPlugin_GetTaskPhaseForCustomTrainingJob(t *testing.T) {
t.Run("OutputWriter.Put returns an error", func(t *testing.T) {
tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, true)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand Down Expand Up @@ -163,7 +243,7 @@ func Test_awsSagemakerPlugin_getEventInfoForCustomTrainingJob(t *testing.T) {

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockCustomTrainingJobTaskContext(taskTemplate, false)
trainingJobResource, err := awsSageMakerTrainingJobHandler.BuildResource(ctx, taskCtx)
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Test_awsSagemakerPlugin_BuildResourceForHyperparameterTuningJob(t *testing.

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5)
taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj)
hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, generateMockHyperparameterTuningJobTaskContext(taskTemplate))
Expand Down Expand Up @@ -74,7 +74,7 @@ func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T

tjObj := generateMockTrainingJobCustomObj(
sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{},
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25)
sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED)
htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5)
taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj)
taskCtx := generateMockHyperparameterTuningJobTaskContext(taskTemplate)
Expand Down
9 changes: 5 additions & 4 deletions go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata {
func generateMockTrainingJobCustomObj(
inputMode sagemakerIdl.InputMode_Value, algName sagemakerIdl.AlgorithmName_Value, algVersion string,
metricDefinitions []*sagemakerIdl.MetricDefinition, contentType sagemakerIdl.InputContentType_Value,
instanceCount int64, instanceType string, volumeSizeInGB int64) *sagemakerIdl.TrainingJob {
instanceCount int64, instanceType string, volumeSizeInGB int64, protocol sagemakerIdl.DistributedProtocol_Value) *sagemakerIdl.TrainingJob {
return &sagemakerIdl.TrainingJob{
AlgorithmSpecification: &sagemakerIdl.AlgorithmSpecification{
InputMode: inputMode,
Expand All @@ -407,9 +407,10 @@ func generateMockTrainingJobCustomObj(
InputContentType: contentType,
},
TrainingJobResourceConfig: &sagemakerIdl.TrainingJobResourceConfig{
InstanceCount: instanceCount,
InstanceType: instanceType,
VolumeSizeInGb: volumeSizeInGB,
InstanceCount: instanceCount,
InstanceType: instanceType,
VolumeSizeInGb: volumeSizeInGB,
DistributedProtocol: protocol,
},
}
}
Expand Down
3 changes: 2 additions & 1 deletion go/tasks/plugins/k8s/sagemaker/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ func Test_getTrainingJobImage(t *testing.T) {
flyteSagemakerIdl.InputContentType_TEXT_CSV,
1,
"ml.m4.xlarge",
25)
25,
flyteSagemakerIdl.DistributedProtocol_UNSPECIFIED)
taskTemplate := generateMockTrainingJobTaskTemplate("the job", tjObj)
taskCtx := generateMockTrainingJobTaskContext(taskTemplate, false)
sagemakerTrainingJob := flyteSagemakerIdl.TrainingJob{}
Expand Down

0 comments on commit 0679ea0

Please sign in to comment.