diff --git a/go/tasks/pluginmachinery/core/exec_context.go b/go/tasks/pluginmachinery/core/exec_context.go index 2f14b1abf..e724c601a 100644 --- a/go/tasks/pluginmachinery/core/exec_context.go +++ b/go/tasks/pluginmachinery/core/exec_context.go @@ -9,8 +9,15 @@ import ( "github.com/flyteorg/flytestdlib/storage" ) +// An interface to access a remote/sharable location that contains the serialized TaskTemplate +type TaskTemplatePath interface { + // Returns the path + Path(ctx context.Context) (storage.DataReference, error) +} + // An interface to access the TaskInformation type TaskReader interface { + TaskTemplatePath // Returns the core TaskTemplate Read(ctx context.Context) (*core.TaskTemplate, error) } diff --git a/go/tasks/pluginmachinery/core/mocks/task_reader.go b/go/tasks/pluginmachinery/core/mocks/task_reader.go index 3fae3aa7c..b7aef61a2 100644 --- a/go/tasks/pluginmachinery/core/mocks/task_reader.go +++ b/go/tasks/pluginmachinery/core/mocks/task_reader.go @@ -5,8 +5,11 @@ package mocks import ( context "context" - core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + flyteidlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" + + storage "github.com/flyteorg/flytestdlib/storage" ) // TaskReader is an autogenerated mock type for the TaskReader type @@ -14,11 +17,50 @@ type TaskReader struct { mock.Mock } +type TaskReader_Path struct { + *mock.Call +} + +func (_m TaskReader_Path) Return(_a0 storage.DataReference, _a1 error) *TaskReader_Path { + return &TaskReader_Path{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TaskReader) OnPath(ctx context.Context) *TaskReader_Path { + c := _m.On("Path", ctx) + return &TaskReader_Path{Call: c} +} + +func (_m *TaskReader) OnPathMatch(matchers ...interface{}) *TaskReader_Path { + c := _m.On("Path", matchers...) + return &TaskReader_Path{Call: c} +} + +// Path provides a mock function with given fields: ctx +func (_m *TaskReader) Path(ctx context.Context) (storage.DataReference, error) { + ret := _m.Called(ctx) + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context) storage.DataReference); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type TaskReader_Read struct { *mock.Call } -func (_m TaskReader_Read) Return(_a0 *core.TaskTemplate, _a1 error) *TaskReader_Read { +func (_m TaskReader_Read) Return(_a0 *flyteidlcore.TaskTemplate, _a1 error) *TaskReader_Read { return &TaskReader_Read{Call: _m.Call.Return(_a0, _a1)} } @@ -33,15 +75,15 @@ func (_m *TaskReader) OnReadMatch(matchers ...interface{}) *TaskReader_Read { } // Read provides a mock function with given fields: ctx -func (_m *TaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { +func (_m *TaskReader) Read(ctx context.Context) (*flyteidlcore.TaskTemplate, error) { ret := _m.Called(ctx) - var r0 *core.TaskTemplate - if rf, ok := ret.Get(0).(func(context.Context) *core.TaskTemplate); ok { + var r0 *flyteidlcore.TaskTemplate + if rf, ok := ret.Get(0).(func(context.Context) *flyteidlcore.TaskTemplate); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*core.TaskTemplate) + r0 = ret.Get(0).(*flyteidlcore.TaskTemplate) } } diff --git a/go/tasks/pluginmachinery/core/mocks/task_template_path.go b/go/tasks/pluginmachinery/core/mocks/task_template_path.go new file mode 100644 index 000000000..245517849 --- /dev/null +++ b/go/tasks/pluginmachinery/core/mocks/task_template_path.go @@ -0,0 +1,55 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + storage "github.com/flyteorg/flytestdlib/storage" +) + +// TaskTemplatePath is an autogenerated mock type for the TaskTemplatePath type +type TaskTemplatePath struct { + mock.Mock +} + +type TaskTemplatePath_Path struct { + *mock.Call +} + +func (_m TaskTemplatePath_Path) Return(_a0 storage.DataReference, _a1 error) *TaskTemplatePath_Path { + return &TaskTemplatePath_Path{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TaskTemplatePath) OnPath(ctx context.Context) *TaskTemplatePath_Path { + c := _m.On("Path", ctx) + return &TaskTemplatePath_Path{Call: c} +} + +func (_m *TaskTemplatePath) OnPathMatch(matchers ...interface{}) *TaskTemplatePath_Path { + c := _m.On("Path", matchers...) + return &TaskTemplatePath_Path{Call: c} +} + +// Path provides a mock function with given fields: ctx +func (_m *TaskTemplatePath) Path(ctx context.Context) (storage.DataReference, error) { + ret := _m.Called(ctx) + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context) storage.DataReference); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/go/tasks/pluginmachinery/core/template/template.go b/go/tasks/pluginmachinery/core/template/template.go index 702197ecd..21707ac15 100644 --- a/go/tasks/pluginmachinery/core/template/template.go +++ b/go/tasks/pluginmachinery/core/template/template.go @@ -33,6 +33,14 @@ func (e ErrorCollection) Error() string { return sb.String() } +// The Parameters struct is used by the Templating Engine to replace the templated parameters +type Parameters struct { + TaskExecMetadata core.TaskExecutionMetadata + Inputs io.InputReader + OutputPath io.OutputFilePaths + Task core.TaskTemplatePath +} + // Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive // Supported templates are: // - {{ .InputFile }} to receive the input file path. The protocol used will depend on the underlying system @@ -44,25 +52,23 @@ func (e ErrorCollection) Error() string { // NOTE: I wanted to do in-place replacement, until I realized that in-place replacement will alter the definition of the // graph. This is not desirable, as we may have to retry and in that case the replacement will not work and we want // to create a new location for outputs -func ReplaceTemplateCommandArgs(ctx context.Context, tExecMeta core.TaskExecutionMetadata, command []string, in io.InputReader, - out io.OutputFilePaths) ([]string, error) { +func Render(ctx context.Context, inputTemplate []string, params Parameters) ([]string, error) { + if len(inputTemplate) == 0 { + return []string{}, nil + } // TODO: Change GetGeneratedName to follow these conventions - var perRetryUniqueKey = tExecMeta.GetTaskExecutionID().GetGeneratedName() + var perRetryUniqueKey = params.TaskExecMetadata.GetTaskExecutionID().GetGeneratedName() perRetryUniqueKey = startsWithAlpha.ReplaceAllString(perRetryUniqueKey, "a") perRetryUniqueKey = alphaNumericOnly.ReplaceAllString(perRetryUniqueKey, "_") - logger.Debugf(ctx, "Using [%s] from [%s]", perRetryUniqueKey, tExecMeta.GetTaskExecutionID().GetGeneratedName()) - - if len(command) == 0 { - return []string{}, nil - } - if in == nil || out == nil { + logger.Debugf(ctx, "Using [%s] from [%s]", perRetryUniqueKey, params.TaskExecMetadata.GetTaskExecutionID().GetGeneratedName()) + if params.Inputs == nil || params.OutputPath == nil { return nil, fmt.Errorf("input reader and output path cannot be nil") } - res := make([]string, 0, len(command)) - for _, commandTemplate := range command { - updated, err := replaceTemplateCommandArgs(ctx, perRetryUniqueKey, commandTemplate, in, out) + res := make([]string, 0, len(inputTemplate)) + for _, t := range inputTemplate { + updated, err := render(ctx, t, params, perRetryUniqueKey) if err != nil { return res, err } @@ -79,30 +85,28 @@ var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`) var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^}\s]+)\s*}}`) var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`) var perRetryUniqueKey = regexp.MustCompile(`(?i){{\s*[\.$]PerRetryUniqueKey\s*}}`) +var taskTemplateRegex = regexp.MustCompile(`(?i){{\s*[\.$]TaskTemplatePath\s*}}`) -func transformVarNameToStringVal(ctx context.Context, varName string, inputs *idlCore.LiteralMap) (string, error) { - inputVal, exists := inputs.Literals[varName] - if !exists { - return "", fmt.Errorf("requested input is not found [%s]", varName) - } - - v, err := serializeLiteral(ctx, inputVal) - if err != nil { - return "", errors.Wrapf(err, "failed to bind a value to inputName [%s]", varName) - } - return v, nil -} +func render(ctx context.Context, inputTemplate string, params Parameters, perRetryKey string) (string, error) { -func replaceTemplateCommandArgs(ctx context.Context, perRetryKey string, commandTemplate string, - in io.InputReader, out io.OutputFilePaths) (string, error) { - - val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String()) - val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String()) - val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String()) - val = rawOutputDataPrefixRegex.ReplaceAllString(val, out.GetRawOutputPrefix().String()) + val := inputFileRegex.ReplaceAllString(inputTemplate, params.Inputs.GetInputPath().String()) + val = outputRegex.ReplaceAllString(val, params.OutputPath.GetOutputPrefixPath().String()) + val = inputPrefixRegex.ReplaceAllString(val, params.Inputs.GetInputPrefixPath().String()) + val = rawOutputDataPrefixRegex.ReplaceAllString(val, params.OutputPath.GetRawOutputPrefix().String()) val = perRetryUniqueKey.ReplaceAllString(val, perRetryKey) - inputs, err := in.Get(ctx) + // For Task template, we will replace only if there is a match. This is because, task template replacement + // may be expensive, as we may offload + if taskTemplateRegex.MatchString(val) { + p, err := params.Task.Path(ctx) + if err != nil { + logger.Debugf(ctx, "Failed to substitute Task Template reference - reason %s", err) + return "", err + } + val = taskTemplateRegex.ReplaceAllString(val, p.String()) + } + + inputs, err := params.Inputs.Get(ctx) if err != nil { return val, errors.Wrapf(err, "unable to read inputs") } @@ -129,6 +133,19 @@ func replaceTemplateCommandArgs(ctx context.Context, perRetryKey string, command return val, nil } +func transformVarNameToStringVal(ctx context.Context, varName string, inputs *idlCore.LiteralMap) (string, error) { + inputVal, exists := inputs.Literals[varName] + if !exists { + return "", fmt.Errorf("requested input is not found [%s]", varName) + } + + v, err := serializeLiteral(ctx, inputVal) + if err != nil { + return "", errors.Wrapf(err, "failed to bind a value to inputName [%s]", varName) + } + return v, nil +} + func serializePrimitive(p *idlCore.Primitive) (string, error) { switch o := p.Value.(type) { case *idlCore.Primitive_Integer: diff --git a/go/tasks/pluginmachinery/core/template/template_test.go b/go/tasks/pluginmachinery/core/template/template_test.go index 03259ecd1..5ede705c5 100644 --- a/go/tasks/pluginmachinery/core/template/template_test.go +++ b/go/tasks/pluginmachinery/core/template/template_test.go @@ -66,8 +66,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) t.Run("empty cmd", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, - []string{}, nil, nil) + actual, err := Render(context.TODO(), []string{}, Parameters{}) assert.NoError(t, err) assert.Equal(t, []string{}, actual) }) @@ -78,11 +77,17 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { rawOutputDataPrefix: "s3://custom-bucket", } + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } t.Run("nothing to substitute", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "world", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ @@ -92,11 +97,11 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub InputFile", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "world", "{{ .Input }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ @@ -108,11 +113,17 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { t.Run("Sub Input Prefix", func(t *testing.T) { in := dummyInputReader{inputPath: "input/prefix"} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ "hello", "world", "{{ .Input }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ @@ -123,11 +134,11 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub Output Prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "world", "{{ .OutputPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -137,12 +148,12 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub Input Output prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "world", "{{ .Input }}", "{{ .OutputPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -153,13 +164,13 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Bad input template", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "world", "${{input}}", "{{ .OutputPrefix }}", "--switch {{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -182,13 +193,19 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }, }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", "{{ $RawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -205,13 +222,19 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "date": coreutils.MustMakeLiteral(time.Date(1900, 01, 01, 01, 01, 01, 000000001, time.UTC)), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.date }}`, "{{ .OutputPrefix }}", "{{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -228,13 +251,19 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", "{{ .wrongOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -247,14 +276,20 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { t.Run("nil input", func(t *testing.T) { in := dummyInputReader{inputs: &core.LiteralMap{}} - - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + + actual, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", "--raw-data-output-prefix {{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -274,14 +309,20 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "min": coreutils.MustMakeLiteral(15), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ `SELECT COUNT(*) as total_count FROM hive.events.{{ .Inputs.table }} WHERE ds = '{{ .Inputs.ds }}' AND hr = '{{ .Inputs.hr }}' AND min = {{ .Inputs.min }} - `}, in, out) + `}, params) assert.NoError(t, err) assert.Equal(t, []string{ `SELECT @@ -299,12 +340,18 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - _, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + _, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.blah }}`, "{{ .OutputPrefix }}", - }, in, out) + }, params) assert.Error(t, err) }) @@ -314,12 +361,18 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: nil, + } + actual, err := Render(context.TODO(), []string{ "hello", "world", `--someArg {{ .Inputs.blah blah }} {{ .PerretryuNIqueKey }}`, "{{ .OutputPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -330,12 +383,12 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("sub raw output data prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "{{ .perRetryUniqueKey }}", "world", "{{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -344,6 +397,52 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "s3://custom-bucket", }, actual) }) + + t.Run("sub task template happy", func(t *testing.T) { + ctx := context.TODO() + tMock := &pluginsCoreMocks.TaskTemplatePath{} + tMock.OnPath(ctx).Return("s3://task-path", nil) + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: tMock, + } + + actual, err := Render(ctx, []string{ + "hello", + "{{ .perRetryUniqueKey }}", + "world", + "{{ .taskTemplatePath }}", + }, params) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "per_retry_unique_key", + "world", + "s3://task-path", + }, actual) + }) + + t.Run("sub task template error", func(t *testing.T) { + ctx := context.TODO() + tMock := &pluginsCoreMocks.TaskTemplatePath{} + tMock.OnPath(ctx).Return("", fmt.Errorf("error")) + params := Parameters{ + TaskExecMetadata: taskMetadata, + Inputs: in, + OutputPath: out, + Task: tMock, + } + + _, err := Render(ctx, []string{ + "hello", + "{{ .perRetryUniqueKey }}", + "world", + "{{ .taskTemplatePath }}", + }, params) + assert.Error(t, err) + }) } func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) { @@ -353,18 +452,21 @@ func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) { rawOutputDataPrefix: "s3://custom-bucket", } + params := Parameters{Inputs: in, OutputPath: out} + t.Run("dashes are replaced", func(t *testing.T) { taskExecutionID := &pluginsCoreMocks.TaskExecutionID{} taskExecutionID.On("GetGeneratedName").Return("per-retry-unique-key") taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + params.TaskExecMetadata = taskMetadata + actual, err := Render(context.TODO(), []string{ "hello", "{{ .perRetryUniqueKey }}", "world", "{{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", @@ -381,6 +483,7 @@ func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) { taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) + params.TaskExecMetadata = taskMetadata testString := "doesn't start with a number" testString2 := "1 does start with a number" testString3 := " 1 3 nd spaces " @@ -388,12 +491,12 @@ func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) { assert.Equal(t, "adoes start with a number", startsWithAlpha.ReplaceAllString(testString2, "a")) assert.Equal(t, "and spaces ", startsWithAlpha.ReplaceAllString(testString3, "a")) - actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + actual, err := Render(context.TODO(), []string{ "hello", "{{ .perRetryUniqueKey }}", "world", "{{ .rawOutputDataPrefix }}", - }, in, out) + }, params) assert.NoError(t, err) assert.Equal(t, []string{ "hello", diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper.go b/go/tasks/pluginmachinery/flytek8s/container_helper.go index 9cbbf56ab..45d549800 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -13,9 +13,7 @@ import ( "k8s.io/apimachinery/pkg/util/rand" "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ) const resourceGPU = "gpu" @@ -84,35 +82,34 @@ func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequiremen } // Returns a K8s Container for the execution -func ToK8sContainer(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskContainer *core.Container, iFace *core.TypedInterface, - inputReader io.InputReader, outputPaths io.OutputFilePaths) (*v1.Container, error) { - modifiedCommand, err := template.ReplaceTemplateCommandArgs(ctx, taskExecutionMetadata, taskContainer.GetCommand(), inputReader, outputPaths) +func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *core.TypedInterface, parameters template.Parameters) (*v1.Container, error) { + modifiedCommand, err := template.Render(ctx, taskContainer.GetCommand(), parameters) if err != nil { return nil, err } - modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskExecutionMetadata, taskContainer.GetArgs(), inputReader, outputPaths) + modifiedArgs, err := template.Render(ctx, taskContainer.GetArgs(), parameters) if err != nil { return nil, err } - envVars := DecorateEnvVars(ctx, ToK8sEnvVar(taskContainer.GetEnv()), taskExecutionMetadata.GetTaskExecutionID()) + envVars := DecorateEnvVars(ctx, ToK8sEnvVar(taskContainer.GetEnv()), parameters.TaskExecMetadata.GetTaskExecutionID()) - if taskExecutionMetadata.GetOverrides() == nil { + if parameters.TaskExecMetadata.GetOverrides() == nil { return nil, errors.Errorf(errors.BadTaskSpecification, "platform/compiler error, overrides not set for task") } - if taskExecutionMetadata.GetOverrides() == nil || taskExecutionMetadata.GetOverrides().GetResources() == nil { + if parameters.TaskExecMetadata.GetOverrides() == nil || parameters.TaskExecMetadata.GetOverrides().GetResources() == nil { return nil, errors.Errorf(errors.BadTaskSpecification, "resource requirements not found for container task, required!") } - res := taskExecutionMetadata.GetOverrides().GetResources() + res := parameters.TaskExecMetadata.GetOverrides().GetResources() if res != nil { res = ApplyResourceOverrides(ctx, *res) } // Make the container name the same as the pod name, unless it violates K8s naming conventions // Container names are subject to the DNS-1123 standard - containerName := taskExecutionMetadata.GetTaskExecutionID().GetGeneratedName() + containerName := parameters.TaskExecMetadata.GetTaskExecutionID().GetGeneratedName() if errs := validation.IsDNS1123Label(containerName); len(errs) > 0 { containerName = rand.String(4) } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 3b0583e15..48ba02fa5 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flytestdlib/logger" @@ -14,7 +16,6 @@ import ( pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ) const PodKind = "pod" @@ -45,9 +46,8 @@ func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, } } -func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskReader pluginsCore.TaskReader, - inputs io.InputReader, outputPaths io.OutputFilePaths) (*v1.PodSpec, error) { - task, err := taskReader.Read(ctx) +func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { + task, err := tCtx.TaskReader().Read(ctx) if err != nil { logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error()) return nil, err @@ -56,7 +56,12 @@ func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExe logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.") return nil, fmt.Errorf("container not specified in task template") } - c, err := ToK8sContainer(ctx, taskExecutionMetadata, task.GetContainer(), task.Interface, inputs, outputPaths) + c, err := ToK8sContainer(ctx, task.GetContainer(), task.Interface, template.Parameters{ + Task: tCtx.TaskReader(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + }) if err != nil { return nil, err } @@ -67,9 +72,9 @@ func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExe pod := &v1.PodSpec{ Containers: containers, } - UpdatePod(taskExecutionMetadata, []v1.ResourceRequirements{c.Resources}, pod) + UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod) - if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), taskExecutionMetadata, inputs, outputPaths, task.GetContainer().GetDataConfig()); err != nil { + if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil { return nil, err } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 4a6452bde..164a46e2d 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -80,6 +80,19 @@ func dummyInputReader() io.InputReader { return inputReader } +func dummyExecContext(r *v1.ResourceRequirements) pluginsCore.TaskExecutionContext { + ow := &pluginsIOMock.OutputWriter{} + ow.OnGetOutputPrefixPath().Return("") + ow.OnGetRawOutputPrefix().Return("") + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(r)) + tCtx.OnInputReader().Return(dummyInputReader()) + tCtx.OnTaskReader().Return(dummyTaskReader()) + tCtx.OnOutputWriter().Return(ow) + return tCtx +} + func TestPodSetup(t *testing.T) { configAccessor := viper.NewAccessor(config1.Options{ StrictMode: true, @@ -142,11 +155,7 @@ func updatePod(t *testing.T) { func toK8sPodInterruptible(t *testing.T) { ctx := context.TODO() - op := &pluginsIOMock.OutputFilePaths{} - op.On("GetOutputPrefixPath").Return(storage.DataReference("")) - op.On("GetRawOutputPrefix").Return(storage.DataReference("")) - - x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ + x := dummyExecContext(&v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceStorage: resource.MustParse("100M"), @@ -158,7 +167,7 @@ func toK8sPodInterruptible(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x, dummyTaskReader(), dummyInputReader(), op) + p, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Len(t, p.Tolerations, 2) assert.Equal(t, "x/flyte", p.Tolerations[1].Key) @@ -196,7 +205,7 @@ func TestToK8sPod(t *testing.T) { op.On("GetRawOutputPrefix").Return(storage.DataReference("")) t.Run("WithGPU", func(t *testing.T) { - x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ + x := dummyExecContext(&v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceStorage: resource.MustParse("100M"), @@ -208,13 +217,13 @@ func TestToK8sPod(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x, dummyTaskReader(), dummyInputReader(), op) + p, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 1) }) t.Run("NoGPU", func(t *testing.T) { - x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ + x := dummyExecContext(&v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceStorage: resource.MustParse("100M"), @@ -225,14 +234,14 @@ func TestToK8sPod(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x, dummyTaskReader(), dummyInputReader(), op) + p, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 0) assert.Equal(t, "some-acceptable-name", p.Containers[0].Name) }) t.Run("Default toleration, selector, scheduler", func(t *testing.T) { - x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ + x := dummyExecContext(&v1.ResourceRequirements{ Limits: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceStorage: resource.MustParse("100M"), @@ -256,7 +265,7 @@ func TestToK8sPod(t *testing.T) { SchedulerName: "myScheduler", })) - p, err := ToK8sPodSpec(ctx, x, dummyTaskReader(), dummyInputReader(), op) + p, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, 1, len(p.Tolerations)) assert.Equal(t, 1, len(p.NodeSelector)) diff --git a/go/tasks/pluginmachinery/ioutils/doc.go b/go/tasks/pluginmachinery/ioutils/doc.go new file mode 100644 index 000000000..bc0381a28 --- /dev/null +++ b/go/tasks/pluginmachinery/ioutils/doc.go @@ -0,0 +1,5 @@ +// Package ioutils contains utilities for interacting with the IO Layer of FlytePropeller Metastore +// For example, utilities like reading inputs, writing outputs, computing output paths, prefixes. +// These helpers are intended to be used by FlytePropeller and aim to reduce the burden of implementing simple +// io functions +package ioutils diff --git a/go/tasks/pluginmachinery/ioutils/paths.go b/go/tasks/pluginmachinery/ioutils/paths.go index 69df844d8..724ef44bc 100644 --- a/go/tasks/pluginmachinery/ioutils/paths.go +++ b/go/tasks/pluginmachinery/ioutils/paths.go @@ -9,11 +9,12 @@ import ( ) const ( - InputsSuffix = "inputs.pb" - FuturesSuffix = "futures.pb" - OutputsSuffix = "outputs.pb" - ErrorsSuffix = "error.pb" - IndexLookupSuffix = "indexlookup.pb" + InputsSuffix = "inputs.pb" + TaskTemplateSuffix = "task.pb" + FuturesSuffix = "futures.pb" + OutputsSuffix = "outputs.pb" + ErrorsSuffix = "error.pb" + IndexLookupSuffix = "indexlookup.pb" ) func constructPath(store storage.ReferenceConstructor, base storage.DataReference, suffix string) storage.DataReference { @@ -25,30 +26,12 @@ func constructPath(store storage.ReferenceConstructor, base storage.DataReferenc return res } -func GetPath(ctx context.Context, store storage.ReferenceConstructor, root storage.DataReference, subNames ...string) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, root, subNames...) -} - -func GetMasterOutputsPath(ctx context.Context, store storage.ReferenceConstructor, output storage.DataReference) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, output, OutputsSuffix) -} - -func GetInputsPath(ctx context.Context, store storage.ReferenceConstructor, prefix storage.DataReference) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, prefix, InputsSuffix) -} - -func GetOutputsPath(ctx context.Context, store storage.ReferenceConstructor, prefix storage.DataReference) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, prefix, OutputsSuffix) -} - -func GetFuturesPath(ctx context.Context, store storage.ReferenceConstructor, prefix storage.DataReference) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, prefix, FuturesSuffix) -} - -func GetErrorsPath(ctx context.Context, store storage.ReferenceConstructor, prefix storage.DataReference) (res storage.DataReference, err error) { - return store.ConstructReference(ctx, prefix, ErrorsSuffix) +// GetTaskTemplatePath returns a protobuf file path where TaskTemplate is stored +func GetTaskTemplatePath(ctx context.Context, store storage.ReferenceConstructor, base storage.DataReference) (storage.DataReference, error) { + return store.ConstructReference(ctx, base, TaskTemplateSuffix) } +// GetIndexLookupPath returns the indexpath suffixed to IndexLookupSuffix func GetIndexLookupPath(ctx context.Context, store storage.ReferenceConstructor, prefix storage.DataReference) (res storage.DataReference, err error) { return store.ConstructReference(ctx, prefix, IndexLookupSuffix) } diff --git a/go/tasks/pluginmachinery/ioutils/task_reader.go b/go/tasks/pluginmachinery/ioutils/task_reader.go new file mode 100644 index 000000000..a65d1a3df --- /dev/null +++ b/go/tasks/pluginmachinery/ioutils/task_reader.go @@ -0,0 +1,59 @@ +package ioutils + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flytestdlib/atomic" + "github.com/flyteorg/flytestdlib/storage" + "github.com/pkg/errors" +) + +var ( + _ pluginsCore.TaskReader = &lazyUploadingTaskReader{} +) + +// SimpleTaskReader provides only the TaskReader interface. This is created to conveniently use the uploading taskreader +// interface +type SimpleTaskReader interface { + Read(ctx context.Context) (*core.TaskTemplate, error) +} + +// lazyUploadingTaskReader provides a lazy interface that uploads the core.TaskTemplate to a configured location, +// only if the location is accessed. This reduces the potential overhead of writing the template +type lazyUploadingTaskReader struct { + SimpleTaskReader + uploaded atomic.Bool + store storage.ProtobufStore + remotePath storage.DataReference +} + +func (r *lazyUploadingTaskReader) Path(ctx context.Context) (storage.DataReference, error) { + // We are using atomic because it is ok to re-upload in some cases. We know that most of the plugins are + // executed in a single go-routine, so chances of a race condition are minimal. + if !r.uploaded.Load() { + t, err := r.SimpleTaskReader.Read(ctx) + if err != nil { + return "", err + } + err = r.store.WriteProtobuf(ctx, r.remotePath, storage.Options{}, t) + if err != nil { + return "", errors.Wrapf(err, "failed to store task template to remote path [%s]", r.remotePath) + } + r.uploaded.Store(true) + } + return r.remotePath, nil +} + +// NewLazyUploadingTaskReader decorates an existing TaskReader and adds a functionality to allow lazily uploading the task template to +// a remote location, only when the location information is accessed +func NewLazyUploadingTaskReader(baseTaskReader SimpleTaskReader, remotePath storage.DataReference, store storage.ProtobufStore) pluginsCore.TaskReader { + return &lazyUploadingTaskReader{ + SimpleTaskReader: baseTaskReader, + uploaded: atomic.NewBool(false), + store: store, + remotePath: remotePath, + } +} diff --git a/go/tasks/pluginmachinery/ioutils/task_reader_test.go b/go/tasks/pluginmachinery/ioutils/task_reader_test.go new file mode 100644 index 000000000..f7134093d --- /dev/null +++ b/go/tasks/pluginmachinery/ioutils/task_reader_test.go @@ -0,0 +1,100 @@ +package ioutils + +import ( + "context" + "fmt" + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +const dummyPath = storage.DataReference("test") + +func TestLazyUploadingTaskReader_Happy(t *testing.T) { + ttm := &core.TaskTemplate{} + + ctx := context.TODO() + tr := &mocks.TaskReader{} + tr.OnRead(ctx).Return(ttm, nil) + + rawStore, err := storage.NewInMemoryRawStore(nil, promutils.NewTestScope()) + assert.NoError(t, err) + protoStore := storage.NewDefaultProtobufStore(rawStore, promutils.NewTestScope()) + + ltr := NewLazyUploadingTaskReader(tr, dummyPath, protoStore) + + x, err := ltr.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, x, ttm) + + p, err := ltr.Path(ctx) + assert.NoError(t, err) + assert.Equal(t, p, dummyPath) + + v, err := rawStore.Head(ctx, dummyPath) + assert.NoError(t, err) + assert.True(t, v.Exists()) +} + +// test storage.ProtobufStore to test upload failure +type failingProtoStore struct { + storage.ProtobufStore +} + +func (d *failingProtoStore) WriteProtobuf(ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error { + return fmt.Errorf("failed") +} + +func TestLazyUploadingTaskReader_TaskWriteFailure(t *testing.T) { + ttm := &core.TaskTemplate{} + + ctx := context.TODO() + tr := &mocks.TaskReader{} + tr.OnRead(ctx).Return(ttm, nil) + + ltr := NewLazyUploadingTaskReader(tr, dummyPath, &failingProtoStore{}) + + x, err := ltr.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, x, ttm) + + p, err := ltr.Path(ctx) + assert.Error(t, err) + assert.Equal(t, p, storage.DataReference("")) +} + +func TestLazyUploadingTaskReader_TaskReadFailure(t *testing.T) { + + ctx := context.TODO() + tr := &mocks.TaskReader{} + tr.OnRead(ctx).Return(nil, fmt.Errorf("read fail")) + + rawStore, err := storage.NewInMemoryRawStore(nil, promutils.NewTestScope()) + assert.NoError(t, err) + protoStore := storage.NewDefaultProtobufStore(rawStore, promutils.NewTestScope()) + + ltr := NewLazyUploadingTaskReader(tr, dummyPath, protoStore) + + x, err := ltr.Read(ctx) + assert.Error(t, err) + assert.Nil(t, x) + + p, err := ltr.Path(ctx) + assert.Error(t, err) + assert.Equal(t, p, storage.DataReference("")) + + v, err := rawStore.Head(ctx, dummyPath) + assert.NoError(t, err) + assert.False(t, v.Exists()) +} + +func init() { + labeled.SetMetricKeys(contextutils.ExecIDKey) +} diff --git a/go/tasks/plugins/array/awsbatch/transformer.go b/go/tasks/plugins/array/awsbatch/transformer.go index 597f6d988..81dd119d7 100644 --- a/go/tasks/plugins/array/awsbatch/transformer.go +++ b/go/tasks/plugins/array/awsbatch/transformer.go @@ -50,13 +50,26 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon return nil, errors.Errorf(errors.BadTaskSpecification, "config[%v] is missing", DynamicTaskQueueKey) } - cmd, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetCommand(), tCtx.InputReader(), tCtx.OutputWriter()) + inputReader := array.GetInputReader(tCtx, taskTemplate) + cmd, err := template.Render( + ctx, + taskTemplate.GetContainer().GetCommand(), + template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: inputReader, + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return nil, err } - inputReader := array.GetInputReader(tCtx, taskTemplate) - args, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetArgs(), - inputReader, tCtx.OutputWriter()) + args, err := template.Render(ctx, taskTemplate.GetContainer().GetArgs(), + template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: inputReader, + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) taskTemplate.GetContainer().GetEnv() if err != nil { return nil, err diff --git a/go/tasks/plugins/array/k8s/task.go b/go/tasks/plugins/array/k8s/task.go index 7598c3fd5..22569d7db 100644 --- a/go/tasks/plugins/array/k8s/task.go +++ b/go/tasks/plugins/array/k8s/task.go @@ -83,8 +83,13 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Missing task template") } inputReader := array.GetInputReader(tCtx, taskTemplate) - pod.Spec.Containers[0].Args, err = template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), args, - inputReader, tCtx.OutputWriter()) + pod.Spec.Containers[0].Args, err = template.Render(ctx, args, + template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: inputReader, + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") } diff --git a/go/tasks/plugins/array/k8s/transformer.go b/go/tasks/plugins/array/k8s/transformer.go index 70061249c..9382bdeef 100644 --- a/go/tasks/plugins/array/k8s/transformer.go +++ b/go/tasks/plugins/array/k8s/transformer.go @@ -3,6 +3,8 @@ package k8s import ( "context" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" @@ -17,6 +19,16 @@ import ( const PodKind = "pod" +type arrayTaskContext struct { + core.TaskExecutionContext + arrayInputReader io.InputReader +} + +// Overrides the TaskExecutionContext from base and returns a specialized context for Array +func (a *arrayTaskContext) InputReader() io.InputReader { + return a.arrayInputReader +} + // Note that Name is not set on the result object. // It's up to the caller to set the Name before creating the object in K8s. func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext) ( @@ -35,7 +47,10 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC "Required value not set, taskTemplate Container") } - inputReader := array.GetInputReader(tCtx, taskTemplate) + arrTCtx := &arrayTaskContext{ + TaskExecutionContext: tCtx, + arrayInputReader: array.GetInputReader(tCtx, taskTemplate), + } var arrayJob *idlPlugins.ArrayJob if taskTemplate.GetCustom() != nil { arrayJob, err = core2.ToArrayJob(taskTemplate.GetCustom(), taskTemplate.TaskTypeVersion) @@ -44,8 +59,7 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC } } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, tCtx.TaskExecutionMetadata(), tCtx.TaskReader(), inputReader, - tCtx.OutputWriter()) + podSpec, err := flytek8s.ToK8sPodSpec(ctx, arrTCtx) if err != nil { return v1.Pod{}, nil, err } diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go index 881a5ce7c..cbc45cc06 100644 --- a/go/tasks/plugins/hive/execution_state.go +++ b/go/tasks/plugins/hive/execution_state.go @@ -274,7 +274,13 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( query := hiveJob.Query.GetQuery() - outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{query}, tCtx.InputReader(), tCtx.OutputWriter()) + outputs, err := template.Render(ctx, []string{query}, + template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return "", "", []string{}, 0, "", err } diff --git a/go/tasks/plugins/k8s/container/container.go b/go/tasks/plugins/k8s/container/container.go index 84a35cd68..6a148f191 100755 --- a/go/tasks/plugins/k8s/container/container.go +++ b/go/tasks/plugins/k8s/container/container.go @@ -61,7 +61,7 @@ func (Plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, // Creates a new Pod that will Exit on completion. The pods have no retries by design func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter()) + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, err } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 005e97771..0312a3538 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -62,7 +62,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter()) + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 382da309e..2ad381853 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -62,7 +62,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter()) + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index ac520f530..7dc26ac8b 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -334,7 +334,12 @@ func injectTaskTemplateEnvVarToHyperparameters(ctx context.Context, taskTemplate func injectArgsAndEnvVars(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, taskTemplate *flyteIdlCore.TaskTemplate) ([]*commonv1.KeyValuePair, error) { templateArgs := taskTemplate.GetContainer().GetArgs() - templateArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), templateArgs, taskCtx.InputReader(), taskCtx.OutputWriter()) + templateArgs, err := template.Render(ctx, templateArgs, template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + }) if err != nil { return nil, errors.Wrapf(ErrSagemaker, err, "Failed to de-template the hyperparameter values") } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go index ca92e70bc..18be269fb 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -39,13 +39,23 @@ func validateAndFinalizePod( if container.Name == primaryContainerName { hasPrimaryContainer = true } - modifiedCommand, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.Command, taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedCommand, err := template.Render(ctx, container.Command, template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + }) if err != nil { return nil, err } container.Command = modifiedCommand - modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.Args, taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedArgs, err := template.Render(ctx, container.Args, template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + }) if err != nil { return nil, err } diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 4a0469973..11801dcb8 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -116,7 +116,12 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo }, } - modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.GetArgs(), taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedArgs, err := template.Render(ctx, container.GetArgs(), template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + }) if err != nil { return nil, err } diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 8e2700feb..3370b0b95 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -223,12 +223,17 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (string, return "", "", "", "", err } - outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{ + outputs, err := template.Render(ctx, []string{ prestoQuery.RoutingGroup, prestoQuery.Catalog, prestoQuery.Schema, prestoQuery.Statement, - }, tCtx.InputReader(), tCtx.OutputWriter()) + }, template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return "", "", "", "", err } diff --git a/go/tasks/plugins/webapi/athena/utils.go b/go/tasks/plugins/webapi/athena/utils.go index 574986b9b..a4b89b4e6 100644 --- a/go/tasks/plugins/webapi/athena/utils.go +++ b/go/tasks/plugins/webapi/athena/utils.go @@ -97,10 +97,15 @@ func extractQueryInfo(ctx context.Context, tCtx webapi.TaskExecutionContextReade return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid QubleHiveJob proto in custom field.") } - outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{ + outputs, err := template.Render(ctx, []string{ hiveQuery.Query.Query, hiveQuery.ClusterLabel, - }, tCtx.InputReader(), tCtx.OutputWriter()) + }, template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return QueryInfo{}, err } @@ -121,12 +126,17 @@ func extractQueryInfo(ctx context.Context, tCtx webapi.TaskExecutionContextReade return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid PrestoQuery proto in custom field.") } - outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{ + outputs, err := template.Render(ctx, []string{ prestoQuery.RoutingGroup, prestoQuery.Catalog, prestoQuery.Schema, prestoQuery.Statement, - }, tCtx.InputReader(), tCtx.OutputWriter()) + }, template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + }) if err != nil { return QueryInfo{}, err }