Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Enable custom training job in SageMaker plugin (#113)
Browse files Browse the repository at this point in the history
* checking key existence in GetRole()

* (training job plugin only) inject hp; refactor image uri getting mechanism; some other smaller refactoring

* change the key and value of the flyte sm command

* get rid of a obsolete function

* add and change unit tests

* add unit tests

* lint

* refactor a dummy object generation function

* refactor a dummy object generation function

* add unit tests

* lint

* forming the runner cmd and inject it as a hyperparameter in SageMaker plugin

* fix unit tests

* fix unit tests

* add custom training plugin

* revert overriden changes

* compose the __FLYTE_SAGEMAKER_CMD__ in custom training job plugin

* add unit test

* lint

* fix getEventInfoForJob

* fix image getting

* fix unit tests

* stick with inputs.pb and modify hp injecting logic accordingly

* fix args converting logic

* use default file-based output for custom training job

* expanding PluginContext interface with necessary methods so SM plugin can access the DataStore and such

* lint error

* add unit tests

* add logic to inject env vars into hyperparameters

* fix output prefix

* fix output prefix

* remove job name from output prefix for now

* fix a unit test

* accommodating new arg and env var passsing syntax

* injecting a env var to force disable statsd for sagemaker custom training

* correcting variable name

* remove unused constant

* remove comments

* fix unit tests

* merge template.go

* pr comments

* add guarding statement wrt algorithm name for custom training plugin and built-in training plugin

* refactor file structures: splitting the code into multiple files to optimize for readability

* add documentations to a set of constants, and fix a constant's name

* split tests into multiple files

* correcting error types: make permanent failures

* refactor
  • Loading branch information
bnsblue authored Sep 3, 2020
1 parent 462afce commit 5805d1a
Show file tree
Hide file tree
Showing 18 changed files with 2,088 additions and 874 deletions.
8 changes: 8 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package k8s
import (
"context"

"github.com/lyft/flytestdlib/storage"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"

"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -48,6 +50,12 @@ type PluginContext interface {

// Provides an output sync of type io.OutputWriter
OutputWriter() io.OutputWriter

// Returns a handle to the currently configured storage backend that can be used to communicate with the tasks or write metadata
DataStore() *storage.DataStore

// Returns the max allowed dataset size that the outputwriter will accept
MaxDatasetSizeBytes() int64
}

// Defines a simplified interface to author plugins for k8s resources.
Expand Down
255 changes: 255 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package sagemaker

import (
"context"
"fmt"
"strings"
"time"

trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob"
trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob"

awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils"

pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flytestdlib/logger"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils"

flyteIdlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"

pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"

commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common"
"github.com/aws/aws-sdk-go/service/sagemaker"

taskError "github.com/lyft/flyteplugins/go/tasks/errors"

flyteSageMakerIdl "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins/sagemaker"

"github.com/lyft/flyteplugins/go/tasks/plugins/k8s/sagemaker/config"
)

func (m awsSagemakerPlugin) buildResourceForTrainingJob(
ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) {

logger.Infof(ctx, "Building a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
taskTemplate, err := getTaskTemplate(ctx, taskCtx)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to get the task template of the training job task")
}

// Unmarshal the custom field of the task template back into the Hyperparameter Tuning Job struct generated in flyteidl
sagemakerTrainingJob := flyteSageMakerIdl.TrainingJob{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sagemakerTrainingJob)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType)
}
if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJobResourceConfig] of the TrainingJob does not exist")
}
if sagemakerTrainingJob.GetAlgorithmSpecification() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [AlgorithmSpecification] does not exist")
}
if sagemakerTrainingJob.GetAlgorithmSpecification().GetAlgorithmName() == flyteSageMakerIdl.AlgorithmName_CUSTOM {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Custom algorithm is not supported by the built-in training job plugin")
}

taskInput, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs")
}

// Get inputs from literals
inputLiterals := taskInput.GetLiterals()
err = checkIfRequiredInputLiteralsExist(inputLiterals,
[]string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable})
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist")
}

trainPathLiteral := inputLiterals[TrainPredefinedInputVariable]
validationPathLiteral := inputLiterals[ValidationPredefinedInputVariable]
staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable]

if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[train] Input is required and should be of Type [Scalar.Blob]")
}
if validationPathLiteral.GetScalar() == nil || validationPathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[validation] Input is required and should be of Type [Scalar.Blob]")
}

// Convert the hyperparameters to the spec value
staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type")
}

outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), TrainingJobOutputPathSubDir)

jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()

trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, &sagemakerTrainingJob)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image")
}

logger.Infof(ctx, "The Sagemaker TrainingJob Task plugin received static hyperparameters [%v]", staticHyperparams)

cfg := config.GetSagemakerConfig()

var metricDefinitions []commonv1.MetricDefinition
idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions()
for _, md := range idlMetricDefinitions {
metricDefinitions = append(metricDefinitions,
commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)})
}

apiContentType, err := getAPIContentType(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType())
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]", sagemakerTrainingJob.GetAlgorithmSpecification().GetInputContentType().String())
}

inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role = cfg.RoleArn
}

trainingJob := &trainingjobv1.TrainingJob{
Spec: trainingjobv1.TrainingJobSpec{
AlgorithmSpecification: &commonv1.AlgorithmSpecification{
// If the specify a value for this AlgorithmName parameter, the user can't specify a value for TrainingImage.
// in this Flyte plugin, we always use the algorithm name and version the user provides via Flytekit to map to an image
// so we intentionally leave this field nil
AlgorithmName: nil,
TrainingImage: ToStringPtr(trainingImageStr),
TrainingInputMode: commonv1.TrainingInputMode(inputModeString),
MetricDefinitions: metricDefinitions,
},
// The support of spot training will come in a later version
EnableManagedSpotTraining: nil,
HyperParameters: staticHyperparams,
InputDataConfig: []commonv1.Channel{
{
ChannelName: ToStringPtr(TrainPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType),
InputMode: inputModeString,
},
{
ChannelName: ToStringPtr(ValidationPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(validationPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType),
InputMode: inputModeString,
},
},
OutputDataConfig: &commonv1.OutputDataConfig{
S3OutputPath: ToStringPtr(outputPath),
},
CheckpointConfig: nil,
ResourceConfig: &commonv1.ResourceConfig{
InstanceType: sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceType(),
InstanceCount: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetInstanceCount()),
VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()),
VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future
},
RoleArn: ToStringPtr(role),
Region: ToStringPtr(cfg.Region),
StoppingCondition: &commonv1.StoppingCondition{
MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout
MaxWaitTimeInSeconds: nil, // TODO: decide how to coordinate this and Flyte's timeout and queueing budget
},
TensorBoardOutputConfig: nil,
Tags: nil,
TrainingJobName: &jobName,
},
}
logger.Infof(ctx, "Successfully built a training job resource for task [%v]", taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
return trainingJob, nil
}

func (m awsSagemakerPlugin) getTaskPhaseForTrainingJob(
ctx context.Context, pluginContext k8s.PluginContext, trainingJob *trainingjobv1.TrainingJob) (pluginsCore.PhaseInfo, error) {

logger.Infof(ctx, "Getting task phase for sagemaker training job [%v]", trainingJob.Status.SageMakerTrainingJobName)
info, err := m.getEventInfoForTrainingJob(ctx, trainingJob)
if err != nil {
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, "Failed to get event info for the job")
}

occurredAt := time.Now()

switch trainingJob.Status.TrainingJobStatus {
case trainingjobController.ReconcilingTrainingJobStatus:
logger.Errorf(ctx, "Job stuck in reconciling status, assuming retryable failure [%s]", trainingJob.Status.Additional)
// TODO talk to AWS about why there cannot be an explicit condition that signals AWS API call pluginErrors
execError := &flyteIdlCore.ExecutionError{
Message: trainingJob.Status.Additional,
Kind: flyteIdlCore.ExecutionError_USER,
Code: trainingjobController.ReconcilingTrainingJobStatus,
}
return pluginsCore.PhaseInfoFailed(pluginsCore.PhaseRetryableFailure, execError, info), nil
case sagemaker.TrainingJobStatusFailed:
execError := &flyteIdlCore.ExecutionError{
Message: trainingJob.Status.Additional,
Kind: flyteIdlCore.ExecutionError_USER,
Code: sagemaker.TrainingJobStatusFailed,
}
return pluginsCore.PhaseInfoFailed(pluginsCore.PhasePermanentFailure, execError, info), nil
case sagemaker.TrainingJobStatusStopped:
reason := fmt.Sprintf("Training Job Stopped")
return pluginsCore.PhaseInfoRetryableFailure(taskError.DownstreamSystemError, reason, info), nil
case sagemaker.TrainingJobStatusCompleted:
// Now that it is a success we will set the outputs as expected by the task

// We have specified an output path in the CRD, and we know SageMaker will automatically upload the
// model tarball to s3://<specified-output-path>/<training-job-name>/output/model.tar.gz

// Therefore, here we create a output literal map, where we fill in the above path to the URI field of the
// blob output, which will later be written out by the OutputWriter to the outputs.pb remotely on S3
outputLiteralMap, err := getOutputLiteralMapFromTaskInterface(ctx, pluginContext.TaskReader(),
createModelOutputPath(trainingJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(), trainingJob.Status.SageMakerTrainingJobName))
if err != nil {
logger.Errorf(ctx, "Failed to create outputs, err: %s", err)
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task")
}
// Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter
if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil)); err != nil {
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unable to write output to the remote location")
}
logger.Debugf(ctx, "Successfully produced and returned outputs")
return pluginsCore.PhaseInfoSuccess(info), nil
case "":
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted"), nil
}

return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
}

func (m awsSagemakerPlugin) getEventInfoForTrainingJob(ctx context.Context, trainingJob *trainingjobv1.TrainingJob) (*pluginsCore.TaskInfo, error) {

var jobRegion, jobName, jobTypeInURL, sagemakerLinkName string
jobRegion = *trainingJob.Spec.Region
jobName = *trainingJob.Spec.TrainingJobName
jobTypeInURL = "jobs"
sagemakerLinkName = TrainingJobSageMakerLinkName

logger.Infof(ctx, "Getting event information for SageMaker BuiltinAlgorithmTrainingJob task, job region: [%v], job name: [%v], "+
"job type in url: [%v], sagemaker link name: [%v]", jobRegion, jobName, jobTypeInURL, sagemakerLinkName)

return createTaskInfo(ctx, jobRegion, jobName, jobTypeInURL, sagemakerLinkName)
}
Loading

0 comments on commit 5805d1a

Please sign in to comment.