Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Enable custom training job in SageMaker plugin #113

Merged
merged 48 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
7ffa794
checking key existence in GetRole()
bnsblue Aug 12, 2020
3a8be98
(training job plugin only) inject hp; refactor image uri getting mech…
bnsblue Aug 13, 2020
685fbf0
change the key and value of the flyte sm command
bnsblue Aug 13, 2020
65eb0a2
get rid of a obsolete function
bnsblue Aug 13, 2020
43ea822
add and change unit tests
bnsblue Aug 14, 2020
536f2cf
add unit tests
bnsblue Aug 14, 2020
2a1e007
lint
bnsblue Aug 14, 2020
4de7787
refactor a dummy object generation function
bnsblue Aug 14, 2020
5b5e4c4
refactor a dummy object generation function
bnsblue Aug 14, 2020
3a8ee70
add unit tests
bnsblue Aug 14, 2020
688e1f9
lint
bnsblue Aug 15, 2020
b80f943
forming the runner cmd and inject it as a hyperparameter in SageMaker…
bnsblue Aug 20, 2020
67275fd
fix unit tests
bnsblue Aug 20, 2020
cd3edaa
fix unit tests
bnsblue Aug 20, 2020
f1adee2
add custom training plugin
bnsblue Aug 22, 2020
53e517c
revert overriden changes
bnsblue Aug 22, 2020
c519e93
compose the __FLYTE_SAGEMAKER_CMD__ in custom training job plugin
bnsblue Aug 22, 2020
a125e9b
add unit test
bnsblue Aug 22, 2020
27b35bc
lint
bnsblue Aug 22, 2020
0f3e050
fix getEventInfoForJob
bnsblue Aug 25, 2020
e1c2762
fix image getting
bnsblue Aug 25, 2020
0062934
fix unit tests
bnsblue Aug 25, 2020
837ae99
stick with inputs.pb and modify hp injecting logic accordingly
bnsblue Aug 25, 2020
ce23f18
fix args converting logic
bnsblue Aug 25, 2020
d0a085d
use default file-based output for custom training job
bnsblue Aug 26, 2020
580b18a
expanding PluginContext interface with necessary methods so SM plugin…
bnsblue Aug 27, 2020
46cfcf8
lint error
bnsblue Aug 27, 2020
01d3132
add unit tests
bnsblue Aug 27, 2020
531bc56
add logic to inject env vars into hyperparameters
bnsblue Aug 27, 2020
6fa9990
fix output prefix
bnsblue Aug 27, 2020
ad4d035
fix output prefix
bnsblue Aug 27, 2020
43a1131
remove job name from output prefix for now
bnsblue Aug 28, 2020
ad68fea
fix a unit test
bnsblue Aug 28, 2020
a4548c9
accommodating new arg and env var passsing syntax
bnsblue Aug 31, 2020
26cdda0
injecting a env var to force disable statsd for sagemaker custom trai…
bnsblue Aug 31, 2020
bf78cc2
correcting variable name
bnsblue Sep 1, 2020
3451dcb
remove unused constant
bnsblue Sep 1, 2020
c1fb782
remove comments
bnsblue Sep 1, 2020
6870b99
fix unit tests
bnsblue Sep 1, 2020
4fa1648
resolve conflict
bnsblue Sep 1, 2020
3039fee
merge template.go
bnsblue Sep 1, 2020
2a3046c
pr comments
bnsblue Sep 1, 2020
01a5e12
add guarding statement wrt algorithm name for custom training plugin …
bnsblue Sep 2, 2020
6a63de2
refactor file structures: splitting the code into multiple files to o…
bnsblue Sep 2, 2020
6a6b0a5
add documentations to a set of constants, and fix a constant's name
bnsblue Sep 2, 2020
6135f80
split tests into multiple files
bnsblue Sep 2, 2020
f46aed9
correcting error types: make permanent failures
bnsblue Sep 3, 2020
d64afce
refactor
bnsblue Sep 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions go/tasks/pluginmachinery/k8s/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package k8s
import (
"context"

"github.com/lyft/flytestdlib/storage"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"

"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -48,6 +50,12 @@ type PluginContext interface {

// Provides an output sync of type io.OutputWriter
OutputWriter() io.OutputWriter

// Returns a handle to the currently configured storage backend that can be used to communicate with the tasks or write metadata
bnsblue marked this conversation as resolved.
Show resolved Hide resolved
DataStore() *storage.DataStore

// Returns the max allowed dataset size that the outputwriter will accept
MaxDatasetSizeBytes() int64
}

// Defines a simplified interface to author plugins for k8s resources.
Expand Down
15 changes: 9 additions & 6 deletions go/tasks/pluginmachinery/utils/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ func transformVarNameToStringVal(ctx context.Context, varName string, inputs *co
return v, nil
}

func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in io.InputReader, out io.OutputFilePaths) (string, error) {
val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String())
val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String())
val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String())
val = rawOutputDataPrefixRegex.ReplaceAllString(val, out.GetRawOutputPrefix().String())

func replaceInputVarsTemplateCommandArgs(ctx context.Context, in io.InputReader, val string) (string, error) {
inputs, err := in.Get(ctx)
if err != nil {
return val, errors.Wrapf(err, "unable to read inputs")
Expand Down Expand Up @@ -98,6 +93,14 @@ func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in
return val, nil
}

func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in io.InputReader, out io.OutputFilePaths) (string, error) {
val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String())
val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String())
val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String())

return replaceInputVarsTemplateCommandArgs(ctx, in, val)
}

func serializePrimitive(p *core.Primitive) (string, error) {
switch o := p.Value.(type) {
case *core.Primitive_Integer:
Expand Down
4 changes: 3 additions & 1 deletion go/tasks/plugins/awsutils/awsutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import "context"

func GetRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string {
if len(roleAnnotationKey) > 0 {
return annotations[roleAnnotationKey]
if role, found := annotations[roleAnnotationKey]; found {
bnsblue marked this conversation as resolved.
Show resolved Hide resolved
return role
}
}

return ""
Expand Down
18 changes: 18 additions & 0 deletions go/tasks/plugins/k8s/sagemaker/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ const (
trainingJobTaskType = "sagemaker_training_job_task"
)

const (
customTrainingJobTaskPluginID = "sagemaker_custom_training"
customTrainingJobTaskType = "sagemaker_custom_training_job_task"
)

const (
hyperparameterTuningJobTaskPluginID = "sagemaker_hyperparameter_tuning"
hyperparameterTuningJobTaskType = "sagemaker_hyperparameter_tuning_job_task"
Expand All @@ -13,3 +18,16 @@ const (
const (
TEXTCSVInputContentType string = "text/csv"
)

const (
FlyteSageMakerEnvVarKeyPrefix string = "__FLYTE_ENV_VAR_"
FlyteSageMakerKeySuffix string = "__"
FlyteSageMakerCmdKeyPrefix string = "__FLYTE_CMD_"
FlyteSageMakerCmdDummyValue string = "__FLYTE_CMD_DUMMY_VALUE__"
FlytesageMakerEnvVarKeyStatsdDisabled string = "FLYTE_STATSD_DISABLED"
)

const (
TrainingJobOutputPathSubDir = "training_outputs"
HyperparameterOutputPathSubDir = "hyperparameter_tuning_outputs"
)
46 changes: 46 additions & 0 deletions go/tasks/plugins/k8s/sagemaker/outputs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package sagemaker
bnsblue marked this conversation as resolved.
Show resolved Hide resolved

import (
"context"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/lyft/flytestdlib/logger"
"github.com/lyft/flytestdlib/storage"
)

type JobOutputPaths struct {
io.OutputFilePaths
jobName string
store storage.ReferenceConstructor
outputPrefix storage.DataReference
}

func constructPath(store storage.ReferenceConstructor, base storage.DataReference, suffix string) storage.DataReference {
res, err := store.ConstructReference(context.Background(), base, suffix)
if err != nil {
logger.Error(context.Background(), "Failed to construct path. Base[%v] Error: %v", base, err)
}

return res
}

func (s JobOutputPaths) GetOutputPrefixPath() storage.DataReference {
return constructPath(s.store, s.outputPrefix, s.jobName)
}

func (s JobOutputPaths) GetOutputPath() storage.DataReference {
return constructPath(s.store, s.GetOutputPrefixPath(), ioutils.OutputsSuffix)
}

func (s JobOutputPaths) GetErrorPath() storage.DataReference {
return constructPath(s.store, s.GetOutputPrefixPath(), ioutils.ErrorsSuffix)
}

func NewJobOutputPaths(_ context.Context, store storage.ReferenceConstructor, outputPrefix storage.DataReference, jobName string) JobOutputPaths {
return JobOutputPaths{
jobName: jobName,
store: store,
outputPrefix: outputPrefix,
}
}
Loading