From 62f54d1f84a09ba6cb4dc2d6f96b226e04ab75a1 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 2 Dec 2020 11:12:46 -0800 Subject: [PATCH] Add interpolation to Hive queries (#142) --- copilot/go.mod | 2 +- copilot/go.sum | 1 + go.sum | 1 + .../{utils => core/template}/template.go | 87 ++++--- .../{utils => core/template}/template_test.go | 213 ++++++++++++++---- .../flytek8s/container_helper.go | 7 +- go/tasks/pluginmachinery/io/iface.go | 2 +- .../plugins/array/awsbatch/transformer.go | 8 +- go/tasks/plugins/array/k8s/task.go | 6 +- go/tasks/plugins/hive/execution_state.go | 84 ++++++- go/tasks/plugins/hive/execution_state_test.go | 46 ++-- go/tasks/plugins/hive/executions_cache.go | 2 +- .../plugins/hive/executions_cache_test.go | 2 +- go/tasks/plugins/hive/test_helpers.go | 13 ++ go/tasks/plugins/k8s/sagemaker/utils.go | 4 +- go/tasks/plugins/k8s/sidecar/sidecar.go | 6 +- go/tasks/plugins/k8s/spark/spark.go | 4 +- go/tasks/plugins/presto/execution_state.go | 4 +- 18 files changed, 377 insertions(+), 115 deletions(-) rename go/tasks/pluginmachinery/{utils => core/template}/template.go (67%) mode change 100755 => 100644 rename go/tasks/pluginmachinery/{utils => core/template}/template_test.go (64%) mode change 100755 => 100644 diff --git a/copilot/go.mod b/copilot/go.mod index 0d4a94bca9..71e2cd243f 100644 --- a/copilot/go.mod +++ b/copilot/go.mod @@ -8,7 +8,7 @@ require ( github.com/gogo/protobuf v1.3.1 github.com/golang/protobuf v1.4.2 github.com/imdario/mergo v0.3.9 // indirect - github.com/lyft/flyteidl v0.18.0 + github.com/lyft/flyteidl v0.18.9 github.com/lyft/flyteplugins v0.4.4 github.com/lyft/flytestdlib v0.3.9 github.com/mitchellh/go-ps v1.0.0 diff --git a/copilot/go.sum b/copilot/go.sum index 59c2f0c474..7044a9fc49 100644 --- a/copilot/go.sum +++ b/copilot/go.sum @@ -377,6 +377,7 @@ github.com/lyft/flyteidl v0.17.32 h1:Iio3gYjTyPhAiOMWJ/H/4YtfWIZm5KZSlWMULT1Ef6U github.com/lyft/flyteidl v0.17.32/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteidl v0.18.0 h1:f4yv1MafE26wpMC6QlthM02EeTEDXpy/waL54dRDiSs= github.com/lyft/flyteidl v0.18.0/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= +github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flytepropeller v0.3.6/go.mod h1:1Iw3ngmJBP+52coloHL1rOxcX7EDDUUvTYFQQy2WYzk= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI= diff --git a/go.sum b/go.sum index 4b9bf8cfe7..10b4e22520 100644 --- a/go.sum +++ b/go.sum @@ -395,6 +395,7 @@ github.com/lyft/flytestdlib v0.3.3 h1:MkWXPkwQinh6MR3Yf5siZhmRSt9r4YmsF+5kvVVVed github.com/lyft/flytestdlib v0.3.3/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI= github.com/lyft/flytestdlib v0.3.9/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= +github.com/lyft/spark-on-k8s-operator v0.1.4-0.20201027003055-c76b67e3b6d0 h1:1vSmc+Bo70X0JVYywQ9Hy/aet6p613ejacy9x5td0m4= github.com/lyft/spark-on-k8s-operator v0.1.4-0.20201027003055-c76b67e3b6d0/go.mod h1:hkRqdqAsdNnxT/Zst6MNMRbTAoiCZ0JRw7svRgAYb0A= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= diff --git a/go/tasks/pluginmachinery/utils/template.go b/go/tasks/pluginmachinery/core/template/template.go old mode 100755 new mode 100644 similarity index 67% rename from go/tasks/pluginmachinery/utils/template.go rename to go/tasks/pluginmachinery/core/template/template.go index 8f4a829038..cab579ef7e --- a/go/tasks/pluginmachinery/utils/template.go +++ b/go/tasks/pluginmachinery/core/template/template.go @@ -1,25 +1,37 @@ -package utils +package template import ( "context" "fmt" - "reflect" "regexp" "strings" - "github.com/golang/protobuf/ptypes" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flytestdlib/logger" - "github.com/pkg/errors" + "reflect" + + "github.com/golang/protobuf/ptypes" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/pkg/errors" ) -var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) -var inputPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]InputPrefix\s*}}`) -var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`) -var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^}\s]+)\s*}}`) -var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`) +var alphaNumericOnly = regexp.MustCompile("[^a-zA-Z0-9_]+") +var startsWithAlpha = regexp.MustCompile("^[^a-zA-Z_]+") + +type ErrorCollection struct { + Errors []error +} + +func (e ErrorCollection) Error() string { + sb := strings.Builder{} + for idx, err := range e.Errors { + sb.WriteString(fmt.Sprintf("%v: %v\r\n", idx, err)) + } + + return sb.String() +} // Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive // Supported templates are: @@ -32,7 +44,16 @@ var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPr // NOTE: I wanted to do in-place replacement, until I realized that in-place replacement will alter the definition of the // graph. This is not desirable, as we may have to retry and in that case the replacement will not work and we want // to create a new location for outputs -func ReplaceTemplateCommandArgs(ctx context.Context, command []string, in io.InputReader, out io.OutputFilePaths) ([]string, error) { +func ReplaceTemplateCommandArgs(ctx context.Context, tExecMeta core.TaskExecutionMetadata, command []string, in io.InputReader, + out io.OutputFilePaths) ([]string, error) { + + // TODO: Change GetGeneratedName to follow these conventions + var perRetryUniqueKey = tExecMeta.GetTaskExecutionID().GetGeneratedName() + perRetryUniqueKey = startsWithAlpha.ReplaceAllString(perRetryUniqueKey, "a") + perRetryUniqueKey = alphaNumericOnly.ReplaceAllString(perRetryUniqueKey, "_") + + logger.Debugf(ctx, "Using [%s] from [%s]", perRetryUniqueKey, tExecMeta.GetTaskExecutionID().GetGeneratedName()) + if len(command) == 0 { return []string{}, nil } @@ -41,7 +62,7 @@ func ReplaceTemplateCommandArgs(ctx context.Context, command []string, in io.Inp } res := make([]string, 0, len(command)) for _, commandTemplate := range command { - updated, err := replaceTemplateCommandArgs(ctx, commandTemplate, in, out) + updated, err := replaceTemplateCommandArgs(ctx, perRetryUniqueKey, commandTemplate, in, out) if err != nil { return res, err } @@ -52,7 +73,14 @@ func ReplaceTemplateCommandArgs(ctx context.Context, command []string, in io.Inp return res, nil } -func transformVarNameToStringVal(ctx context.Context, varName string, inputs *core.LiteralMap) (string, error) { +var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) +var inputPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]InputPrefix\s*}}`) +var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`) +var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^}\s]+)\s*}}`) +var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`) +var perRetryUniqueKey = regexp.MustCompile(`(?i){{\s*[\.$]PerRetryUniqueKey\s*}}`) + +func transformVarNameToStringVal(ctx context.Context, varName string, inputs *idlCore.LiteralMap) (string, error) { inputVal, exists := inputs.Literals[varName] if !exists { return "", fmt.Errorf("requested input is not found [%s]", varName) @@ -65,11 +93,14 @@ func transformVarNameToStringVal(ctx context.Context, varName string, inputs *co return v, nil } -func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in io.InputReader, out io.OutputFilePaths) (string, error) { +func replaceTemplateCommandArgs(ctx context.Context, perRetryKey string, commandTemplate string, + in io.InputReader, out io.OutputFilePaths) (string, error) { + val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String()) val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String()) val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String()) val = rawOutputDataPrefixRegex.ReplaceAllString(val, out.GetRawOutputPrefix().String()) + val = perRetryUniqueKey.ReplaceAllString(val, perRetryKey) inputs, err := in.Get(ctx) if err != nil { @@ -98,39 +129,41 @@ func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in return val, nil } -func serializePrimitive(p *core.Primitive) (string, error) { +func serializePrimitive(p *idlCore.Primitive) (string, error) { switch o := p.Value.(type) { - case *core.Primitive_Integer: + case *idlCore.Primitive_Integer: return fmt.Sprintf("%v", o.Integer), nil - case *core.Primitive_Boolean: + case *idlCore.Primitive_Boolean: return fmt.Sprintf("%v", o.Boolean), nil - case *core.Primitive_Datetime: + case *idlCore.Primitive_Datetime: return ptypes.TimestampString(o.Datetime), nil - case *core.Primitive_Duration: + case *idlCore.Primitive_Duration: return o.Duration.String(), nil - case *core.Primitive_FloatValue: + case *idlCore.Primitive_FloatValue: return fmt.Sprintf("%v", o.FloatValue), nil - case *core.Primitive_StringValue: + case *idlCore.Primitive_StringValue: return o.StringValue, nil default: return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) } } -func serializeLiteralScalar(l *core.Scalar) (string, error) { +func serializeLiteralScalar(l *idlCore.Scalar) (string, error) { switch o := l.Value.(type) { - case *core.Scalar_Primitive: + case *idlCore.Scalar_Primitive: return serializePrimitive(o.Primitive) - case *core.Scalar_Blob: + case *idlCore.Scalar_Blob: return o.Blob.Uri, nil + case *idlCore.Scalar_Schema: + return o.Schema.Uri, nil default: return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) } } -func serializeLiteral(ctx context.Context, l *core.Literal) (string, error) { +func serializeLiteral(ctx context.Context, l *idlCore.Literal) (string, error) { switch o := l.Value.(type) { - case *core.Literal_Collection: + case *idlCore.Literal_Collection: res := make([]string, 0, len(o.Collection.Literals)) for _, sub := range o.Collection.Literals { s, err := serializeLiteral(ctx, sub) @@ -142,7 +175,7 @@ func serializeLiteral(ctx context.Context, l *core.Literal) (string, error) { } return fmt.Sprintf("[%v]", strings.Join(res, ",")), nil - case *core.Literal_Scalar: + case *idlCore.Literal_Scalar: return serializeLiteralScalar(o.Scalar) default: logger.Debugf(ctx, "received unexpected primitive type") diff --git a/go/tasks/pluginmachinery/utils/template_test.go b/go/tasks/pluginmachinery/core/template/template_test.go old mode 100755 new mode 100644 similarity index 64% rename from go/tasks/pluginmachinery/utils/template_test.go rename to go/tasks/pluginmachinery/core/template/template_test.go index 16031953b1..41f813ed6d --- a/go/tasks/pluginmachinery/utils/template_test.go +++ b/go/tasks/pluginmachinery/core/template/template_test.go @@ -1,11 +1,14 @@ -package utils +package template import ( "context" "fmt" + "regexp" "testing" "time" + pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/lyft/flyteidl/clients/go/coreutils" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flytestdlib/storage" @@ -13,12 +16,6 @@ import ( "github.com/stretchr/testify/assert" ) -func BenchmarkRegexCommandArgs(b *testing.B) { - for i := 0; i < b.N; i++ { - inputFileRegex.MatchString("{{ .InputFile }}") - } -} - type dummyInputReader struct { inputPrefix storage.DataReference inputPath storage.DataReference @@ -62,35 +59,14 @@ func (d dummyOutputPaths) GetErrorPath() storage.DataReference { panic("should not be called") } -func TestInputRegexMatch(t *testing.T) { - assert.True(t, inputFileRegex.MatchString("{{$input}}")) - assert.True(t, inputFileRegex.MatchString("{{ $Input }}")) - assert.True(t, inputFileRegex.MatchString("{{.input}}")) - assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) - assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) - assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) - assert.True(t, inputFileRegex.MatchString("{{ .Input}}")) - assert.True(t, inputFileRegex.MatchString("{{.Input }}")) - assert.True(t, inputFileRegex.MatchString("--something={{.Input}}")) - assert.False(t, inputFileRegex.MatchString("{{input}}"), "Missing $") - assert.False(t, inputFileRegex.MatchString("{$input}}"), "Missing Brace") -} - -func TestOutputRegexMatch(t *testing.T) { - assert.True(t, outputRegex.MatchString("{{.OutputPrefix}}")) - assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) - assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) - assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) - assert.True(t, outputRegex.MatchString("{{ .OutputPrefix}}")) - assert.True(t, outputRegex.MatchString("{{.OutputPrefix }}")) - assert.True(t, outputRegex.MatchString("--something={{.OutputPrefix}}")) - assert.False(t, outputRegex.MatchString("{{output}}"), "Missing $") - assert.False(t, outputRegex.MatchString("{.OutputPrefix}}"), "Missing Brace") -} - func TestReplaceTemplateCommandArgs(t *testing.T) { + taskExecutionID := &pluginsCoreMocks.TaskExecutionID{} + taskExecutionID.On("GetGeneratedName").Return("per_retry_unique_key") + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) + t.Run("empty cmd", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{}, nil, nil) assert.NoError(t, err) assert.Equal(t, []string{}, actual) @@ -103,7 +79,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { } t.Run("nothing to substitute", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", }, in, out) @@ -116,7 +92,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub InputFile", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", "{{ .Input }}", @@ -132,7 +108,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { t.Run("Sub Input Prefix", func(t *testing.T) { in := dummyInputReader{inputPath: "input/prefix"} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", "{{ .Input }}", @@ -147,7 +123,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub Output Prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", "{{ .OutputPrefix }}", @@ -161,7 +137,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Sub Input Output prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", "{{ .Input }}", @@ -177,7 +153,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) t.Run("Bad input template", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", "${{input}}", @@ -206,7 +182,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }, }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, @@ -229,7 +205,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "date": coreutils.MustMakeLiteral(time.Date(1900, 01, 01, 01, 01, 01, 000000001, time.UTC)), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", `--someArg {{ .Inputs.date }}`, @@ -252,7 +228,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, @@ -272,7 +248,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { t.Run("nil input", func(t *testing.T) { in := dummyInputReader{inputs: &core.LiteralMap{}} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, @@ -298,7 +274,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "min": coreutils.MustMakeLiteral(15), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ `SELECT COUNT(*) as total_count FROM @@ -323,7 +299,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - _, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + _, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", `--someArg {{ .Inputs.blah }}`, @@ -338,32 +314,169 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), }, }} - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", "world", - `--someArg {{ .Inputs.blah blah }}`, + `--someArg {{ .Inputs.blah blah }} {{ .PerretryuNIqueKey }}`, "{{ .OutputPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ "hello", "world", - `--someArg {{ .Inputs.blah blah }}`, + `--someArg {{ .Inputs.blah blah }} per_retry_unique_key`, "output/blah", }, actual) }) t.Run("sub raw output data prefix", func(t *testing.T) { - actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ "hello", + "{{ .perRetryUniqueKey }}", "world", "{{ .rawOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ "hello", + "per_retry_unique_key", "world", "s3://custom-bucket", }, actual) }) } + +func TestReplaceTemplateCommandArgsSpecialChars(t *testing.T) { + in := dummyInputReader{inputPath: "input/blah"} + out := dummyOutputPaths{ + outputPath: "output/blah", + rawOutputDataPrefix: "s3://custom-bucket", + } + + t.Run("dashes are replaced", func(t *testing.T) { + taskExecutionID := &pluginsCoreMocks.TaskExecutionID{} + taskExecutionID.On("GetGeneratedName").Return("per-retry-unique-key") + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) + + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + "hello", + "{{ .perRetryUniqueKey }}", + "world", + "{{ .rawOutputDataPrefix }}", + }, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "per_retry_unique_key", + "world", + "s3://custom-bucket", + }, actual) + }) + + t.Run("non-alphabet leading characters are stripped", func(t *testing.T) { + var startsWithAlpha = regexp.MustCompile("^[^a-zA-Z_]+") + taskExecutionID := &pluginsCoreMocks.TaskExecutionID{} + taskExecutionID.On("GetGeneratedName").Return("33 per retry-unique-key") + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetTaskExecutionID").Return(taskExecutionID) + + testString := "doesn't start with a number" + testString2 := "1 does start with a number" + testString3 := " 1 3 nd spaces " + assert.Equal(t, testString, startsWithAlpha.ReplaceAllString(testString, "a")) + assert.Equal(t, "adoes start with a number", startsWithAlpha.ReplaceAllString(testString2, "a")) + assert.Equal(t, "and spaces ", startsWithAlpha.ReplaceAllString(testString3, "a")) + + actual, err := ReplaceTemplateCommandArgs(context.TODO(), taskMetadata, []string{ + "hello", + "{{ .perRetryUniqueKey }}", + "world", + "{{ .rawOutputDataPrefix }}", + }, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "aper_retry_unique_key", + "world", + "s3://custom-bucket", + }, actual) + }) +} + +func BenchmarkRegexCommandArgs(b *testing.B) { + for i := 0; i < b.N; i++ { + inputFileRegex.MatchString("{{ .InputFile }}") + } +} + +func TestInputRegexMatch(t *testing.T) { + assert.True(t, inputFileRegex.MatchString("{{$input}}")) + assert.True(t, inputFileRegex.MatchString("{{ $Input }}")) + assert.True(t, inputFileRegex.MatchString("{{.input}}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input }}")) + assert.True(t, inputFileRegex.MatchString("{{ .Input}}")) + assert.True(t, inputFileRegex.MatchString("{{.Input }}")) + assert.True(t, inputFileRegex.MatchString("--something={{.Input}}")) + assert.False(t, inputFileRegex.MatchString("{{input}}"), "Missing $") + assert.False(t, inputFileRegex.MatchString("{$input}}"), "Missing Brace") +} + +func TestOutputRegexMatch(t *testing.T) { + assert.True(t, outputRegex.MatchString("{{.OutputPrefix}}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("{{ .OutputPrefix}}")) + assert.True(t, outputRegex.MatchString("{{.OutputPrefix }}")) + assert.True(t, outputRegex.MatchString("--something={{.OutputPrefix}}")) + assert.False(t, outputRegex.MatchString("{{output}}"), "Missing $") + assert.False(t, outputRegex.MatchString("{.OutputPrefix}}"), "Missing Brace") +} + +func getBlobLiteral(uri string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Blob{ + Blob: &core.Blob{ + Metadata: nil, + Uri: uri, + }, + }, + }, + }, + } +} + +func getSchemaLiteral(uri string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Schema{ + Schema: &core.Schema{Type: nil, Uri: uri}, + }, + }, + }, + } +} + +func TestSerializeLiteral(t *testing.T) { + ctx := context.Background() + + t.Run("serialize blob", func(t *testing.T) { + b := getBlobLiteral("asdf fdsa") + interpolated, err := serializeLiteral(ctx, b) + assert.NoError(t, err) + assert.Equal(t, "asdf fdsa", interpolated) + }) + + t.Run("serialize blob", func(t *testing.T) { + s := getSchemaLiteral("s3://some-bucket/fdsa/x.parquet") + interpolated, err := serializeLiteral(ctx, s) + assert.NoError(t, err) + assert.Equal(t, "s3://some-bucket/fdsa/x.parquet", interpolated) + }) +} diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper.go b/go/tasks/pluginmachinery/flytek8s/container_helper.go index f0583516a0..0be0d375bf 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -4,6 +4,8 @@ import ( "context" "regexp" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flytestdlib/logger" v1 "k8s.io/api/core/v1" @@ -14,7 +16,6 @@ import ( pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" ) var isAcceptableK8sName, _ = regexp.Compile("[a-z0-9]([-a-z0-9]*[a-z0-9])?") @@ -87,12 +88,12 @@ func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequiremen // Returns a K8s Container for the execution func ToK8sContainer(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskContainer *core.Container, iFace *core.TypedInterface, inputReader io.InputReader, outputPaths io.OutputFilePaths) (*v1.Container, error) { - modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetCommand(), inputReader, outputPaths) + modifiedCommand, err := template.ReplaceTemplateCommandArgs(ctx, taskExecutionMetadata, taskContainer.GetCommand(), inputReader, outputPaths) if err != nil { return nil, err } - modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetArgs(), inputReader, outputPaths) + modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskExecutionMetadata, taskContainer.GetArgs(), inputReader, outputPaths) if err != nil { return nil, err } diff --git a/go/tasks/pluginmachinery/io/iface.go b/go/tasks/pluginmachinery/io/iface.go index 4719bc5bea..6879bcc6e4 100644 --- a/go/tasks/pluginmachinery/io/iface.go +++ b/go/tasks/pluginmachinery/io/iface.go @@ -52,7 +52,7 @@ type RawOutputPaths interface { } // All paths where various meta outputs produced by the task can be placed, such that the framework can directly access them. -// All paths are reperesented using storage.DataReference -> an URN for the configured storage backend +// All paths are represented using storage.DataReference -> an URN for the configured storage backend type OutputFilePaths interface { // RawOutputPaths are available with OutputFilePaths RawOutputPaths diff --git a/go/tasks/plugins/array/awsbatch/transformer.go b/go/tasks/plugins/array/awsbatch/transformer.go index b64025e365..69722dde1c 100644 --- a/go/tasks/plugins/array/awsbatch/transformer.go +++ b/go/tasks/plugins/array/awsbatch/transformer.go @@ -5,6 +5,8 @@ import ( "sort" "time" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/golang/protobuf/ptypes/duration" "k8s.io/apimachinery/pkg/api/resource" @@ -19,8 +21,6 @@ import ( "github.com/aws/aws-sdk-go/service/batch" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/lyft/flyteplugins/go/tasks/errors" pluginCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" v1 "k8s.io/api/core/v1" @@ -66,12 +66,12 @@ func FlyteTaskToBatchInput(ctx context.Context, tCtx pluginCore.TaskExecutionCon return nil, errors.Errorf(errors.BadTaskSpecification, "config[%v] is missing", DynamicTaskQueueKey) } - cmd, err := utils.ReplaceTemplateCommandArgs(ctx, taskTemplate.GetContainer().GetCommand(), tCtx.InputReader(), tCtx.OutputWriter()) + cmd, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetCommand(), tCtx.InputReader(), tCtx.OutputWriter()) if err != nil { return nil, err } - args, err := utils.ReplaceTemplateCommandArgs(ctx, taskTemplate.GetContainer().GetArgs(), + args, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), taskTemplate.GetContainer().GetArgs(), arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) taskTemplate.GetContainer().GetEnv() if err != nil { diff --git a/go/tasks/plugins/array/k8s/task.go b/go/tasks/plugins/array/k8s/task.go index a236b4871a..f65104a5a2 100644 --- a/go/tasks/plugins/array/k8s/task.go +++ b/go/tasks/plugins/array/k8s/task.go @@ -5,9 +5,10 @@ import ( "strconv" "strings" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/lyft/flyteplugins/go/tasks/plugins/array" "github.com/lyft/flyteplugins/go/tasks/plugins/array/arraystatus" arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core" @@ -74,7 +75,8 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl }) pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...) - pod.Spec.Containers[0].Args, err = utils.ReplaceTemplateCommandArgs(ctx, args, arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) + pod.Spec.Containers[0].Args, err = template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), args, + arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) if err != nil { return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") } diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go index e5b08b3467..8212326fa0 100644 --- a/go/tasks/plugins/hive/execution_state.go +++ b/go/tasks/plugins/hive/execution_state.go @@ -6,6 +6,10 @@ import ( "strconv" "time" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" + "github.com/lyft/flytestdlib/cache" "github.com/lyft/flytestdlib/contextutils" @@ -27,7 +31,7 @@ const ( PhaseNotStarted ExecutionPhase = iota PhaseQueued // resource manager token gotten PhaseSubmitted // Sent off to Qubole - + PhaseWriteOutputFile PhaseQuerySucceeded PhaseQueryFailed ) @@ -40,6 +44,8 @@ func (p ExecutionPhase) String() string { return "PhaseQueued" case PhaseSubmitted: return "PhaseSubmitted" + case PhaseWriteOutputFile: + return "PhaseWriteOutputFile" case PhaseQuerySucceeded: return "PhaseQuerySucceeded" case PhaseQueryFailed: @@ -84,6 +90,9 @@ func HandleExecutionState(ctx context.Context, tCtx core.TaskExecutionContext, c case PhaseSubmitted: newState, transformError = MonitorQuery(ctx, tCtx, currentState, executionsCache) + case PhaseWriteOutputFile: + newState, transformError = WriteOutputs(ctx, tCtx, currentState) + case PhaseQuerySucceeded: newState = currentState transformError = nil @@ -96,7 +105,7 @@ func HandleExecutionState(ctx context.Context, tCtx core.TaskExecutionContext, c return newState, transformError } -func MapExecutionStateToPhaseInfo(state ExecutionState, quboleClient client.QuboleClient) core.PhaseInfo { +func MapExecutionStateToPhaseInfo(state ExecutionState, _ client.QuboleClient) core.PhaseInfo { var phaseInfo core.PhaseInfo t := time.Now() @@ -113,6 +122,9 @@ func MapExecutionStateToPhaseInfo(state ExecutionState, quboleClient client.Qubo case PhaseSubmitted: phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, ConstructTaskInfo(state)) + case PhaseWriteOutputFile: + phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion+1, ConstructTaskInfo(state)) + case PhaseQuerySucceeded: phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) @@ -231,7 +243,7 @@ func validateQuboleHiveJob(hiveJob plugins.QuboleHiveJob) error { // This function is the link between the output written by the SDK, and the execution side. It extracts the query // out of the task template. func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( - query string, cluster string, tags []string, timeoutSec uint32, taskName string, err error) { + formattedQuery string, cluster string, tags []string, timeoutSec uint32, taskName string, err error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { @@ -248,7 +260,14 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( return "", "", []string{}, 0, "", err } - query = hiveJob.Query.GetQuery() + query := hiveJob.Query.GetQuery() + + outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{query}, tCtx.InputReader(), tCtx.OutputWriter()) + if err != nil { + return "", "", []string{}, 0, "", err + } + formattedQuery = outputs[0] + cluster = hiveJob.ClusterLabel timeoutSec = hiveJob.Query.TimeoutSec taskName = taskTemplate.Id.Name @@ -257,8 +276,10 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( for k, v := range tCtx.TaskExecutionMetadata().GetLabels() { tags = append(tags, fmt.Sprintf("%s:%s", k, v)) } - logger.Debugf(ctx, "QueryInfo: query: [%v], cluster: [%v], timeoutSec: [%v], tags: [%v]", query, cluster, timeoutSec, tags) - return + logger.Debugf(ctx, "QueryInfo: original query [%s], query: [%s], cluster: [%s], timeoutSec: [%d], tags: [%v]", + query, formattedQuery, cluster, timeoutSec, tags) + + return formattedQuery, cluster, tags, timeoutSec, taskName, err } func mapLabelToPrimaryLabel(ctx context.Context, quboleCfg *config.Config, label string) (primaryLabel string, found bool) { @@ -459,3 +480,54 @@ func IsNotYetSubmitted(e ExecutionState) bool { } return false } + +func WriteOutputs(ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState) ( + ExecutionState, error) { + + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + logger.Errorf(ctx, "Error reading task template: [%s]", err) + return currentState, err + } + + externalLocation := tCtx.OutputWriter().GetRawOutputPrefix() + outputs := taskTemplate.Interface.Outputs.GetVariables() + if len(outputs) != 0 && len(outputs) != 1 { + return currentState, errors.Errorf(errors.BadTaskSpecification, "Hive tasks must have zero or one output: [%d] found", len(outputs)) + } + if len(outputs) == 1 { + if results, ok := outputs["results"]; ok { + if results.GetType().GetSchema() == nil { + return currentState, errors.Errorf(errors.BadTaskSpecification, "A non-SchemaType was found [%v]", results.GetType()) + } + logger.Debugf(ctx, "Writing outputs file for Hive task at [%s]", tCtx.OutputWriter().GetOutputPrefixPath()) + err = tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &idlCore.LiteralMap{ + Literals: map[string]*idlCore.Literal{ + "results": { + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{Value: &idlCore.Scalar_Schema{ + Schema: &idlCore.Schema{ + Uri: externalLocation.String(), + Type: results.GetType().GetSchema(), + }, + }, + }, + }, + }, + }, + }, nil)) + if err != nil { + logger.Errorf(ctx, "Error writing outputs file: [%s]", err) + return currentState, err + } + } else { + logger.Errorf(ctx, "Wrong name for output [%s]", err) + return currentState, errors.Errorf(errors.BadTaskSpecification, "One output found but wrong name [%s]", outputs) + } + } + + logger.Debugf(ctx, "Moving hive task to succeeded") + currentState.Phase = PhaseQuerySucceeded + return currentState, nil +} diff --git a/go/tasks/plugins/hive/execution_state_test.go b/go/tasks/plugins/hive/execution_state_test.go index e22e054733..38b1a96ff6 100644 --- a/go/tasks/plugins/hive/execution_state_test.go +++ b/go/tasks/plugins/hive/execution_state_test.go @@ -2,10 +2,14 @@ package hive import ( "context" + "fmt" "net/url" "testing" "time" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + ioMock "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils/labeled" @@ -75,24 +79,13 @@ func TestIsNotYetSubmitted(t *testing.T) { func TestGetQueryInfo(t *testing.T) { ctx := context.Background() + tCtx := GetMockTaskExecutionContext() - taskTemplate := GetSingleHiveQueryTaskTemplate() - mockTaskReader := &mocks.TaskReader{} - mockTaskReader.On("Read", mock.Anything).Return(&taskTemplate, nil) - - mockTaskExecutionContext := mocks.TaskExecutionContext{} - mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) - - taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} - taskMetadata.On("GetNamespace").Return("myproject-staging") - taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) - mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) - - query, cluster, tags, timeout, taskName, err := GetQueryInfo(ctx, &mockTaskExecutionContext) + query, cluster, tags, timeout, taskName, err := GetQueryInfo(ctx, tCtx) assert.NoError(t, err) assert.Equal(t, "select 'one'", query) assert.Equal(t, "default", cluster) - assert.Equal(t, []string{"flyte_plugin_test", "ns:myproject-staging", "sample:label"}, tags) + assert.Equal(t, []string{"flyte_plugin_test", "ns:test-namespace", "label-1:val1"}, tags) assert.Equal(t, 500, int(timeout)) assert.Equal(t, "sample_hive_task_test_name", taskName) } @@ -168,6 +161,15 @@ func TestMapExecutionStateToPhaseInfo(t *testing.T) { phaseInfo := MapExecutionStateToPhaseInfo(e, c) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) }) + + t.Run("Write outputs file", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseWriteOutputFile, + } + phaseInfo := MapExecutionStateToPhaseInfo(e, c) + assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) + assert.Equal(t, uint32(1), phaseInfo.Version()) + }) } func TestGetAllocationToken(t *testing.T) { @@ -347,6 +349,22 @@ func TestKickOffQuery(t *testing.T) { assert.True(t, quboleCalled) } +func TestWriteOutputs(t *testing.T) { + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + tCtx.OutputWriter().(*ioMock.OutputWriter).On("Put", mock.Anything, mock.Anything).Return(nil).Run(func(arguments mock.Arguments) { + reader := arguments.Get(1).(io.OutputReader) + literals, err1, err2 := reader.Read(context.Background()) + assert.Nil(t, err1) + assert.NoError(t, err2) + assert.NotNil(t, literals.Literals["results"].GetScalar().GetSchema()) + }) + + state := ExecutionState{} + newState, _ := WriteOutputs(ctx, tCtx, state) + fmt.Println(newState) +} + func createMockQuboleCfg() *config.Config { return &config.Config{ DefaultClusterLabel: "default", diff --git a/go/tasks/plugins/hive/executions_cache.go b/go/tasks/plugins/hive/executions_cache.go index 3e7347ffa3..d854813874 100644 --- a/go/tasks/plugins/hive/executions_cache.go +++ b/go/tasks/plugins/hive/executions_cache.go @@ -155,7 +155,7 @@ func (q *QuboleHiveExecutionsCache) SyncQuboleQuery(ctx context.Context, batch c func QuboleStatusToExecutionPhase(s client.QuboleStatus) (ExecutionPhase, error) { switch s { case client.QuboleStatusDone: - return PhaseQuerySucceeded, nil + return PhaseWriteOutputFile, nil case client.QuboleStatusCancelled: return PhaseQueryFailed, nil case client.QuboleStatusError: diff --git a/go/tasks/plugins/hive/executions_cache_test.go b/go/tasks/plugins/hive/executions_cache_test.go index cc33365b37..b74428588d 100644 --- a/go/tasks/plugins/hive/executions_cache_test.go +++ b/go/tasks/plugins/hive/executions_cache_test.go @@ -86,6 +86,6 @@ func TestQuboleHiveExecutionsCache_SyncQuboleQuery(t *testing.T) { newExecutionState := newCacheItem[0].Item.(ExecutionStateCacheItem) assert.NoError(t, err) assert.Equal(t, cache.Update, newCacheItem[0].Action) - assert.Equal(t, PhaseQuerySucceeded, newExecutionState.Phase) + assert.Equal(t, PhaseWriteOutputFile, newExecutionState.Phase) }) } diff --git a/go/tasks/plugins/hive/test_helpers.go b/go/tasks/plugins/hive/test_helpers.go index fa78fd6f22..e8f924f928 100644 --- a/go/tasks/plugins/hive/test_helpers.go +++ b/go/tasks/plugins/hive/test_helpers.go @@ -47,6 +47,17 @@ func GetSingleHiveQueryTaskTemplate() idlCore.TaskTemplate { Version: "1", ResourceType: idlCore.ResourceType_TASK, }, + Interface: &idlCore.TypedInterface{ + Outputs: &idlCore.VariableMap{ + Variables: map[string]*idlCore.Variable{ + "results": &idlCore.Variable{ + Type: &idlCore.LiteralType{ + Type: &idlCore.LiteralType_Schema{Schema: &idlCore.SchemaType{}}, + }, + }, + }, + }, + }, } return tt @@ -106,6 +117,7 @@ func GetMockTaskExecutionContext() core.TaskExecutionContext { dummyTaskMetadata := GetMockTaskExecutionMetadata() taskCtx := &coreMock.TaskExecutionContext{} inputReader := &ioMock.InputReader{} + inputReader.On("GetInputPrefixPath").Return(storage.DataReference("s3://test-input-prefix")) inputReader.On("GetInputPath").Return(storage.DataReference("test-data-reference")) inputReader.On("Get", mock.Anything).Return(&idlCore.LiteralMap{}, nil) taskCtx.On("InputReader").Return(inputReader) @@ -113,6 +125,7 @@ func GetMockTaskExecutionContext() core.TaskExecutionContext { outputReader := &ioMock.OutputWriter{} outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("gs://custom-output-bucket/b")) taskCtx.On("OutputWriter").Return(outputReader) taskReader := &coreMock.TaskReader{} diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index 6c32cbeaf7..55db278df9 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -6,6 +6,8 @@ import ( "sort" "strings" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" @@ -323,7 +325,7 @@ func injectTaskTemplateEnvVarToHyperparameters(ctx context.Context, taskTemplate func injectArgsAndEnvVars(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, taskTemplate *flyteIdlCore.TaskTemplate) ([]*commonv1.KeyValuePair, error) { templateArgs := taskTemplate.GetContainer().GetArgs() - templateArgs, err := utils.ReplaceTemplateCommandArgs(ctx, templateArgs, taskCtx.InputReader(), taskCtx.OutputWriter()) + templateArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), templateArgs, taskCtx.InputReader(), taskCtx.OutputWriter()) if err != nil { return nil, errors.Wrapf(ErrSagemaker, err, "Failed to de-template the hyperparameter values") } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go index 83eddede16..2e6bea8265 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" @@ -37,13 +39,13 @@ func validateAndFinalizePod( if container.Name == primaryContainerName { hasPrimaryContainer = true } - modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, container.Command, taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedCommand, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.Command, taskCtx.InputReader(), taskCtx.OutputWriter()) if err != nil { return nil, err } container.Command = modifiedCommand - modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, container.Args, taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.Args, taskCtx.InputReader(), taskCtx.OutputWriter()) if err != nil { return nil, err } diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 51f69a0c5a..e2cc35bba8 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -5,6 +5,8 @@ import ( "fmt" "strconv" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" @@ -133,7 +135,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo }, } - modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, container.GetArgs(), taskCtx.InputReader(), taskCtx.OutputWriter()) + modifiedArgs, err := template.ReplaceTemplateCommandArgs(ctx, taskCtx.TaskExecutionMetadata(), container.GetArgs(), taskCtx.InputReader(), taskCtx.OutputWriter()) if err != nil { return nil, err } diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 701b25d6a7..3f5640949e 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -3,6 +3,8 @@ package presto import ( "context" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/template" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "k8s.io/apimachinery/pkg/util/rand" @@ -219,7 +221,7 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (string, return "", "", "", "", err } - outputs, err := utils.ReplaceTemplateCommandArgs(ctx, []string{ + outputs, err := template.ReplaceTemplateCommandArgs(ctx, tCtx.TaskExecutionMetadata(), []string{ prestoQuery.RoutingGroup, prestoQuery.Catalog, prestoQuery.Schema,