diff --git a/flyteplugins/go/tasks/plugins/array/k8s/config.go b/flyteplugins/go/tasks/plugins/array/k8s/config.go index c200bc24be..df02803a72 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/config.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/config.go @@ -8,6 +8,8 @@ import ( "fmt" "io/ioutil" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + "github.com/pkg/errors" v1 "k8s.io/api/core/v1" restclient "k8s.io/client-go/rest" @@ -113,8 +115,14 @@ type Config struct { RemoteClusterConfig ClusterConfig `json:"remoteClusterConfig" pflag:"-,Configuration of remote K8s cluster for array jobs"` NodeSelector map[string]string `json:"node-selector" pflag:"-,Defines a set of node selector labels to add to the pod."` Tolerations []v1.Toleration `json:"tolerations" pflag:"-,Tolerations to be applied for k8s-array pods"` + NamespaceTemplate string `json:"namespaceTemplate" pflag:"-,Namespace pattern to spawn array-jobs in. Defaults to parent namespace if not set"` OutputAssembler workqueue.Config ErrorAssembler workqueue.Config + LogConfig LogConfig `json:"logs" pflag:",Config for log links for k8s array jobs."` +} + +type LogConfig struct { + Config logs.LogConfig `json:"config" pflag:",Defines the log config for k8s logs."` } func GetConfig() *Config { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/config_flags.go b/flyteplugins/go/tasks/plugins/array/k8s/config_flags.go index 145da4b309..5d7f188090 100755 --- a/flyteplugins/go/tasks/plugins/array/k8s/config_flags.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/config_flags.go @@ -50,5 +50,16 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "ErrorAssembler.workers"), defaultConfig.ErrorAssembler.Workers, "Number of concurrent workers to start processing the queue.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "ErrorAssembler.maxRetries"), defaultConfig.ErrorAssembler.MaxRetries, "Maximum number of retries per item.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "ErrorAssembler.maxItems"), defaultConfig.ErrorAssembler.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.config.cloudwatch-enabled"), defaultConfig.LogConfig.Config.IsCloudwatchEnabled, "Enable Cloudwatch Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.cloudwatch-region"), defaultConfig.LogConfig.Config.CloudwatchRegion, "AWS region in which Cloudwatch logs are stored.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.cloudwatch-log-group"), defaultConfig.LogConfig.Config.CloudwatchLogGroup, "Log group to which streams are associated.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.cloudwatch-template-uri"), defaultConfig.LogConfig.Config.CloudwatchTemplateURI, "Template Uri to use when building cloudwatch log links") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.config.kubernetes-enabled"), defaultConfig.LogConfig.Config.IsKubernetesEnabled, "Enable Kubernetes Logging") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.kubernetes-url"), defaultConfig.LogConfig.Config.KubernetesURL, "Console URL for Kubernetes logs") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.kubernetes-template-uri"), defaultConfig.LogConfig.Config.KubernetesTemplateURI, "Template Uri to use when building kubernetes log links") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "logs.config.stackdriver-enabled"), defaultConfig.LogConfig.Config.IsStackDriverEnabled, "Enable Log-links to stackdriver") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.gcp-project"), defaultConfig.LogConfig.Config.GCPProjectName, "Name of the project in GCP") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.stackdriver-logresourcename"), defaultConfig.LogConfig.Config.StackdriverLogResourceName, "Name of the logresource in stackdriver") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "logs.config.stackdriver-template-uri"), defaultConfig.LogConfig.Config.StackDriverTemplateURI, "Template Uri to use when building stackdriver log links") return cmdFlags } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/config_flags_test.go b/flyteplugins/go/tasks/plugins/array/k8s/config_flags_test.go index 6fc957b614..9d01ed420e 100755 --- a/flyteplugins/go/tasks/plugins/array/k8s/config_flags_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/config_flags_test.go @@ -297,4 +297,246 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_logs.config.cloudwatch-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("logs.config.cloudwatch-enabled"); err == nil { + assert.Equal(t, bool(defaultConfig.LogConfig.Config.IsCloudwatchEnabled), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.cloudwatch-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.config.cloudwatch-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LogConfig.Config.IsCloudwatchEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.cloudwatch-region", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-region"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.CloudwatchRegion), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.cloudwatch-region", testValue) + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-region"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.CloudwatchRegion) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.cloudwatch-log-group", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-log-group"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.CloudwatchLogGroup), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.cloudwatch-log-group", testValue) + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-log-group"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.CloudwatchLogGroup) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.cloudwatch-template-uri", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-template-uri"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.CloudwatchTemplateURI), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.cloudwatch-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.config.cloudwatch-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.CloudwatchTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.kubernetes-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("logs.config.kubernetes-enabled"); err == nil { + assert.Equal(t, bool(defaultConfig.LogConfig.Config.IsKubernetesEnabled), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.kubernetes-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.config.kubernetes-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LogConfig.Config.IsKubernetesEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.kubernetes-url", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.kubernetes-url"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.KubernetesURL), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.kubernetes-url", testValue) + if vString, err := cmdFlags.GetString("logs.config.kubernetes-url"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.KubernetesURL) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.kubernetes-template-uri", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.kubernetes-template-uri"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.KubernetesTemplateURI), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.kubernetes-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.config.kubernetes-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.KubernetesTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.stackdriver-enabled", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vBool, err := cmdFlags.GetBool("logs.config.stackdriver-enabled"); err == nil { + assert.Equal(t, bool(defaultConfig.LogConfig.Config.IsStackDriverEnabled), vBool) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.stackdriver-enabled", testValue) + if vBool, err := cmdFlags.GetBool("logs.config.stackdriver-enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LogConfig.Config.IsStackDriverEnabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.gcp-project", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.gcp-project"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.GCPProjectName), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.gcp-project", testValue) + if vString, err := cmdFlags.GetString("logs.config.gcp-project"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.GCPProjectName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.stackdriver-logresourcename", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.stackdriver-logresourcename"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.StackdriverLogResourceName), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.stackdriver-logresourcename", testValue) + if vString, err := cmdFlags.GetString("logs.config.stackdriver-logresourcename"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.StackdriverLogResourceName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_logs.config.stackdriver-template-uri", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("logs.config.stackdriver-template-uri"); err == nil { + assert.Equal(t, string(defaultConfig.LogConfig.Config.StackDriverTemplateURI), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("logs.config.stackdriver-template-uri", testValue) + if vString, err := cmdFlags.GetString("logs.config.stackdriver-template-uri"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.LogConfig.Config.StackDriverTemplateURI) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index 3d21ad577a..322cd8df28 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -6,6 +6,8 @@ import ( "strconv" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/storage" @@ -59,6 +61,12 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon currentState.ArrayStatus = *newArrayStatus } + logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) + if err != nil { + logger.Errorf(ctx, "Error initializing LogPlugins: [%s]", err) + return currentState, logLinks, subTaskIDs, err + } + for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] indexStr := strconv.Itoa(childIdx) @@ -113,7 +121,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } var monitorResult MonitorResult - monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) + monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox, logPlugin) logLinks = task.LogLinks subTaskIDs = task.SubTaskIDs @@ -157,7 +165,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return newState, logLinks, subTaskIDs, nil } -func CheckPodStatus(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName) ( +func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName, index int, retryAttempt uint32, logPlugin tasklog.Plugin) ( info core.PhaseInfo, err error) { pod := &v1.Pod{ @@ -192,11 +200,19 @@ func CheckPodStatus(ctx context.Context, client core.KubeClient, name k8sTypes.N } if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, pod, 0, " (User)") - if err != nil { - return core.PhaseInfoUndefined, err + + if logPlugin != nil { + o, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: pod.Name, + Namespace: pod.Namespace, + LogName: fmt.Sprintf(" #%d-%d", index, retryAttempt), + }) + + if err != nil { + return core.PhaseInfoUndefined, err + } + taskInfo.Logs = o.TaskLogs } - taskInfo.Logs = taskLogs } switch pod.Status.Phase { case v1.PodSucceeded: diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go index 64fec894ee..a6290cde76 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/workqueue" + core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" @@ -93,6 +97,15 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex return tCtx } +func TestGetNamespaceForExecution(t *testing.T) { + ctx := context.Background() + tCtx := getMockTaskExecutionContext(ctx) + + assert.Equal(t, GetNamespaceForExecution(tCtx, ""), tCtx.TaskExecutionMetadata().GetNamespace()) + assert.Equal(t, GetNamespaceForExecution(tCtx, "abcd"), "abcd") + assert.Equal(t, GetNamespaceForExecution(tCtx, "a-{{.namespace}}-b"), fmt.Sprintf("a-%s-b", tCtx.TaskExecutionMetadata().GetNamespace())) +} + func testSubTaskIDs(t *testing.T, actual []*string) { var expected = make([]*string, 5) for i := 0; i < len(expected); i++ { @@ -113,16 +126,54 @@ func TestCheckSubTasksState(t *testing.T) { tCtx.OnResourceManager().Return(&resourceManager) t.Run("Happy case", func(t *testing.T) { - config := Config{MaxArrayJobSize: 100} - newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ + config := Config{ + MaxArrayJobSize: 100, + MaxErrorStringLength: 200, + NamespaceTemplate: "a-{{.namespace}}-b", + OutputAssembler: workqueue.Config{ + Workers: 2, + MaxRetries: 0, + IndexCacheMaxItems: 100, + }, + ErrorAssembler: workqueue.Config{ + Workers: 2, + MaxRetries: 0, + IndexCacheMaxItems: 100, + }, + LogConfig: LogConfig{ + Config: logs.LogConfig{ + IsCloudwatchEnabled: true, + CloudwatchTemplateURI: "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.{{ .podName }};streamFilter=typeLogStreamPrefix", + IsKubernetesEnabled: true, + KubernetesTemplateURI: "k8s/log/{{.namespace}}/{{.podName}}/pod?namespace={{.namespace}}", + }}, + } + cacheIndexes := bitarray.NewBitSet(5) + cacheIndexes.Set(0) + cacheIndexes.Set(1) + cacheIndexes.Set(2) + cacheIndexes.Set(3) + cacheIndexes.Set(4) + + newState, logLinks, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, OriginalMinSuccesses: 5, + IndexesToCache: cacheIndexes, }) assert.Nil(t, err) - //assert.NotEmpty(t, logLinks) + assert.NotEmpty(t, logLinks) + assert.Equal(t, 10, len(logLinks)) + for i := 0; i < 10; i = i + 2 { + assert.Equal(t, fmt.Sprintf("Kubernetes Logs #%d-0 (PhaseRunning)", i/2), logLinks[i].Name) + assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d/pod?namespace=a-n-b", i/2), logLinks[i].Uri) + + assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #%d-0 (PhaseRunning)", i/2), logLinks[i+1].Name) + assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri) + } + p, _ := newState.GetPhase() assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String()) resourceManager.AssertNumberOfCalls(t, "AllocateResource", 0) @@ -177,11 +228,13 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { }, } + cacheIndexes := bitarray.NewBitSet(5) newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, OriginalMinSuccesses: 5, + IndexesToCache: cacheIndexes, ArrayStatus: arraystatus.ArrayStatus{ Detailed: arrayCore.NewPhasesCompactArray(uint(5)), }, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/task.go b/flyteplugins/go/tasks/plugins/array/k8s/task.go index 22569d7db1..a0e5537d77 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/task.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/task.go @@ -2,9 +2,12 @@ package k8s import ( "context" + "fmt" "strconv" "strings" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -49,7 +52,7 @@ const ( ) func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (LaunchResult, error) { - podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx) + podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, t.Config.NamespaceTemplate) if err != nil { return LaunchError, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for a task") } @@ -129,20 +132,31 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl return LaunchSuccess, nil } -func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { +func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, + logPlugin tasklog.Plugin) (MonitorResult, error) { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) t.SubTaskIDs = append(t.SubTaskIDs, &podName) - phaseInfo, err := CheckPodStatus(ctx, kubeClient, + + // Use original-index for log-name/links + originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) + phaseInfo, err := FetchPodStatusAndLogs(ctx, kubeClient, k8sTypes.NamespacedName{ Name: podName, - Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), - }) + Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate), + }, + originalIdx, + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, + logPlugin) if err != nil { return MonitorError, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") } if phaseInfo.Info() != nil { + // Append sub-job status in Log Name for viz. + for _, log := range phaseInfo.Info().Logs { + log.Name += fmt.Sprintf(" (%s)", phaseInfo.Phase().String()) + } t.LogLinks = append(t.LogLinks, phaseInfo.Info().Logs...) } @@ -152,7 +166,6 @@ func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kube actualPhase := phaseInfo.Phase() if phaseInfo.Phase().IsSuccess() { - originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, t.ChildIdx, originalIdx) if err != nil { return MonitorError, err @@ -175,7 +188,7 @@ func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeCli }, ObjectMeta: metav1.ObjectMeta{ Name: podName, - Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), + Namespace: GetNamespaceForExecution(tCtx, t.Config.NamespaceTemplate), }, } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/transformer.go b/flyteplugins/go/tasks/plugins/array/k8s/transformer.go index 9382bdeef8..127e31b038 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/transformer.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/transformer.go @@ -2,6 +2,7 @@ package k8s import ( "context" + "regexp" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -19,6 +20,8 @@ import ( const PodKind = "pod" +var namespaceRegex = regexp.MustCompile("(?i){{.namespace}}(?i)") + type arrayTaskContext struct { core.TaskExecutionContext arrayInputReader io.InputReader @@ -29,9 +32,23 @@ func (a *arrayTaskContext) InputReader() io.InputReader { return a.arrayInputReader } +func GetNamespaceForExecution(tCtx core.TaskExecutionContext, namespaceTemplate string) string { + + // Default to parent namespace + namespace := tCtx.TaskExecutionMetadata().GetNamespace() + if namespaceTemplate != "" { + if namespaceRegex.MatchString(namespaceTemplate) { + namespace = namespaceRegex.ReplaceAllString(namespaceTemplate, namespace) + } else { + namespace = namespaceTemplate + } + } + return namespace +} + // Note that Name is not set on the result object. // It's up to the caller to set the Name before creating the object in K8s. -func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext) ( +func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext, namespaceTemplate string) ( podTemplate v1.Pod, job *idlPlugins.ArrayJob, err error) { // Check that the taskTemplate is valid @@ -71,7 +88,7 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC }, ObjectMeta: metav1.ObjectMeta{ // Note that name is missing here - Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), + Namespace: GetNamespaceForExecution(tCtx, namespaceTemplate), Labels: tCtx.TaskExecutionMetadata().GetLabels(), Annotations: tCtx.TaskExecutionMetadata().GetAnnotations(), OwnerReferences: []metav1.OwnerReference{tCtx.TaskExecutionMetadata().GetOwnerReference()},