Skip to content

Commit

Permalink
Supporting interruptible for map tasks (flyteorg#253)
Browse files Browse the repository at this point in the history
* implemented IsInterruptible for SubTaskExecutionMetadata

Signed-off-by: Daniel Rammer <[email protected]>

* fixed possible race condition

Signed-off-by: Daniel Rammer <[email protected]>

* fixed unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issue

Signed-off-by: Daniel Rammer <[email protected]>

* updated TODO documentation

Signed-off-by: Daniel Rammer <[email protected]>

* changed context on NewCompactArray error log

Signed-off-by: Daniel Rammer <[email protected]>

* fixing retry attempt calculation on abort

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Apr 8, 2022
1 parent 6f534eb commit cc2b8c6
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 62 deletions.
1 change: 1 addition & 0 deletions go/tasks/pluginmachinery/core/exec_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ type TaskExecutionMetadata interface {
GetSecurityContext() core.SecurityContext
IsInterruptible() bool
GetPlatformResources() *v1.ResourceRequirements
GetInterruptibleFailureThreshold() uint32
}
32 changes: 32 additions & 0 deletions go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 4 additions & 16 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) {
UpdatePodWithInterruptibleFlag(taskExecutionMetadata, resourceRequirements, podSpec, false)
}

// UpdatePodWithInterruptibleFlag updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec, omitInterruptible bool) {
isInterruptible := !omitInterruptible && taskExecutionMetadata.IsInterruptible()
if len(podSpec.RestartPolicy) == 0 {
podSpec.RestartPolicy = v1.RestartPolicyNever
}
podSpec.Tolerations = append(
GetPodTolerations(isInterruptible, resourceRequirements...), podSpec.Tolerations...)
GetPodTolerations(taskExecutionMetadata.IsInterruptible(), resourceRequirements...), podSpec.Tolerations...)

if len(podSpec.ServiceAccountName) == 0 {
podSpec.ServiceAccountName = taskExecutionMetadata.GetK8sServiceAccount()
Expand All @@ -83,7 +76,7 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut
podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName
}
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector)
if isInterruptible {
if taskExecutionMetadata.IsInterruptible() {
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector)
}
if podSpec.Affinity == nil && config.GetK8sPluginConfig().DefaultAffinity != nil {
Expand All @@ -98,16 +91,11 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut
if podSpec.DNSConfig == nil && config.GetK8sPluginConfig().DefaultPodDNSConfig != nil {
podSpec.DNSConfig = config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy()
}
ApplyInterruptibleNodeAffinity(isInterruptible, podSpec)
ApplyInterruptibleNodeAffinity(taskExecutionMetadata.IsInterruptible(), podSpec)
}

// ToK8sPodSpec constructs a pod spec from the given TaskTemplate
func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) {
return ToK8sPodSpecWithInterruptible(ctx, tCtx, false)
}

// ToK8sPodSpecWithInterruptible constructs a pod spec from the gien TaskTemplate and optionally add (interruptible instance) support.
func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, omitInterruptible bool) (*v1.PodSpec, error) {
task, err := tCtx.TaskReader().Read(ctx)
if err != nil {
logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error())
Expand Down Expand Up @@ -138,7 +126,7 @@ func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExe
pod := &v1.PodSpec{
Containers: containers,
}
UpdatePodWithInterruptibleFlag(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod, omitInterruptible)
UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod)

if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil {
return nil, err
Expand Down
38 changes: 0 additions & 38 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ func TestPodSetup(t *testing.T) {
t.Run("ApplyInterruptibleNodeAffinity", TestApplyInterruptibleNodeAffinity)
t.Run("UpdatePod", updatePod)
t.Run("ToK8sPodInterruptible", toK8sPodInterruptible)
t.Run("toK8sPodInterruptibleFalse", toK8sPodInterruptibleFalse)
}

func TestApplyInterruptibleNodeAffinity(t *testing.T) {
Expand Down Expand Up @@ -349,43 +348,6 @@ func toK8sPodInterruptible(t *testing.T) {
)
}

func toK8sPodInterruptibleFalse(t *testing.T) {
ctx := context.TODO()

x := dummyExecContext(&v1.ResourceRequirements{
Limits: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
ResourceNvidiaGPU: resource.MustParse("1"),
},
Requests: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
},
})

p, err := ToK8sPodSpecWithInterruptible(ctx, x, true)
assert.NoError(t, err)
assert.Len(t, p.Tolerations, 1)
assert.Equal(t, 0, len(p.NodeSelector))
assert.Equal(t, "", p.NodeSelector["x/interruptible"])
assert.NotEqualValues(
t,
[]v1.NodeSelectorTerm{
v1.NodeSelectorTerm{
MatchExpressions: []v1.NodeSelectorRequirement{
v1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: v1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
p.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms,
)
}

func TestToK8sPod(t *testing.T) {
ctx := context.TODO()

Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ type State struct {

// Tracks the number of subtask retries using the execution index
RetryAttempts bitarray.CompactArray `json:"retryAttempts"`

// Tracks the number of system failures for each subtask using the execution index
SystemFailures bitarray.CompactArray `json:"systemFailures"`
}

func (s State) GetReason() string {
Expand Down
43 changes: 39 additions & 4 deletions go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon

retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue)
if err != nil {
logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue)
logger.Errorf(ctx, "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue)
return currentState, externalResources, nil
}

Expand All @@ -106,6 +106,26 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
currentState.RetryAttempts = retryAttemptsArray
}

// If the current State is newly minted then we must initialize SystemFailures to track how many
// times the subtask failed due to system issues, this is necessary to correctly evaluate
// interruptible subtasks.
if len(currentState.SystemFailures.GetItems()) == 0 {
count := uint(currentState.GetExecutionArraySize())
maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetInterruptibleFailureThreshold())

systemFailuresArray, err := bitarray.NewCompactArray(count, maxValue)
if err != nil {
logger.Errorf(ctx, "Failed to create system failures array with [count: %v, maxValue: %v]", count, maxValue)
return currentState, externalResources, err
}

for i := 0; i < currentState.GetExecutionArraySize(); i++ {
systemFailuresArray.SetItem(i, 0)
}

currentState.SystemFailures = systemFailuresArray
}

// initialize log plugin
logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config)
if err != nil {
Expand Down Expand Up @@ -146,7 +166,8 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
}

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt)
systemFailures := currentState.SystemFailures.GetItem(childIdx)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures)
if err != nil {
return currentState, externalResources, err
}
Expand Down Expand Up @@ -188,6 +209,16 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return currentState, externalResources, perr
}

if phaseInfo.Err() != nil {
messageCollector.Collect(childIdx, phaseInfo.Err().String())
}

if phaseInfo.Err() != nil && phaseInfo.Err().GetKind() == idlCore.ExecutionError_SYSTEM {
newState.SystemFailures.SetItem(childIdx, systemFailures+1)
} else {
newState.SystemFailures.SetItem(childIdx, systemFailures)
}

// process subtask phase
actualPhase := phaseInfo.Phase()
if actualPhase.IsSuccess() {
Expand Down Expand Up @@ -294,15 +325,19 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
messageCollector := errorcollector.NewErrorMessageCollector()
for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() {
existingPhase := core.Phases[existingPhaseIdx]
retryAttempt := currentState.RetryAttempts.GetItem(childIdx)
retryAttempt := uint64(0)
if childIdx < len(currentState.RetryAttempts.GetItems()) {
// we can use RetryAttempts if it has been initialized, otherwise stay with default 0
retryAttempt = currentState.RetryAttempts.GetItem(childIdx)
}

// return immediately if subtask has completed or not yet started
if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined {
continue
}

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
tMeta.OnGetAnnotations().Return(nil)
tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)

ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/prefix/")
Expand Down
15 changes: 12 additions & 3 deletions go/tasks/plugins/array/k8s/subtask_exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader {

// NewSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters
func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate,
executionIndex, originalIndex int, retryAttempt uint64) (SubTaskExecutionContext, error) {
executionIndex, originalIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionContext, error) {

subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt)
subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt, systemFailures)
if err != nil {
return SubTaskExecutionContext{}, err
}
Expand Down Expand Up @@ -135,6 +135,7 @@ type SubTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
annotations map[string]string
labels map[string]string
interruptible bool
subtaskExecutionID SubTaskExecutionID
}

Expand All @@ -153,8 +154,14 @@ func (s SubTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecution
return s.subtaskExecutionID
}

// IsInterruptbile overrides the base NodeExecutionMetadata to return a subtask specific identifier
func (s SubTaskExecutionMetadata) IsInterruptible() bool {
return s.interruptible
}

// NewSubtaskExecutionMetadata constructs a SubTaskExecutionMetadata using the provided parameters
func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate, executionIndex int, retryAttempt uint64) (SubTaskExecutionMetadata, error) {
func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate,
executionIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionMetadata, error) {

var err error
secretsMap := make(map[string]string)
Expand All @@ -171,10 +178,12 @@ func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecution
}

subTaskExecutionID := NewSubTaskExecutionID(taskExecutionMetadata.GetTaskExecutionID(), executionIndex, retryAttempt)
interruptible := taskExecutionMetadata.IsInterruptible() && uint32(systemFailures) < taskExecutionMetadata.GetInterruptibleFailureThreshold()
return SubTaskExecutionMetadata{
taskExecutionMetadata,
utils.UnionMaps(taskExecutionMetadata.GetAnnotations(), secretsMap),
utils.UnionMaps(taskExecutionMetadata.GetLabels(), injectSecretsLabel),
interruptible,
subTaskExecutionID,
}, nil
}
3 changes: 2 additions & 1 deletion go/tasks/plugins/array/k8s/subtask_exec_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ func TestSubTaskExecutionContext(t *testing.T) {
executionIndex := 0
originalIndex := 5
retryAttempt := uint64(1)
systemFailures := uint64(0)

stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures)
assert.Nil(t, err)

assert.Equal(t, fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt), stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
Expand Down
1 change: 1 addition & 0 deletions tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
Name: execID,
})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)

catClient := &catalogMocks.Client{}
catData := sync.Map{}
Expand Down

0 comments on commit cc2b8c6

Please sign in to comment.