diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 3d5253fe5e..e59794f85e 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -91,6 +91,7 @@ func TestToK8sPodIterruptible(t *testing.T) { op := &pluginsIOMock.OutputFilePaths{} op.On("GetOutputPrefixPath").Return(storage.DataReference("")) + op.On("GetRawOutputPrefix").Return(storage.DataReference("")) x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ Limits: v1.ResourceList{ @@ -139,6 +140,7 @@ func TestToK8sPod(t *testing.T) { op := &pluginsIOMock.OutputFilePaths{} op.On("GetOutputPrefixPath").Return(storage.DataReference("")) + op.On("GetRawOutputPrefix").Return(storage.DataReference("")) t.Run("WithGPU", func(t *testing.T) { x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ diff --git a/go/tasks/pluginmachinery/utils/template.go b/go/tasks/pluginmachinery/utils/template.go index 087ced2f62..8f4a829038 100755 --- a/go/tasks/pluginmachinery/utils/template.go +++ b/go/tasks/pluginmachinery/utils/template.go @@ -19,6 +19,7 @@ var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) var inputPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]InputPrefix\s*}}`) var outputRegex = regexp.MustCompile(`(?i){{\s*[\.$]OutputPrefix\s*}}`) var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^}\s]+)\s*}}`) +var rawOutputDataPrefixRegex = regexp.MustCompile(`(?i){{\s*[\.$]RawOutputDataPrefix\s*}}`) // Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive // Supported templates are: @@ -68,6 +69,7 @@ func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String()) val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String()) val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String()) + val = rawOutputDataPrefixRegex.ReplaceAllString(val, out.GetRawOutputPrefix().String()) inputs, err := in.Get(ctx) if err != nil { diff --git a/go/tasks/pluginmachinery/utils/template_test.go b/go/tasks/pluginmachinery/utils/template_test.go index 07ee1e3fa8..16031953b1 100755 --- a/go/tasks/pluginmachinery/utils/template_test.go +++ b/go/tasks/pluginmachinery/utils/template_test.go @@ -42,11 +42,12 @@ func (d dummyInputReader) Get(ctx context.Context) (*core.LiteralMap, error) { } type dummyOutputPaths struct { - outputPath storage.DataReference + outputPath storage.DataReference + rawOutputDataPrefix storage.DataReference } func (d dummyOutputPaths) GetRawOutputPrefix() storage.DataReference { - panic("should not be called") + return d.rawOutputDataPrefix } func (d dummyOutputPaths) GetOutputPrefixPath() storage.DataReference { @@ -96,7 +97,10 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }) in := dummyInputReader{inputPath: "input/blah"} - out := dummyOutputPaths{outputPath: "output/blah"} + out := dummyOutputPaths{ + outputPath: "output/blah", + rawOutputDataPrefix: "s3://custom-bucket", + } t.Run("nothing to substitute", func(t *testing.T) { actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ @@ -178,6 +182,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", "${{input}}", "{{ .OutputPrefix }}", + "--switch {{ .rawOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ @@ -185,6 +190,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", "${{input}}", "output/blah", + "--switch s3://custom-bucket", }, actual) }) @@ -205,6 +211,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", + "{{ $RawOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ @@ -212,6 +219,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", "--someArg [a,b]", "output/blah", + "s3://custom-bucket", }, actual) }) @@ -226,6 +234,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", `--someArg {{ .Inputs.date }}`, "{{ .OutputPrefix }}", + "{{ .rawOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ @@ -233,6 +242,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", "--someArg 1900-01-01T01:01:01.000000001Z", "output/blah", + "s3://custom-bucket", }, actual) }) @@ -247,6 +257,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", + "{{ .wrongOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ @@ -254,6 +265,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", "--someArg [[a,b],[1,2]]", "output/blah", + "{{ .wrongOutputDataPrefix }}", }, actual) }) @@ -265,6 +277,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", `--someArg {{ .Inputs.arr }}`, "{{ .OutputPrefix }}", + "--raw-data-output-prefix {{ .rawOutputDataPrefix }}", }, in, out) assert.NoError(t, err) assert.Equal(t, []string{ @@ -272,6 +285,7 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "world", `--someArg {{ .Inputs.arr }}`, "output/blah", + "--raw-data-output-prefix s3://custom-bucket", }, actual) }) @@ -338,4 +352,18 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "output/blah", }, actual) }) + + t.Run("sub raw output data prefix", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + "{{ .rawOutputDataPrefix }}", + }, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "s3://custom-bucket", + }, actual) + }) } diff --git a/go/tasks/plugins/array/awsbatch/launcher_test.go b/go/tasks/plugins/array/awsbatch/launcher_test.go index 159c2759a9..f260a684a3 100644 --- a/go/tasks/plugins/array/awsbatch/launcher_test.go +++ b/go/tasks/plugins/array/awsbatch/launcher_test.go @@ -78,6 +78,7 @@ func TestLaunchSubTasks(t *testing.T) { ow := &mocks3.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") + ow.OnGetRawOutputPrefix().Return("s3://") ir := &mocks3.InputReader{} ir.OnGetInputPrefixPath().Return("/prefix/") diff --git a/go/tasks/plugins/array/awsbatch/transformer_test.go b/go/tasks/plugins/array/awsbatch/transformer_test.go index 10fe7f92be..7a4f4f7074 100644 --- a/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -172,6 +172,7 @@ func TestArrayJobToBatchInput(t *testing.T) { or := &mocks2.OutputWriter{} or.OnGetOutputPrefixPath().Return("/path/output") + or.OnGetRawOutputPrefix().Return("s3://") taskCtx := &mocks.TaskExecutionContext{} taskCtx.OnTaskExecutionMetadata().Return(tMetadata) diff --git a/go/tasks/plugins/k8s/container/container_test.go b/go/tasks/plugins/k8s/container/container_test.go index 38301fd25c..32703a2778 100755 --- a/go/tasks/plugins/k8s/container/container_test.go +++ b/go/tasks/plugins/k8s/container/container_test.go @@ -86,6 +86,7 @@ func dummyContainerTaskContext(resources *v1.ResourceRequirements, command []str outputReader := &pluginsIOMock.OutputWriter{} outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("")) taskCtx.On("OutputWriter").Return(outputReader) taskReader := &pluginsCoreMock.TaskReader{} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 2b325e841b..67d4fb4aa8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -107,6 +107,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx outputReader := &pluginIOMocks.OutputWriter{} outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("")) taskCtx.OnOutputWriter().Return(outputReader) taskReader := &mocks.TaskReader{} diff --git a/go/tasks/plugins/k8s/sagemaker/config/config_flags.go b/go/tasks/plugins/k8s/sagemaker/config/config_flags.go index de6c82cb04..683f88b9a8 100755 --- a/go/tasks/plugins/k8s/sagemaker/config/config_flags.go +++ b/go/tasks/plugins/k8s/sagemaker/config/config_flags.go @@ -43,5 +43,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "roleArn"), defaultConfig.RoleArn, "The role the SageMaker plugin uses to communicate with the SageMaker service") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "region"), defaultConfig.Region, "The AWS region the SageMaker plugin communicates to") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "roleAnnotationKey"), defaultConfig.RoleAnnotationKey, "Map key to use to lookup role from task annotations.") return cmdFlags } diff --git a/go/tasks/plugins/k8s/sagemaker/config/config_flags_test.go b/go/tasks/plugins/k8s/sagemaker/config/config_flags_test.go index c97f992863..52a3c92931 100755 --- a/go/tasks/plugins/k8s/sagemaker/config/config_flags_test.go +++ b/go/tasks/plugins/k8s/sagemaker/config/config_flags_test.go @@ -143,4 +143,26 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_roleAnnotationKey", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("roleAnnotationKey"); err == nil { + assert.Equal(t, string(defaultConfig.RoleAnnotationKey), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("roleAnnotationKey", testValue) + if vString, err := cmdFlags.GetString("roleAnnotationKey"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RoleAnnotationKey) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/go/tasks/plugins/k8s/sidecar/sidecar_test.go index 57fc71d819..fdc01a0580 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -103,6 +103,7 @@ func getDummySidecarTaskContext(taskTemplate *core.TaskTemplate, resources *v1.R outputReader := &pluginsIOMock.OutputWriter{} outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("")) taskCtx.On("OutputWriter").Return(outputReader) taskReader := &pluginsCoreMock.TaskReader{} diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 5684e995c9..5707738b89 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -231,6 +231,8 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExec outputReader := &pluginIOMocks.OutputWriter{} outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("")) + taskCtx.On("OutputWriter").Return(outputReader) taskReader := &mocks.TaskReader{} diff --git a/go/tasks/plugins/presto/helpers_test.go b/go/tasks/plugins/presto/helpers_test.go index f5bc6a2355..f2366ae1b9 100644 --- a/go/tasks/plugins/presto/helpers_test.go +++ b/go/tasks/plugins/presto/helpers_test.go @@ -95,6 +95,7 @@ func GetMockTaskExecutionContext() core.TaskExecutionContext { outputReader := &ioMock.OutputWriter{} outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + outputReader.On("GetRawOutputPrefix").Return(storage.DataReference("s3://")) taskCtx.On("OutputWriter").Return(outputReader) taskReader := &coreMock.TaskReader{}