diff --git a/go/tasks/plugins/array/k8s/config.go b/go/tasks/plugins/array/k8s/config.go index abcf74c54..c7e5d515c 100644 --- a/go/tasks/plugins/array/k8s/config.go +++ b/go/tasks/plugins/array/k8s/config.go @@ -44,7 +44,7 @@ type Config struct { DefaultScheduler string `json:"scheduler" pflag:",Decides the scheduler to use when launching array-pods."` MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."` MaxArrayJobSize int64 `json:"maxArrayJobSize" pflag:",Maximum size of array job."` - ResourceConfig ResourceConfig `json:"resourceConfig" pflag:"-,ResourceConfiguration to limit number of resources used by k8s-array."` + ResourceConfig ResourceConfig `json:"resourceConfig" pflag:",ResourceConfiguration to limit number of resources used by k8s-array."` NodeSelector map[string]string `json:"node-selector" pflag:"-,Defines a set of node selector labels to add to the pod."` Tolerations []v1.Toleration `json:"tolerations" pflag:"-,Tolerations to be applied for k8s-array pods"` OutputAssembler workqueue.Config diff --git a/go/tasks/plugins/array/k8s/config_flags.go b/go/tasks/plugins/array/k8s/config_flags.go index 4a03aefc9..7bb1404de 100755 --- a/go/tasks/plugins/array/k8s/config_flags.go +++ b/go/tasks/plugins/array/k8s/config_flags.go @@ -44,6 +44,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "scheduler"), defaultConfig.DefaultScheduler, "Decides the scheduler to use when launching array-pods.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxErrLength"), defaultConfig.MaxErrorStringLength, "Determines the maximum length of the error string returned for the array.") cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "maxArrayJobSize"), defaultConfig.MaxArrayJobSize, "Maximum size of array job.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "resourceConfig.primaryLabel"), defaultConfig.ResourceConfig.PrimaryLabel, "PrimaryLabel of a given service cluster") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "resourceConfig.limit"), defaultConfig.ResourceConfig.Limit, "Resource quota (in the number of outstanding requests) for the cluster") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.workers"), defaultConfig.OutputAssembler.Workers, "Number of concurrent workers to start processing the queue.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.maxRetries"), defaultConfig.OutputAssembler.MaxRetries, "Maximum number of retries per item.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "OutputAssembler.maxItems"), defaultConfig.OutputAssembler.IndexCacheMaxItems, "Maximum number of entries to keep in the index.") diff --git a/go/tasks/plugins/array/k8s/config_flags_test.go b/go/tasks/plugins/array/k8s/config_flags_test.go index df7b41909..d0ffbdb49 100755 --- a/go/tasks/plugins/array/k8s/config_flags_test.go +++ b/go/tasks/plugins/array/k8s/config_flags_test.go @@ -165,6 +165,50 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_resourceConfig.primaryLabel", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("resourceConfig.primaryLabel"); err == nil { + assert.Equal(t, string(defaultConfig.ResourceConfig.PrimaryLabel), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("resourceConfig.primaryLabel", testValue) + if vString, err := cmdFlags.GetString("resourceConfig.primaryLabel"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ResourceConfig.PrimaryLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_resourceConfig.limit", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("resourceConfig.limit"); err == nil { + assert.Equal(t, int(defaultConfig.ResourceConfig.Limit), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("resourceConfig.limit", testValue) + if vInt, err := cmdFlags.GetInt("resourceConfig.limit"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ResourceConfig.Limit) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_OutputAssembler.workers", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/go/tasks/plugins/array/k8s/launcher.go b/go/tasks/plugins/array/k8s/launcher.go index 919b956ee..5af0a6965 100644 --- a/go/tasks/plugins/array/k8s/launcher.go +++ b/go/tasks/plugins/array/k8s/launcher.go @@ -74,6 +74,10 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube if err != nil { errs.Collect(childIdx, err.Error()) } + err = task.Abort(ctx, tCtx, kubeClient) + if err != nil { + errs.Collect(childIdx, err.Error()) + } } if errs.Length() > 0 { diff --git a/go/tasks/plugins/array/k8s/monitor.go b/go/tasks/plugins/array/k8s/monitor.go index e1fdf5179..38ace640c 100644 --- a/go/tasks/plugins/array/k8s/monitor.go +++ b/go/tasks/plugins/array/k8s/monitor.go @@ -90,27 +90,28 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon // The first time we enter this state we will launch every subtask. On subsequent rounds, the pod // has already been created so we return a Success value and continue with the Monitor step. - var status TaskStatus - status, err = task.Launch(ctx, tCtx, kubeClient) + var launchResult LaunchResult + launchResult, err = task.Launch(ctx, tCtx, kubeClient) if err != nil { return currentState, logLinks, err } - switch status { - case Success: + switch launchResult { + case LaunchSuccess: // Continue with execution if successful - case Error: + case LaunchError: return currentState, logLinks, err // If Resource manager is enabled and there are currently not enough resources we can skip this round // for a subtask and wait until there are enough resources. - case Waiting: + case LaunchWaiting: continue - case ReturnState: + case LaunchReturnState: return currentState, logLinks, nil } - status, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) - if status != Success { + var monitorResult MonitorResult + monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox) + if monitorResult != MonitorSuccess { return currentState, logLinks, err } diff --git a/go/tasks/plugins/array/k8s/task.go b/go/tasks/plugins/array/k8s/task.go index 5e0a34c10..c52610869 100644 --- a/go/tasks/plugins/array/k8s/task.go +++ b/go/tasks/plugins/array/k8s/task.go @@ -31,19 +31,25 @@ type Task struct { MessageCollector *errorcollector.ErrorMessageCollector } -type TaskStatus int8 +type LaunchResult int8 +type MonitorResult int8 const ( - Success TaskStatus = iota - Error - Waiting - ReturnState + LaunchSuccess LaunchResult = iota + LaunchError + LaunchWaiting + LaunchReturnState ) -func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (TaskStatus, error) { +const ( + MonitorSuccess MonitorResult = iota + MonitorError +) + +func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (LaunchResult, error) { podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx) if err != nil { - return Error, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for a task") + return LaunchError, errors2.Wrapf(ErrBuildPodTemplate, err, "Failed to convert task template to a pod template for a task") } var args []string @@ -51,7 +57,7 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl args = append(podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args...) podTemplate.Spec.Containers[0].Command = []string{} } else { - return Error, errors2.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.") + return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.") } indexStr := strconv.Itoa(t.ChildIdx) @@ -67,7 +73,7 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...) pod.Spec.Containers[0].Args, err = utils.ReplaceTemplateCommandArgs(ctx, args, arrayJobInputReader{tCtx.InputReader()}, tCtx.OutputWriter()) if err != nil { - return Error, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") + return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args") } pod = ApplyPodPolicies(ctx, t.Config, pod) @@ -76,12 +82,12 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl allocationStatus, err := allocateResource(ctx, tCtx, t.Config, podName) if err != nil { - return Error, err + return LaunchError, err } if allocationStatus != core.AllocationStatusGranted { t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(core.PhaseWaitingForResources)) t.NewArrayStatus.Summary.Inc(core.PhaseWaitingForResources) - return Waiting, nil + return LaunchWaiting, nil } err = kubeClient.GetClient().Create(ctx, pod) @@ -96,16 +102,16 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl } t.State = t.State.SetReason(err.Error()) - return ReturnState, nil + return LaunchReturnState, nil } - return Error, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") + return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") } - return Success, nil + return LaunchSuccess, nil } -func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (TaskStatus, error) { +func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference) (MonitorResult, error) { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) phaseInfo, err := CheckPodStatus(ctx, kubeClient, @@ -114,7 +120,7 @@ func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeC Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), }) if err != nil { - return Error, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") + return MonitorError, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") } if phaseInfo.Info() != nil { @@ -130,19 +136,17 @@ func (t Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeC originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, t.ChildIdx, originalIdx) if err != nil { - return Error, err + return MonitorError, err } } t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(actualPhase)) t.NewArrayStatus.Summary.Inc(actualPhase) - return Success, nil + return MonitorSuccess, nil } -func (t Task) Abort() {} - -func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { +func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) pod := &corev1.Pod{ @@ -182,6 +186,21 @@ func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kube } +func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error { + indexStr := strconv.Itoa(t.ChildIdx) + podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + + // Deallocate Resource + err := deallocateResource(ctx, tCtx, t.Config, t.ChildIdx) + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err) + return err + } + + return nil + +} + func allocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) (core.AllocationStatus, error) { if !IsResourceConfigSet() { return core.AllocationStatusGranted, nil