From 47a3c026fddd1cfbdce79776ec191be0945df6bf Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 24 Mar 2021 11:38:01 -0700 Subject: [PATCH] Pass along external resource IDs in task event execution metadata (#165) --- flyteplugins/go.mod | 2 +- flyteplugins/go.sum | 7 +++- .../go/tasks/pluginmachinery/core/phase.go | 4 ++ .../k8s/mocks/plugin_context.go | 34 +++++++++++++++ .../go/tasks/pluginmachinery/k8s/plugin.go | 3 ++ .../pluginmachinery/webapi/example/plugin.go | 9 ++++ .../tasks/plugins/array/awsbatch/executor.go | 4 +- .../plugins/array/awsbatch/task_links.go | 29 ++++++++++--- .../go/tasks/plugins/array/core/state.go | 12 +++++- .../go/tasks/plugins/array/core/state_test.go | 42 +++++++++++++++---- .../go/tasks/plugins/array/k8s/executor.go | 5 ++- .../go/tasks/plugins/array/k8s/monitor.go | 23 +++++----- .../tasks/plugins/array/k8s/monitor_test.go | 29 +++++++++++-- .../go/tasks/plugins/array/k8s/task.go | 2 + .../go/tasks/plugins/hive/execution_state.go | 12 ++++++ .../plugins/hive/execution_state_test.go | 10 +++++ .../k8s/sagemaker/builtin_training_test.go | 10 +++++ .../k8s/sagemaker/custom_training_test.go | 10 +++++ .../sagemaker/hyperparameter_tuning_test.go | 10 +++++ .../go/tasks/plugins/k8s/sagemaker/utils.go | 8 ++++ .../tasks/plugins/presto/execution_state.go | 9 ++++ .../plugins/presto/execution_state_test.go | 2 + .../go/tasks/plugins/webapi/athena/plugin.go | 9 ++++ .../plugins/webapi/athena/plugin_test.go | 30 +++++++++++++ 24 files changed, 280 insertions(+), 35 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 157098a6a..2c47222c2 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -13,7 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.0.0 github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v0.18.17 + github.com/flyteorg/flyteidl v0.18.25 github.com/flyteorg/flytestdlib v0.3.13 github.com/go-logr/zapr v0.4.0 // indirect github.com/go-test/deep v1.0.7 diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index c00c83cd4..055058a3b 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -72,6 +72,7 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= @@ -227,8 +228,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.18.17 h1:74pPZ9PzITuzq+CgjMPb9EcFI5bVkf8mM5m4xmmlTmY= -github.com/flyteorg/flyteidl v0.18.17/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= +github.com/flyteorg/flyteidl v0.18.25 h1:XbHwM4G1u5nGAcdKod+ENgbL84cHdNzQIWY+NajuHs8= +github.com/flyteorg/flyteidl v0.18.25/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= github.com/flyteorg/flytestdlib v0.3.13 h1:5ioA/q3ixlyqkFh5kDaHgmPyTP/AHtqq1K/TIbVLUzM= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= @@ -753,6 +754,7 @@ go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= @@ -1214,6 +1216,7 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= k8s.io/api v0.18.2/go.mod h1:SJCWI7OLzhZSvbY7U8zwNl9UA4o1fizoug34OV/2r78= diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 42828f215..0f07e5113 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" structpb "github.com/golang/protobuf/ptypes/struct" ) @@ -72,6 +74,8 @@ type TaskInfo struct { OccurredAt *time.Time // Custom Event information that the plugin would like to expose to the front-end CustomInfo *structpb.Struct + // Metadata around how a task was executed + Metadata *event.TaskExecutionMetadata } func (t *TaskInfo) String() string { diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin_context.go b/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin_context.go index 2d4c65b96..bdf3cbcae 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin_context.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin_context.go @@ -150,6 +150,40 @@ func (_m *PluginContext) OutputWriter() io.OutputWriter { return r0 } +type PluginContext_TaskExecutionMetadata struct { + *mock.Call +} + +func (_m PluginContext_TaskExecutionMetadata) Return(_a0 core.TaskExecutionMetadata) *PluginContext_TaskExecutionMetadata { + return &PluginContext_TaskExecutionMetadata{Call: _m.Call.Return(_a0)} +} + +func (_m *PluginContext) OnTaskExecutionMetadata() *PluginContext_TaskExecutionMetadata { + c := _m.On("TaskExecutionMetadata") + return &PluginContext_TaskExecutionMetadata{Call: c} +} + +func (_m *PluginContext) OnTaskExecutionMetadataMatch(matchers ...interface{}) *PluginContext_TaskExecutionMetadata { + c := _m.On("TaskExecutionMetadata", matchers...) + return &PluginContext_TaskExecutionMetadata{Call: c} +} + +// TaskExecutionMetadata provides a mock function with given fields: +func (_m *PluginContext) TaskExecutionMetadata() core.TaskExecutionMetadata { + ret := _m.Called() + + var r0 core.TaskExecutionMetadata + if rf, ok := ret.Get(0).(func() core.TaskExecutionMetadata); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(core.TaskExecutionMetadata) + } + } + + return r0 +} + type PluginContext_TaskReader struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 04c2d75c2..529314e2f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -58,6 +58,9 @@ type PluginContext interface { // Returns the max allowed dataset size that the outputwriter will accept MaxDatasetSizeBytes() int64 + + // Returns a handle to the Task's execution metadata. + TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata } // Defines a simplified interface to author plugins for k8s resources. diff --git a/flyteplugins/go/tasks/pluginmachinery/webapi/example/plugin.go b/flyteplugins/go/tasks/pluginmachinery/webapi/example/plugin.go index fa89447f6..401cb9064 100644 --- a/flyteplugins/go/tasks/pluginmachinery/webapi/example/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/webapi/example/plugin.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/errors" @@ -94,6 +96,13 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co }, }, OccurredAt: &tNow, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "abc", + }, + }, + }, }), nil } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go index 840f41f3f..104149f29 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/executor.go @@ -104,7 +104,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c } // Always attempt to augment phase with task logs. - logLinks, err := GetTaskLinks(ctx, tCtx.TaskExecutionMetadata(), e.jobStore, pluginState) + subTaskDetails, err := GetTaskLinks(ctx, tCtx.TaskExecutionMetadata(), e.jobStore, pluginState) if err != nil { return core.UnknownTransition, err } @@ -112,7 +112,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c logger.Infof(ctx, "Exiting handle with phase [%v]", pluginState.State.CurrentPhase) // Determine transition information from the state - phaseInfo, err := arrayCore.MapArrayStateToPluginPhase(ctx, pluginState.State, logLinks) + phaseInfo, err := arrayCore.MapArrayStateToPluginPhase(ctx, pluginState.State, subTaskDetails.LogLinks, subTaskDetails.SubTaskIDs) if err != nil { return core.UnknownTransition, err } diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/task_links.go b/flyteplugins/go/tasks/plugins/array/awsbatch/task_links.go index dccd04855..10e844fff 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/task_links.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/task_links.go @@ -36,13 +36,22 @@ func GetJobTaskLog(jobSize int, accountID, region, queue, jobID string) *idlCore } } +type SubTaskDetails struct { + LogLinks []*idlCore.TaskLog + SubTaskIDs []*string +} + func GetTaskLinks(ctx context.Context, taskMeta pluginCore.TaskExecutionMetadata, jobStore *JobStore, state *State) ( - []*idlCore.TaskLog, error) { + SubTaskDetails, error) { logLinks := make([]*idlCore.TaskLog, 0, 4) + subTaskIDs := make([]*string, 0) if state.GetExternalJobID() == nil { - return logLinks, nil + return SubTaskDetails{ + LogLinks: logLinks, + SubTaskIDs: subTaskIDs, + }, nil } // TODO: Add tasktemplate container config to job config @@ -58,14 +67,20 @@ func GetTaskLinks(ctx context.Context, taskMeta pluginCore.TaskExecutionMetadata }) if err != nil { - return nil, errors.Wrapf(errors2.DownstreamSystemError, err, "Failed to retrieve a job from job store.") + return SubTaskDetails{ + LogLinks: logLinks, + SubTaskIDs: subTaskIDs, + }, errors.Wrapf(errors2.DownstreamSystemError, err, "Failed to retrieve a job from job store.") } if job == nil { logger.Debugf(ctx, "Job [%v] not found in jobs store. It might have been evicted. If reasonable, bump the max "+ "size of the LRU cache.", *state.GetExternalJobID()) - return logLinks, nil + return SubTaskDetails{ + LogLinks: logLinks, + SubTaskIDs: subTaskIDs, + }, nil } detailedArrayStatus := state.GetArrayStatus().Detailed @@ -83,7 +98,11 @@ func GetTaskLinks(ctx context.Context, taskMeta pluginCore.TaskExecutionMetadata }) } } + subTaskIDs = append(subTaskIDs, &subJob.ID) } - return logLinks, nil + return SubTaskDetails{ + LogLinks: logLinks, + SubTaskIDs: subTaskIDs, + }, nil } diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 44fd901a0..2f8d31152 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" @@ -168,7 +170,7 @@ func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 { // Info fields will always be nil, because we're going to send log links individually. This simplifies our state // handling as we don't have to keep an ever growing list of log links (our batch jobs can be 5000 sub-tasks, keeping // all the log links takes up a lot of space). -func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idlCore.TaskLog) (core.PhaseInfo, error) { +func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idlCore.TaskLog, subTaskIDs []*string) (core.PhaseInfo, error) { phaseInfo := core.PhaseInfoUndefined t := time.Now() @@ -176,6 +178,14 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl OccurredAt: &t, Logs: logLinks, } + if nowTaskInfo.Metadata == nil { + nowTaskInfo.Metadata = &event.TaskExecutionMetadata{} + } + for _, subTaskID := range subTaskIDs { + nowTaskInfo.Metadata.ExternalResources = append(nowTaskInfo.Metadata.ExternalResources, &event.ExternalResourceInfo{ + ExternalId: *subTaskID, + }) + } switch p, version := state.GetPhase(); p { case PhaseStart: diff --git a/flyteplugins/go/tasks/plugins/array/core/state_test.go b/flyteplugins/go/tasks/plugins/array/core/state_test.go index d7d221050..876612623 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state_test.go +++ b/flyteplugins/go/tasks/plugins/array/core/state_test.go @@ -2,8 +2,11 @@ package core import ( "context" + "fmt" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/proto" @@ -48,14 +51,32 @@ func assertBitSetsEqual(t testing.TB, b1, b2 *bitarray.BitSet, len int) { } } +func assertTaskExecutionMetadata(t *testing.T, subTaskIDs []*string, metadata *event.TaskExecutionMetadata) { + assert.NotNil(t, metadata) + var externalResources = make([]*event.ExternalResourceInfo, len(subTaskIDs)) + for i, subTaskID := range subTaskIDs { + externalResources[i] = &event.ExternalResourceInfo{ + ExternalId: *subTaskID, + } + } + assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ + ExternalResources: externalResources, + }, metadata)) +} + func TestMapArrayStateToPluginPhase(t *testing.T) { ctx := context.Background() + var subTaskIDs = make([]*string, 3) + for i := 0; i < 3; i++ { + subTaskID := fmt.Sprintf("sub_task_%d", i) + subTaskIDs[i] = &subTaskID + } t.Run("start", func(t *testing.T) { s := State{ CurrentPhase: PhaseStart, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseInitializing, phaseInfo.Phase()) }) @@ -66,7 +87,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 0, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) }) @@ -79,10 +100,11 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { ExecutionArraySize: 5, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(368), phaseInfo.Version()) + assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) }) t.Run("write to discovery", func(t *testing.T) { @@ -93,10 +115,11 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { ExecutionArraySize: 5, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(548), phaseInfo.Version()) + assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) }) t.Run("success", func(t *testing.T) { @@ -105,9 +128,10 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 0, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseSuccess, phaseInfo.Phase()) + assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) }) t.Run("retryable failure", func(t *testing.T) { @@ -116,9 +140,10 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 0, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) + assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) }) t.Run("permanent failure", func(t *testing.T) { @@ -127,9 +152,10 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 0, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhasePermanentFailure, phaseInfo.Phase()) + assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) }) t.Run("All phases", func(t *testing.T) { @@ -138,7 +164,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { CurrentPhase: p, } - phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil) + phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.NotEqual(t, core.PhaseUndefined, phaseInfo.Phase()) } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/executor.go b/flyteplugins/go/tasks/plugins/array/k8s/executor.go index 2790ee5f1..7e8bc6353 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/executor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/executor.go @@ -86,6 +86,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c var nextState *arrayCore.State var err error var logLinks []*idlCore.TaskLog + var subTaskIDs []*string switch p, _ := pluginState.GetPhase(); p { case arrayCore.PhaseStart: @@ -107,7 +108,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c case arrayCore.PhaseCheckingSubTaskExecutions: - nextState, logLinks, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, + nextState, logLinks, subTaskIDs, err = LaunchAndCheckSubTasksState(ctx, tCtx, e.kubeClient, pluginConfig, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState) case arrayCore.PhaseAssembleFinalOutput: @@ -135,7 +136,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c } // Determine transition information from the state - phaseInfo, err := arrayCore.MapArrayStateToPluginPhase(ctx, nextState, logLinks) + phaseInfo, err := arrayCore.MapArrayStateToPluginPhase(ctx, nextState, logLinks, subTaskIDs) if err != nil { return core.UnknownTransition, err } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index a2afde9d7..3d21ad577 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -36,12 +36,12 @@ const ( func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, config *Config, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, currentState *arrayCore.State) ( - newState *arrayCore.State, logLinks []*idlCore.TaskLog, err error) { + newState *arrayCore.State, logLinks []*idlCore.TaskLog, subTaskIDs []*string, err error) { if int64(currentState.GetExecutionArraySize()) > config.MaxArrayJobSize { ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), config.MaxArrayJobSize) logger.Info(ctx, ee) currentState = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error()) - return currentState, logLinks, nil + return currentState, logLinks, subTaskIDs, nil } logLinks = make([]*idlCore.TaskLog, 0, 4) @@ -51,6 +51,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon Summary: arraystatus.ArraySummary{}, Detailed: arrayCore.NewPhasesCompactArray(uint(currentState.GetExecutionArraySize())), } + subTaskIDs = make([]*string, 0, len(currentState.GetArrayStatus().Detailed.GetItems())) // If we have arrived at this state for the first time then currentState has not been // initialized with number of sub tasks. @@ -70,7 +71,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon err = deallocateResource(ctx, tCtx, config, childIdx) if err != nil { logger.Errorf(ctx, "Error releasing allocation token [%s] in LaunchAndCheckSubTasks [%s]", podName, err) - return currentState, logLinks, errors2.Wrapf(ErrCheckPodStatus, err, "Error releasing allocation token.") + return currentState, logLinks, subTaskIDs, errors2.Wrapf(ErrCheckPodStatus, err, "Error releasing allocation token.") } newArrayStatus.Summary.Inc(existingPhase) newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase)) @@ -86,6 +87,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon Config: config, ChildIdx: childIdx, MessageCollector: &msg, + SubTaskIDs: subTaskIDs, } // The first time we enter this state we will launch every subtask. On subsequent rounds, the pod @@ -94,31 +96,32 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon launchResult, err = task.Launch(ctx, tCtx, kubeClient) if err != nil { logger.Errorf(ctx, "K8s array - Launch error %v", err) - return currentState, logLinks, err + return currentState, logLinks, subTaskIDs, err } switch launchResult { case LaunchSuccess: // Continue with execution if successful case LaunchError: - return currentState, logLinks, err + return currentState, logLinks, subTaskIDs, err // If Resource manager is enabled and there are currently not enough resources we can skip this round // for a subtask and wait until there are enough resources. case LaunchWaiting: continue case LaunchReturnState: - return currentState, logLinks, nil + return currentState, logLinks, subTaskIDs, nil } var monitorResult MonitorResult monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) logLinks = task.LogLinks + subTaskIDs = task.SubTaskIDs if monitorResult != MonitorSuccess { if err != nil { logger.Errorf(ctx, "K8s array - Monitor error %v", err) } - return currentState, logLinks, err + return currentState, logLinks, subTaskIDs, err } } @@ -127,9 +130,9 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon // Check that the taskTemplate is valid taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { - return currentState, logLinks, err + return currentState, logLinks, subTaskIDs, err } else if taskTemplate == nil { - return currentState, logLinks, fmt.Errorf("required value not set, taskTemplate is nil") + return currentState, logLinks, subTaskIDs, fmt.Errorf("required value not set, taskTemplate is nil") } phase := arrayCore.SummaryToPhase(ctx, currentState.GetOriginalMinSuccesses()-currentState.GetOriginalArraySize()+int64(currentState.GetExecutionArraySize()), newArrayStatus.Summary) @@ -151,7 +154,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon newState = newState.SetPhase(phase, core.DefaultPhaseVersion) } - return newState, logLinks, nil + return newState, logLinks, subTaskIDs, nil } func CheckPodStatus(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName) ( diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go index 3c9ebc65c..64fec894e 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go @@ -1,6 +1,7 @@ package k8s import ( + "fmt" "testing" core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -92,6 +93,15 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex return tCtx } +func testSubTaskIDs(t *testing.T, actual []*string) { + var expected = make([]*string, 5) + for i := 0; i < len(expected); i++ { + subTaskID := fmt.Sprintf("notfound-%d", i) + expected[i] = &subTaskID + } + assert.EqualValues(t, expected, actual) +} + func TestCheckSubTasksState(t *testing.T) { ctx := context.Background() @@ -104,7 +114,7 @@ func TestCheckSubTasksState(t *testing.T) { t.Run("Happy case", func(t *testing.T) { config := Config{MaxArrayJobSize: 100} - newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, @@ -116,6 +126,7 @@ func TestCheckSubTasksState(t *testing.T) { p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "AllocateResource", 0) + testSubTaskIDs(t, subTaskIDs) }) t.Run("Resource exhausted", func(t *testing.T) { @@ -127,17 +138,21 @@ func TestCheckSubTasksState(t *testing.T) { }, } - newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, OriginalMinSuccesses: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(5)), + }, }) assert.Nil(t, err) p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseWaitingForResources.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) + assert.Empty(t, subTaskIDs, "subtask ids are only populated when monitor is called for a successfully launched task") }) } @@ -148,6 +163,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) resourceManager := mocks.ResourceManager{} + resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) tCtx.OnResourceManager().Return(&resourceManager) @@ -161,17 +177,21 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { }, } - newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, OriginalMinSuccesses: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: arrayCore.NewPhasesCompactArray(uint(5)), + }, }) assert.Nil(t, err) p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "AllocateResource", 5) + testSubTaskIDs(t, subTaskIDs) }) t.Run("All tasks success", func(t *testing.T) { @@ -191,7 +211,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { arrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseSuccess)) } - newState, _, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, @@ -203,5 +223,6 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseWriteToDiscovery.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "ReleaseResource", 5) + assert.Empty(t, subTaskIDs, "terminal phases don't need to collect subtask IDs") }) } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/task.go b/flyteplugins/go/tasks/plugins/array/k8s/task.go index 40c7747cd..7598c3fd5 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/task.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/task.go @@ -30,6 +30,7 @@ type Task struct { Config *Config ChildIdx int MessageCollector *errorcollector.ErrorMessageCollector + SubTaskIDs []*string } type LaunchResult int8 @@ -126,6 +127,7 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + t.SubTaskIDs = append(t.SubTaskIDs, &podName) phaseInfo, err := CheckPodStatus(ctx, kubeClient, k8sTypes.NamespacedName{ Name: podName, diff --git a/flyteplugins/go/tasks/plugins/hive/execution_state.go b/flyteplugins/go/tasks/plugins/hive/execution_state.go index e741fd9fb..881a5ce7c 100644 --- a/flyteplugins/go/tasks/plugins/hive/execution_state.go +++ b/flyteplugins/go/tasks/plugins/hive/execution_state.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -146,11 +148,21 @@ func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { logs := make([]*idlCore.TaskLog, 0, 1) t := time.Now() + + metadata := &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: e.CommandID, + }, + }, + } + if e.CommandID != "" { logs = append(logs, ConstructTaskLog(e)) return &core.TaskInfo{ Logs: logs, OccurredAt: &t, + Metadata: metadata, } } diff --git a/flyteplugins/go/tasks/plugins/hive/execution_state_test.go b/flyteplugins/go/tasks/plugins/hive/execution_state_test.go index 7fcb8508d..cd5cd868e 100644 --- a/flyteplugins/go/tasks/plugins/hive/execution_state_test.go +++ b/flyteplugins/go/tasks/plugins/hive/execution_state_test.go @@ -7,6 +7,9 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" @@ -125,6 +128,13 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://wellness.qubole.com/v2/analyze?command_id=123", taskInfo.Logs[0].Uri) + assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "123", + }, + }, + })) } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go index e915842bb..df02c97a3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go @@ -5,6 +5,9 @@ import ( "fmt" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/golang/protobuf/proto" + "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -270,5 +273,12 @@ func Test_awsSagemakerPlugin_getEventInfoForTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } + assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "some-acceptable-name", + }, + }, + })) }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/custom_training_test.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/custom_training_test.go index aab7edeb3..91709b507 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/custom_training_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/custom_training_test.go @@ -6,6 +6,9 @@ import ( "strconv" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/golang/protobuf/proto" + "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -294,5 +297,12 @@ func Test_awsSagemakerPlugin_getEventInfoForCustomTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } + assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "some-acceptable-name", + }, + }, + })) }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go index 176a8375c..796949d86 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go @@ -5,6 +5,9 @@ import ( "fmt" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/golang/protobuf/proto" + "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -126,5 +129,12 @@ func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } + assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "some-acceptable-name", + }, + }, + })) }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/utils.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/utils.go index 320f6b0a7..ac520f530 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/utils.go @@ -6,6 +6,7 @@ import ( "sort" "strings" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -394,5 +395,12 @@ func createTaskInfo(_ context.Context, jobRegion string, jobName string, jobType return &pluginsCore.TaskInfo{ Logs: taskLogs, CustomInfo: customInfo, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: jobName, + }, + }, + }, }, nil } diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state.go b/flyteplugins/go/tasks/plugins/presto/execution_state.go index 59df5e07f..8e2700feb 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state.go @@ -3,6 +3,8 @@ package presto import ( "context" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -499,6 +501,13 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { return &core.TaskInfo{ Logs: logs, OccurredAt: &t, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: e.CommandID, + }, + }, + }, } } diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state_test.go b/flyteplugins/go/tasks/plugins/presto/execution_state_test.go index 904c64a5b..d6caad0b8 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state_test.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state_test.go @@ -106,6 +106,8 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://prestoproxy-internal.flyteorg.net:443", taskInfo.Logs[0].Uri) + assert.Len(t, taskInfo.Metadata.ExternalResources, 1) + assert.Equal(t, taskInfo.Metadata.ExternalResources[0].ExternalId, "123") } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/athena/plugin.go b/flyteplugins/go/tasks/plugins/webapi/athena/plugin.go index 0f6c19c3c..f2cbba436 100644 --- a/flyteplugins/go/tasks/plugins/webapi/athena/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/athena/plugin.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" awsSdk "github.com/aws/aws-sdk-go-v2/aws" @@ -184,6 +186,13 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo { Name: "Athena Query Console", }, }, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: queryID, + }, + }, + }, } } diff --git a/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go new file mode 100644 index 000000000..d85b42573 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go @@ -0,0 +1,30 @@ +package athena + +import ( + "testing" + + awsSdk "github.com/aws/aws-sdk-go-v2/aws" + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +func TestCreateTaskInfo(t *testing.T) { + taskInfo := createTaskInfo("query_id", awsSdk.Config{ + Region: "us-east-1", + }) + assert.EqualValues(t, []*idlCore.TaskLog{ + { + Uri: "https://us-east-1.console.aws.amazon.com/athena/home?force®ion=us-east-1#query/history/query_id", + Name: "Athena Query Console", + }, + }, taskInfo.Logs) + assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + ExternalId: "query_id", + }, + }, + }, taskInfo.Metadata)) +}