Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Enable custom training job in SageMaker plugin #113

Merged
merged 48 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
7ffa794
checking key existence in GetRole()
bnsblue Aug 12, 2020
3a8be98
(training job plugin only) inject hp; refactor image uri getting mech…
bnsblue Aug 13, 2020
685fbf0
change the key and value of the flyte sm command
bnsblue Aug 13, 2020
65eb0a2
get rid of a obsolete function
bnsblue Aug 13, 2020
43ea822
add and change unit tests
bnsblue Aug 14, 2020
536f2cf
add unit tests
bnsblue Aug 14, 2020
2a1e007
lint
bnsblue Aug 14, 2020
4de7787
refactor a dummy object generation function
bnsblue Aug 14, 2020
5b5e4c4
refactor a dummy object generation function
bnsblue Aug 14, 2020
3a8ee70
add unit tests
bnsblue Aug 14, 2020
688e1f9
lint
bnsblue Aug 15, 2020
b80f943
forming the runner cmd and inject it as a hyperparameter in SageMaker…
bnsblue Aug 20, 2020
67275fd
fix unit tests
bnsblue Aug 20, 2020
cd3edaa
fix unit tests
bnsblue Aug 20, 2020
f1adee2
add custom training plugin
bnsblue Aug 22, 2020
53e517c
revert overriden changes
bnsblue Aug 22, 2020
c519e93
compose the __FLYTE_SAGEMAKER_CMD__ in custom training job plugin
bnsblue Aug 22, 2020
a125e9b
add unit test
bnsblue Aug 22, 2020
27b35bc
lint
bnsblue Aug 22, 2020
0f3e050
fix getEventInfoForJob
bnsblue Aug 25, 2020
e1c2762
fix image getting
bnsblue Aug 25, 2020
0062934
fix unit tests
bnsblue Aug 25, 2020
837ae99
stick with inputs.pb and modify hp injecting logic accordingly
bnsblue Aug 25, 2020
ce23f18
fix args converting logic
bnsblue Aug 25, 2020
d0a085d
use default file-based output for custom training job
bnsblue Aug 26, 2020
580b18a
expanding PluginContext interface with necessary methods so SM plugin…
bnsblue Aug 27, 2020
46cfcf8
lint error
bnsblue Aug 27, 2020
01d3132
add unit tests
bnsblue Aug 27, 2020
531bc56
add logic to inject env vars into hyperparameters
bnsblue Aug 27, 2020
6fa9990
fix output prefix
bnsblue Aug 27, 2020
ad4d035
fix output prefix
bnsblue Aug 27, 2020
43a1131
remove job name from output prefix for now
bnsblue Aug 28, 2020
ad68fea
fix a unit test
bnsblue Aug 28, 2020
a4548c9
accommodating new arg and env var passsing syntax
bnsblue Aug 31, 2020
26cdda0
injecting a env var to force disable statsd for sagemaker custom trai…
bnsblue Aug 31, 2020
bf78cc2
correcting variable name
bnsblue Sep 1, 2020
3451dcb
remove unused constant
bnsblue Sep 1, 2020
c1fb782
remove comments
bnsblue Sep 1, 2020
6870b99
fix unit tests
bnsblue Sep 1, 2020
4fa1648
resolve conflict
bnsblue Sep 1, 2020
3039fee
merge template.go
bnsblue Sep 1, 2020
2a3046c
pr comments
bnsblue Sep 1, 2020
01a5e12
add guarding statement wrt algorithm name for custom training plugin …
bnsblue Sep 2, 2020
6a63de2
refactor file structures: splitting the code into multiple files to o…
bnsblue Sep 2, 2020
6a6b0a5
add documentations to a set of constants, and fix a constant's name
bnsblue Sep 2, 2020
6135f80
split tests into multiple files
bnsblue Sep 2, 2020
f46aed9
correcting error types: make permanent failures
bnsblue Sep 3, 2020
d64afce
refactor
bnsblue Sep 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions go/tasks/pluginmachinery/k8s/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package k8s
import (
"context"

"github.com/lyft/flytestdlib/storage"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"

"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -48,6 +50,12 @@ type PluginContext interface {

// Provides an output sync of type io.OutputWriter
OutputWriter() io.OutputWriter

// Returns a handle to the currently configured storage backend that can be used to communicate with the tasks or write metadata
bnsblue marked this conversation as resolved.
Show resolved Hide resolved
DataStore() *storage.DataStore

// Returns the max allowed dataset size that the outputwriter will accept
MaxDatasetSizeBytes() int64
}

// Defines a simplified interface to author plugins for k8s resources.
Expand Down
241 changes: 241 additions & 0 deletions go/tasks/plugins/k8s/sagemaker/builtin_training.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package sagemaker

import (
"context"
"fmt"
"strings"
"time"

trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob"
trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob"

awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils"

"github.com/lyft/flytestdlib/logger"
"github.com/pkg/errors"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils"

flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"

pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"

commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common"
"github.com/aws/aws-sdk-go/service/sagemaker"

taskError "github.com/lyft/flyteplugins/go/tasks/errors"

flyteSageMakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker"

"github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config"
)

func (m awsSagemakerPlugin) buildResourceForTrainingJob(
ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) {

logger.Infof(ctx, "Building a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
taskTemplate, err := getTaskTemplate(ctx, taskCtx)
if err != nil {
return nil, err
}

// Unmarshal the custom field of the task template back into the Hyperparameter Tuning Job struct generated in flyteidl
sagemakerTrainingJob := flyteSageMakerIdl.TrainingJob{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob)
if err != nil {
return nil, errors.Wrapf(err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType)
}
if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil {
return nil, errors.Errorf("Required field [TrainingJobResourceConfig] of the TrainingJob does not exist")
}
if sagemakerTrainingJob.GetAlgorithmSpecification() == nil {
return nil, errors.Errorf("Required field [AlgorithmSpecification] does not exist")
}
if sagemakerTrainingJob.GetAlgorithmSpecification().GetAlgorithmName() == flyteSageMakerIdl.AlgorithmName_CUSTOM {
return nil, errors.Errorf("Custom algorithm is not supported by the built-in training job plugin")
}

taskInput, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, errors.Wrapf(err, "unable to fetch task inputs")
}

// Get inputs from literals
inputLiterals := taskInput.GetLiterals()
err = checkIfRequiredInputLiteralsExist(inputLiterals,
[]string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable})
if err != nil {
return nil, errors.Wrapf(err, "Error occurred when checking if all the required inputs exist")
}

trainPathLiteral := inputLiterals[TrainPredefinedInputVariable]
validationPathLiteral := inputLiterals[ValidationPredefinedInputVariable]
staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable]

if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[train] Input is required and should be of Type [Scalar.Blob]")
}
if validationPathLiteral.GetScalar() == nil || validationPathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[validation] Input is required and should be of Type [Scalar.Blob]")
}

// Convert the hyperparameters to the spec value
staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral)
if err != nil {
return nil, errors.Wrapf(err, "could not convert static hyperparameters to spec type")
}

outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), TrainingJobOutputPathSubDir)

jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()

trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, &sagemakerTrainingJob)
if err != nil {
return nil, errors.Wrapf(err, "failed to find the training image")
}

logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", staticHyperparams)

cfg := config.GetSagemakerConfig()

var metricDefinitions []commonv1.MetricDefinition
idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions()
for _, md := range idlMetricDefinitions {
metricDefinitions = append(metricDefinitions,
commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)})
}

apiContentType, err := getAPIContentType(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType())
if err != nil {
return nil, errors.Wrapf(err, "Unsupported input file type [%v]", sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType().String())
}

inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role = cfg.RoleArn
}

trainingJob := &trainingjobv1.TrainingJob{
Spec: trainingjobv1.TrainingJobSpec{
AlgorithmSpecification: &commonv1.AlgorithmSpecification{
// If the specify a value for this AlgorithmName parameter, the user can't specify a value for TrainingImage.
// in this Flyte plugin, we always use the algorithm name and version the user provides via Flytekit to map to an image
// so we intentionally leave this field nil
AlgorithmName: nil,
TrainingImage: ToStringPtr(trainingImageStr),
TrainingInputMode: commonv1.TrainingInputMode(inputModeString),
MetricDefinitions: metricDefinitions,
},
// The support of spot training will come in a later version
EnableManagedSpotTraining: nil,
HyperParameters: staticHyperparams,
InputDataConfig: []commonv1.Channel{
{
ChannelName: ToStringPtr(TrainPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType),
InputMode: inputModeString,
},
{
ChannelName: ToStringPtr(ValidationPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(validationPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType),
InputMode: inputModeString,
},
},
OutputDataConfig: &commonv1.OutputDataConfig{
S3OutputPath: ToStringPtr(outputPath),
},
CheckpointConfig: nil,
ResourceConfig: &commonv1.ResourceConfig{
InstanceType: sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceType(),
InstanceCount: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceCount()),
VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()),
VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future
},
RoleArn: ToStringPtr(role),
Region: ToStringPtr(cfg.Region),
StoppingCondition: &commonv1.StoppingCondition{
MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout
MaxWaitTimeInSeconds: nil, // TODO: decide how to coordinate this and Flyte's timeout and queueing budget
},
TensorBoardOutputConfig: nil,
Tags: nil,
TrainingJobName: &jobName,
},
}
logger.Infof(ctx, "Successfully built a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
return trainingJob, nil
}

func (m awsSagemakerPlugin) getTaskPhaseForTrainingJob(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this different for every job type? or are the jobSTatus common?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output handling is different for different type of tasks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the job statuses are not common either

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I realized once I went through all of them

ctx context.Context, pluginContext k8s.PluginContext, trainingJob *trainingjobv1.TrainingJob) (pluginsCore.PhaseInfo, error) {

logger.Infof(ctx, "Getting task phase for sagemaker training job [%v]", trainingJob.Status.SageMakerTrainingJobName)
info, err := m.getEventInfoForJob(ctx, trainingJob)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

occurredAt := time.Now()

switch trainingJob.Status.TrainingJobStatus {
case trainingjobController.ReconcilingTrainingJobStatus:
logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", trainingJob.Status.Additional)
// TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call errors
execError := &flyteIdlCore.ExecutionError{
Message: trainingJob.Status.Additional,
Kind: flyteIdlCore.ExecutionError_USER,
Code: trainingjobController.ReconcilingTrainingJobStatus,
}
return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil
case sagemaker.TrainingJobStatusFailed:
execError := &flyteIdlCore.ExecutionError{
Message: trainingJob.Status.Additional,
Kind: flyteIdlCore.ExecutionError_USER,
Code: sagemaker.TrainingJobStatusFailed,
}
return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil
case sagemaker.TrainingJobStatusStopped:
reason := fmt.Sprintf("Training Job Stopped")
return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil
case sagemaker.TrainingJobStatusCompleted:
// Now that it is a success we will set the outputs as expected by the task

// We have specified an output path in the CRD, and we know SageMaker will automatically upload the
// model tarball to s3://<specified-output-path>/<training-job-name>/output/model.tar.gz

// Therefore, here we create a output literal map, where we fill in the above path to the URI field of the
// blob output, which will later be written out by the OutputWriter to the outputs.pb remotely on S3
outputLiteralMap, err := getOutputLiteralMapFromTaskInterface(ctx, pluginContext.TaskReader(),
createModelOutputPath(trainingJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(), trainingJob.Status.SageMakerTrainingJobName))
if err != nil {
logger.Errorf(ctx, "Failed to create outputs, err: %s", err)
return pluginsCore.PhaseInfoUndefined, errors.Wrapf(err, "failed to create outputs for the task")
}
// Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter
if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil)); err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
logger.Debugf(ctx, "Successfully produced and returned outputs")
return pluginsCore.PhaseInfoSuccess(info), nil
case "":
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil
}

return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
}
Loading