From b7dedc9dfaffcbec61c8657651acba8648e5c5f1 Mon Sep 17 00:00:00 2001 From: Jeev B Date: Tue, 7 Nov 2023 18:51:39 -0800 Subject: [PATCH] Passthrough unique node ID in task execution ID for generating log template vars Signed-off-by: Jeev B --- flyteplugins/go/tasks/logs/logging_utils.go | 5 +- .../go/tasks/logs/logging_utils_test.go | 59 +++++++----- .../pluginmachinery/core/exec_metadata.go | 4 + .../core/mocks/task_execution_id.go | 32 +++++++ .../tasks/pluginmachinery/tasklog/plugin.go | 3 +- .../tasks/pluginmachinery/tasklog/template.go | 96 +++++++++---------- .../pluginmachinery/tasklog/template_test.go | 93 +++++++----------- .../go/tasks/plugins/k8s/dask/dask.go | 10 +- .../k8s/kfoperators/common/common_operator.go | 50 +++++----- .../go/tasks/plugins/k8s/pod/plugin.go | 6 +- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 6 +- .../go/tasks/plugins/k8s/spark/spark.go | 34 +++---- flyteplugins/tests/end_to_end.go | 1 + .../controller/nodes/task/taskexec_context.go | 23 +++-- .../controller/nodes/task/transformer_test.go | 2 + 15 files changed, 228 insertions(+), 196 deletions(-) diff --git a/flyteplugins/go/tasks/logs/logging_utils.go b/flyteplugins/go/tasks/logs/logging_utils.go index 0ca515d7c87..6af1889e9f9 100644 --- a/flyteplugins/go/tasks/logs/logging_utils.go +++ b/flyteplugins/go/tasks/logs/logging_utils.go @@ -8,6 +8,7 @@ import ( v1 "k8s.io/api/core/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -18,7 +19,7 @@ type logPlugin struct { } // Internal -func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID *core.TaskExecutionIdentifier, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) { +func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID pluginsCore.TaskExecutionID, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) { if logPlugin == nil { return nil, nil } @@ -53,7 +54,7 @@ func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, tas PodRFC3339FinishTime: time.Unix(finishTime, 0).Format(time.RFC3339), PodUnixStartTime: startTime, PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: taskExecID, + TaskExecutionID: taskExecID, ExtraTemplateVarsByScheme: extraLogTemplateVarsByScheme, }, ) diff --git a/flyteplugins/go/tasks/logs/logging_utils_test.go b/flyteplugins/go/tasks/logs/logging_utils_test.go index fbf86b99337..066fdd96c8b 100644 --- a/flyteplugins/go/tasks/logs/logging_utils_test.go +++ b/flyteplugins/go/tasks/logs/logging_utils_test.go @@ -10,34 +10,41 @@ import ( v12 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" ) const podName = "PodName" -var dummyTaskExecID = &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", +func dummyTaskExecID() pluginCore.TaskExecutionID { + tID := &coreMocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("generated-name") + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", }, - }, - RetryAttempt: 1, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + tID.OnGetUniqueNodeID().Return("n0-0-n0") + return tID } func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) { logPlugin, err := InitializeLogPlugins(&LogConfig{}) assert.NoError(t, err) - l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, nil, 0, " Suffix", nil) + l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), nil, 0, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, l) } @@ -49,7 +56,7 @@ func TestGetLogsForContainerInPod_NoLogs(t *testing.T) { CloudwatchLogGroup: "/kubernetes/flyte-production", }) assert.NoError(t, err) - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, nil, 0, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), nil, 0, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -80,7 +87,7 @@ func TestGetLogsForContainerInPod_BadIndex(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 1, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 1, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -105,7 +112,7 @@ func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 1, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 1, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -135,7 +142,7 @@ func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -165,7 +172,7 @@ func TestGetLogsForContainerInPod_K8s(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -198,7 +205,7 @@ func TestGetLogsForContainerInPod_All(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 2) } @@ -229,7 +236,7 @@ func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -303,7 +310,7 @@ func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*c }, } - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " my-Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " my-Suffix", nil) assert.Nil(tb, err) assert.Len(tb, logs, len(expectedTaskLogs)) if diff := deep.Equal(logs, expectedTaskLogs); len(diff) > 0 { @@ -337,7 +344,7 @@ func TestGetLogsForContainerInPod_Templates(t *testing.T) { Name: "StackDriver my-Suffix", }, { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/1/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "Internal my-Suffix", }, diff --git a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go index 8517b9c3853..9ac650baaa7 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go @@ -27,6 +27,10 @@ type TaskExecutionID interface { // GetID returns the underlying idl task identifier. GetID() core.TaskExecutionIdentifier + + // GetUniqueNodeID returns the fully-qualified Node ID that is unique within a + // given workflow execution. + GetUniqueNodeID() string } // TaskExecutionMetadata represents any execution information for a Task. It is used to communicate meta information about the diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go index 7db5590170a..44596bf82f0 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go @@ -114,3 +114,35 @@ func (_m *TaskExecutionID) GetID() flyteidlcore.TaskExecutionIdentifier { return r0 } + +type TaskExecutionID_GetUniqueNodeID struct { + *mock.Call +} + +func (_m TaskExecutionID_GetUniqueNodeID) Return(_a0 string) *TaskExecutionID_GetUniqueNodeID { + return &TaskExecutionID_GetUniqueNodeID{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionID) OnGetUniqueNodeID() *TaskExecutionID_GetUniqueNodeID { + c_call := _m.On("GetUniqueNodeID") + return &TaskExecutionID_GetUniqueNodeID{Call: c_call} +} + +func (_m *TaskExecutionID) OnGetUniqueNodeIDMatch(matchers ...interface{}) *TaskExecutionID_GetUniqueNodeID { + c_call := _m.On("GetUniqueNodeID", matchers...) + return &TaskExecutionID_GetUniqueNodeID{Call: c_call} +} + +// GetUniqueNodeID provides a mock function with given fields: +func (_m *TaskExecutionID) GetUniqueNodeID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go index 0ca91c33704..b812221f6d1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go @@ -4,6 +4,7 @@ import ( "regexp" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" ) //go:generate enumer --type=TemplateScheme --trimprefix=TemplateScheme -json -yaml @@ -42,7 +43,7 @@ type Input struct { PodUnixStartTime int64 PodUnixFinishTime int64 PodUID string - TaskExecutionIdentifier *core.TaskExecutionIdentifier + TaskExecutionID pluginsCore.TaskExecutionID ExtraTemplateVarsByScheme *TemplateVarsByScheme } diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go index 2a68f42cff1..77c49d26950 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go @@ -114,55 +114,55 @@ func (input Input) templateVarsForScheme(scheme TemplateScheme) TemplateVars { vars = append(vars, input.ExtraTemplateVarsByScheme.Pod...) } case TemplateSchemeTaskExecution: - if input.TaskExecutionIdentifier != nil { - vars = append(vars, TemplateVar{ + taskExecutionIdentifier := input.TaskExecutionID.GetID() + vars = append( + vars, + TemplateVar{ + defaultRegexes.NodeID, + input.TaskExecutionID.GetUniqueNodeID(), + }, + TemplateVar{ defaultRegexes.TaskRetryAttempt, - strconv.FormatUint(uint64(input.TaskExecutionIdentifier.RetryAttempt), 10), - }) - if input.TaskExecutionIdentifier.TaskId != nil { - vars = append( - vars, - TemplateVar{ - defaultRegexes.TaskID, - input.TaskExecutionIdentifier.TaskId.Name, - }, - TemplateVar{ - defaultRegexes.TaskVersion, - input.TaskExecutionIdentifier.TaskId.Version, - }, - TemplateVar{ - defaultRegexes.TaskProject, - input.TaskExecutionIdentifier.TaskId.Project, - }, - TemplateVar{ - defaultRegexes.TaskDomain, - input.TaskExecutionIdentifier.TaskId.Domain, - }, - ) - } - if input.TaskExecutionIdentifier.NodeExecutionId != nil { - vars = append(vars, TemplateVar{ - defaultRegexes.NodeID, - input.TaskExecutionIdentifier.NodeExecutionId.NodeId, - }) - if input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId != nil { - vars = append( - vars, - TemplateVar{ - defaultRegexes.ExecutionName, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Name, - }, - TemplateVar{ - defaultRegexes.ExecutionProject, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Project, - }, - TemplateVar{ - defaultRegexes.ExecutionDomain, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain, - }, - ) - } - } + strconv.FormatUint(uint64(taskExecutionIdentifier.RetryAttempt), 10), + }, + ) + if taskExecutionIdentifier.TaskId != nil { + vars = append( + vars, + TemplateVar{ + defaultRegexes.TaskID, + taskExecutionIdentifier.TaskId.Name, + }, + TemplateVar{ + defaultRegexes.TaskVersion, + taskExecutionIdentifier.TaskId.Version, + }, + TemplateVar{ + defaultRegexes.TaskProject, + taskExecutionIdentifier.TaskId.Project, + }, + TemplateVar{ + defaultRegexes.TaskDomain, + taskExecutionIdentifier.TaskId.Domain, + }, + ) + } + if taskExecutionIdentifier.NodeExecutionId != nil && taskExecutionIdentifier.NodeExecutionId.ExecutionId != nil { + vars = append( + vars, + TemplateVar{ + defaultRegexes.ExecutionName, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Name, + }, + TemplateVar{ + defaultRegexes.ExecutionProject, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Project, + }, + TemplateVar{ + defaultRegexes.ExecutionDomain, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain, + }, + ) } if gotExtraTemplateVars { vars = append(vars, input.ExtraTemplateVarsByScheme.TaskExecution...) diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go index e3f03047aa0..320ece05a46 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ) func TestTemplateLog(t *testing.T) { @@ -38,6 +40,30 @@ func Benchmark_initDefaultRegexes(b *testing.B) { } } +func dummyTaskExecID() pluginCore.TaskExecutionID { + tID := &coreMocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("generated-name") + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + tID.OnGetUniqueNodeID().Return("n0-0-n0") + return tID +} + func Test_Input_templateVarsForScheme(t *testing.T) { testRegexes := struct { Foo *regexp.Regexp @@ -66,25 +92,8 @@ func Test_Input_templateVarsForScheme(t *testing.T) { PodUnixFinishTime: 12345, } taskExecutionBase := Input{ - LogName: "main_logs", - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + LogName: "main_logs", + TaskExecutionID: dummyTaskExecID(), } tests := []struct { @@ -162,12 +171,12 @@ func Test_Input_templateVarsForScheme(t *testing.T) { nil, TemplateVars{ {defaultRegexes.LogName, "main_logs"}, - {defaultRegexes.TaskRetryAttempt, "0"}, + {defaultRegexes.NodeID, "n0-0-n0"}, + {defaultRegexes.TaskRetryAttempt, "1"}, {defaultRegexes.TaskID, "my-task-name"}, {defaultRegexes.TaskVersion, "1"}, {defaultRegexes.TaskProject, "my-task-project"}, {defaultRegexes.TaskDomain, "my-task-domain"}, - {defaultRegexes.NodeID, "n0"}, {defaultRegexes.ExecutionName, "my-execution-name"}, {defaultRegexes.ExecutionProject, "my-execution-project"}, {defaultRegexes.ExecutionDomain, "my-execution-domain"}, @@ -484,30 +493,13 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { PodRFC3339FinishTime: "1970-01-01T04:25:45+01:00", PodUnixStartTime: 123, PodUnixFinishTime: 12345, - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + TaskExecutionID: dummyTaskExecID(), }, }, Output{ TaskLogs: []*core.TaskLog{ { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/0/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "main_logs", }, @@ -534,24 +526,7 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { PodRFC3339FinishTime: "1970-01-01T04:25:45+01:00", PodUnixStartTime: 123, PodUnixFinishTime: 12345, - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + TaskExecutionID: dummyTaskExecID(), ExtraTemplateVarsByScheme: &TemplateVarsByScheme{ TaskExecution: TemplateVars{ {MustCreateRegex("subtaskExecutionIndex"), "1"}, @@ -564,7 +539,7 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { Output{ TaskLogs: []*core.TaskLog{ { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/0/mappedIndex/1/mappedAttempt/1/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/0/mappedIndex/1/mappedAttempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "main_logs", }, diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index 65050f5bb26..eb27aec3ced 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -298,13 +298,13 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s status == daskAPI.DaskJobClusterCreated if !isQueued { - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() o, err := logPlugin.GetTaskLogs( tasklog.Input{ - Namespace: job.ObjectMeta.Namespace, - PodName: job.Status.JobRunnerPodName, - LogName: "(User logs)", - TaskExecutionIdentifier: &taskExecID, + Namespace: job.ObjectMeta.Namespace, + PodName: job.Status.JobRunnerPodName, + LogName: "(User logs)", + TaskExecutionID: taskExecID, }, ) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index e0903d02a3c..594767b4b42 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -98,7 +98,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v namespace := objectMeta.Namespace taskLogs := make([]*core.TaskLog, 0, 10) - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) @@ -120,14 +120,14 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v if taskType == PytorchTaskType && hasMaster { masterTaskLog, masterErr := logPlugin.GetTaskLogs( tasklog.Input{ - PodName: name + "-master-0", - Namespace: namespace, - LogName: "master", - PodRFC3339StartTime: RFC3999StartTime, - PodRFC3339FinishTime: RFC3999FinishTime, - PodUnixStartTime: startTime, - PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: &taskExecID, + PodName: name + "-master-0", + Namespace: namespace, + LogName: "master", + PodRFC3339StartTime: RFC3999StartTime, + PodRFC3339FinishTime: RFC3999FinishTime, + PodUnixStartTime: startTime, + PodUnixFinishTime: finishTime, + TaskExecutionID: taskExecID, }, ) if masterErr != nil { @@ -139,13 +139,13 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get all workers log for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-worker-%d", workerIndex), - Namespace: namespace, - PodRFC3339StartTime: RFC3999StartTime, - PodRFC3339FinishTime: RFC3999FinishTime, - PodUnixStartTime: startTime, - PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-worker-%d", workerIndex), + Namespace: namespace, + PodRFC3339StartTime: RFC3999StartTime, + PodRFC3339FinishTime: RFC3999FinishTime, + PodUnixStartTime: startTime, + PodUnixFinishTime: finishTime, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -160,9 +160,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get all parameter servers logs for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -172,9 +172,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get chief worker log, and the max number of chief worker is 1 if chiefReplicasCount != 0 { chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -184,9 +184,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get evaluator log, and the max number of evaluator is 1 if evaluatorReplicasCount != 0 { evaluatorReplicasCount, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index d1ba98bcaad..11de8770216 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -164,9 +164,9 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin ReportedAt: &reportedAt, } - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, &taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme) + taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -211,7 +211,7 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin } else { // if the primary container annotation exists, we use the status of the specified container phaseInfo = flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info) - if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil && + if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil && phaseInfo.Err().GetCode() == flytek8s.PrimaryContainerNotFound { // if the primary container status is not found ensure that the primary container exists. // note: it should be impossible for the primary container to not exist at this point. diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index c1dcc2b8e2b..cc8d1983343 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -444,10 +444,10 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs // RayJob CRD does not include the name of the worker or head pod for now - taskID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{ - Namespace: rayJob.Namespace, - TaskExecutionIdentifier: &taskID, + Namespace: rayJob.Namespace, + TaskExecutionID: taskExecID, }) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index d0506ccfb50..e5fd14478a0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -329,7 +329,7 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl sparkConfig := GetSparkConfig() taskLogs := make([]*core.TaskLog, 0, 3) - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() if !isQueued { if sj.Status.DriverInfo.PodName != "" { @@ -340,10 +340,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Status.DriverInfo.PodName, - Namespace: sj.Namespace, - LogName: "(Driver Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Status.DriverInfo.PodName, + Namespace: sj.Namespace, + LogName: "(Driver Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -361,10 +361,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Status.DriverInfo.PodName, - Namespace: sj.Namespace, - LogName: "(User Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Status.DriverInfo.PodName, + Namespace: sj.Namespace, + LogName: "(User Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -381,10 +381,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Name, - Namespace: sj.Namespace, - LogName: "(System Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Name, + Namespace: sj.Namespace, + LogName: "(System Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -402,10 +402,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Name, - Namespace: sj.Namespace, - LogName: "(Spark-Submit/All User Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Name, + Namespace: sj.Namespace, + LogName: "(Spark-Submit/All User Logs)", + TaskExecutionID: taskExecID, }) if err != nil { diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index 603b4d3a301..037ae877d96 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -136,6 +136,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i }, RetryAttempt: 0, }) + tID.OnGetUniqueNodeID().Return("unique-node-id") overrides := &coreMocks.TaskOverrides{} overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{ diff --git a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go index 8b819c79eb6..74fdbc31f55 100644 --- a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go +++ b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go @@ -33,14 +33,19 @@ var ( const IDMaxLength = 50 type taskExecutionID struct { - execName string - id *core.TaskExecutionIdentifier + execName string + id *core.TaskExecutionIdentifier + uniqueNodeId string } func (te taskExecutionID) GetID() core.TaskExecutionIdentifier { return *te.id } +func (te taskExecutionID) GetUniqueNodeID() string { + return te.uniqueNodeId +} + func (te taskExecutionID) GetGeneratedName() string { return te.execName } @@ -291,11 +296,15 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx interfaces.N NodeExecutionContext: nCtx, tm: taskExecutionMetadata{ NodeExecutionMetadata: nCtx.NodeExecutionMetadata(), - taskExecID: taskExecutionID{execName: uniqueID, id: id}, - o: nCtx.Node(), - maxAttempts: maxAttempts, - platformResources: convertTaskResourcesToRequirements(nCtx.ExecutionContext().GetExecutionConfig().TaskResources), - environmentVariables: nCtx.ExecutionContext().GetExecutionConfig().EnvironmentVariables, + taskExecID: taskExecutionID{ + execName: uniqueID, + id: id, + uniqueNodeId: currentNodeUniqueID, + }, + o: nCtx.Node(), + maxAttempts: maxAttempts, + platformResources: convertTaskResourcesToRequirements(nCtx.ExecutionContext().GetExecutionConfig().TaskResources), + environmentVariables: nCtx.ExecutionContext().GetExecutionConfig().EnvironmentVariables, }, rm: resourcemanager.GetTaskResourceManager( t.resourceManager, resourceNamespacePrefix, id), diff --git a/flytepropeller/pkg/controller/nodes/task/transformer_test.go b/flytepropeller/pkg/controller/nodes/task/transformer_test.go index a26705baeed..a9fd9538f7d 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer_test.go @@ -67,6 +67,7 @@ func TestToTaskExecutionEvent(t *testing.T) { generatedName := "generated_name" tID.OnGetGeneratedName().Return(generatedName) tID.OnGetID().Return(*id) + tID.OnGetUniqueNodeID("unique-node-id") tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) @@ -261,6 +262,7 @@ func TestToTaskExecutionEventWithParent(t *testing.T) { generatedName := "generated_name" tID.OnGetGeneratedName().Return(generatedName) tID.OnGetID().Return(*id) + tID.OnGetUniqueNodeID("unique-node-id") tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID)