diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/fake_k8s_cache.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/fake_k8s_cache.go new file mode 100644 index 0000000000..5b67e20d02 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/fake_k8s_cache.go @@ -0,0 +1,173 @@ +package mocks + +import ( + "context" + "fmt" + "reflect" + "sync" + + "sigs.k8s.io/controller-runtime/pkg/cache" + + "k8s.io/apimachinery/pkg/api/meta" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type FakeKubeCache struct { + syncObj sync.RWMutex + Cache map[string]runtime.Object +} + +func (m *FakeKubeCache) GetInformer(ctx context.Context, obj client.Object) (cache.Informer, error) { + panic("implement me") +} + +func (m *FakeKubeCache) GetInformerForKind(ctx context.Context, gvk schema.GroupVersionKind) (cache.Informer, error) { + panic("implement me") +} + +func (m *FakeKubeCache) Start(ctx context.Context) error { + panic("implement me") +} + +func (m *FakeKubeCache) WaitForCacheSync(ctx context.Context) bool { + panic("implement me") +} + +func (m *FakeKubeCache) IndexField(ctx context.Context, obj client.Object, field string, extractValue client.IndexerFunc) error { + panic("implement me") +} + +func (m *FakeKubeCache) Get(ctx context.Context, key client.ObjectKey, out client.Object) error { + m.syncObj.RLock() + defer m.syncObj.RUnlock() + + item, found := m.Cache[formatKey(key, out.GetObjectKind().GroupVersionKind())] + if found { + // deep copy to avoid mutating cache + item = item.(runtime.Object).DeepCopyObject() + _, isUnstructured := out.(*unstructured.Unstructured) + if isUnstructured { + // Copy the value of the item in the cache to the returned value + outVal := reflect.ValueOf(out) + objVal := reflect.ValueOf(item) + if !objVal.Type().AssignableTo(outVal.Type()) { + return fmt.Errorf("cache had type %s, but %s was asked for", objVal.Type(), outVal.Type()) + } + reflect.Indirect(outVal).Set(reflect.Indirect(objVal)) + return nil + } + + p, err := runtime.DefaultUnstructuredConverter.ToUnstructured(item) + if err != nil { + return err + } + + return runtime.DefaultUnstructuredConverter.FromUnstructured(p, out) + } + + return errors.NewNotFound(schema.GroupResource{}, key.Name) +} + +func (m *FakeKubeCache) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + m.syncObj.RLock() + defer m.syncObj.RUnlock() + + objs := make([]runtime.Object, 0, len(m.Cache)) + + listOptions := &client.ListOptions{} + for _, opt := range opts { + opt.ApplyToList(listOptions) + } + + for _, val := range m.Cache { + if listOptions.Raw != nil { + if val.GetObjectKind().GroupVersionKind().Kind != listOptions.Raw.Kind { + continue + } + + if val.GetObjectKind().GroupVersionKind().GroupVersion().String() != listOptions.Raw.APIVersion { + continue + } + } + + objs = append(objs, val.(runtime.Object).DeepCopyObject()) + } + + return meta.SetList(list, objs) +} + +func (m *FakeKubeCache) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) (err error) { + m.syncObj.Lock() + defer m.syncObj.Unlock() + + accessor, err := meta.Accessor(obj) + if err != nil { + return err + } + + key := formatKey(types.NamespacedName{ + Name: accessor.GetName(), + Namespace: accessor.GetNamespace(), + }, obj.GetObjectKind().GroupVersionKind()) + + if _, exists := m.Cache[key]; !exists { + m.Cache[key] = obj + return nil + } + + return errors.NewAlreadyExists(schema.GroupResource{}, accessor.GetName()) +} + +func (m *FakeKubeCache) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + m.syncObj.Lock() + defer m.syncObj.Unlock() + + accessor, err := meta.Accessor(obj) + if err != nil { + return err + } + + key := formatKey(types.NamespacedName{ + Name: accessor.GetName(), + Namespace: accessor.GetNamespace(), + }, obj.GetObjectKind().GroupVersionKind()) + + delete(m.Cache, key) + + return nil +} + +func (m *FakeKubeCache) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + m.syncObj.Lock() + defer m.syncObj.Unlock() + + accessor, err := meta.Accessor(obj) + if err != nil { + return err + } + + key := formatKey(types.NamespacedName{ + Name: accessor.GetName(), + Namespace: accessor.GetNamespace(), + }, obj.GetObjectKind().GroupVersionKind()) + + if _, exists := m.Cache[key]; exists { + m.Cache[key] = obj + return nil + } + + return errors.NewNotFound(schema.GroupResource{}, accessor.GetName()) +} + +func NewFakeKubeCache() *FakeKubeCache { + return &FakeKubeCache{ + syncObj: sync.RWMutex{}, + Cache: map[string]runtime.Object{}, + } +} diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index e89cfc8b2d..b6a06bb3e2 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -152,13 +152,20 @@ func PhaseInfoNotReady(t time.Time, version uint32, reason string) PhaseInfo { return pi } -// Return in the case the plugin is not ready to start +// Deprecated: Please use PhaseInfoWaitingForResourcesInfo instead func PhaseInfoWaitingForResources(t time.Time, version uint32, reason string) PhaseInfo { pi := phaseInfo(PhaseWaitingForResources, version, nil, &TaskInfo{OccurredAt: &t}) pi.reason = reason return pi } +// Return in the case the plugin is not ready to start +func PhaseInfoWaitingForResourcesInfo(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { + pi := phaseInfo(PhaseWaitingForResources, version, nil, info) + pi.reason = reason + return pi +} + func PhaseInfoQueued(t time.Time, version uint32, reason string) PhaseInfo { pi := phaseInfo(PhaseQueued, version, nil, &TaskInfo{OccurredAt: &t}) pi.reason = reason diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 2f8d311524..4bca673ac2 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -200,7 +200,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo) case PhaseWaitingForResources: - phaseInfo = core.PhaseInfoWaitingForResources(t, version, state.GetReason()) + phaseInfo = core.PhaseInfoWaitingForResourcesInfo(t, version, state.GetReason(), nowTaskInfo) case PhaseCheckingSubTaskExecutions: // For future Running core.Phases, we have to make sure we don't use an earlier Admin version number, diff --git a/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go b/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go index 9c30b15749..9b76265c5c 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/integration_test.go @@ -37,6 +37,7 @@ func init() { func newMockExecutor(ctx context.Context, t testing.TB) (Executor, array.AdvanceIteration) { kubeClient := &mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) e, err := NewExecutor(kubeClient, &Config{ MaxErrorStringLength: 200, OutputAssembler: workqueue.Config{ diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go index 322cd8df28..b938b3b76c 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor.go @@ -71,6 +71,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon existingPhase := core.Phases[existingPhaseIdx] indexStr := strconv.Itoa(childIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) + if existingPhase.IsTerminal() { // If we get here it means we have already "processed" this terminal phase since we will only persist // the phase after all processing is done (e.g. check outputs/errors file, record events... etc.). @@ -83,13 +84,29 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } newArrayStatus.Summary.Inc(existingPhase) newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase)) + originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) + + phaseInfo, err := FetchPodStatusAndLogs(ctx, kubeClient, + k8sTypes.NamespacedName{ + Name: podName, + Namespace: GetNamespaceForExecution(tCtx, config.NamespaceTemplate), + }, + originalIdx, + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, + logPlugin) + + if err != nil { + return currentState, logLinks, subTaskIDs, err + } + + if phaseInfo.Info() != nil { + logLinks = append(logLinks, phaseInfo.Info().Logs...) + } - // TODO: collect log links before doing this continue } task := &Task{ - LogLinks: logLinks, State: newState, NewArrayStatus: newArrayStatus, Config: config, @@ -97,7 +114,6 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon MessageCollector: &msg, SubTaskIDs: subTaskIDs, } - // The first time we enter this state we will launch every subtask. On subsequent rounds, the pod // has already been created so we return a Success value and continue with the Monitor step. var launchResult LaunchResult @@ -121,8 +137,11 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } var monitorResult MonitorResult - monitorResult, err = task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox, logPlugin) - logLinks = task.LogLinks + monitorResult, taskLogs, err := task.Monitor(ctx, tCtx, kubeClient, dataStore, outputPrefix, baseOutputDataSandbox, logPlugin) + + if len(taskLogs) > 0 { + logLinks = append(logLinks, taskLogs...) + } subTaskIDs = task.SubTaskIDs if monitorResult != MonitorSuccess { @@ -214,20 +233,35 @@ func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8s taskInfo.Logs = o.TaskLogs } } + + var phaseInfo core.PhaseInfo + var err2 error + switch pod.Status.Phase { case v1.PodSucceeded: - return flytek8s.DemystifySuccess(pod.Status, taskInfo) + phaseInfo, err2 = flytek8s.DemystifySuccess(pod.Status, taskInfo) case v1.PodFailed: code, message := flytek8s.ConvertPodFailureToError(pod.Status) - return core.PhaseInfoRetryableFailure(code, message, &taskInfo), nil + phaseInfo = core.PhaseInfoRetryableFailure(code, message, &taskInfo) case v1.PodPending: - return flytek8s.DemystifyPending(pod.Status) + phaseInfo, err2 = flytek8s.DemystifyPending(pod.Status) case v1.PodUnknown: - return core.PhaseInfoUndefined, nil + phaseInfo = core.PhaseInfoUndefined + default: + if len(taskInfo.Logs) > 0 { + phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion+1, &taskInfo) + } else { + phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, &taskInfo) + } } - if len(taskInfo.Logs) > 0 { - return core.PhaseInfoRunning(core.DefaultPhaseVersion+1, &taskInfo), nil + + if err2 == nil && phaseInfo.Info() != nil { + // Append sub-job status in Log Name for viz. + for _, log := range phaseInfo.Info().Logs { + log.Name += fmt.Sprintf(" (%s)", phaseInfo.Phase().String()) + } } - return core.PhaseInfoRunning(core.DefaultPhaseVersion, &taskInfo), nil + + return phaseInfo, err2 } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go index a6290cde76..b261f1dcfe 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/monitor_test.go @@ -121,6 +121,8 @@ func TestCheckSubTasksState(t *testing.T) { tCtx := getMockTaskExecutionContext(ctx) kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + resourceManager := mocks.ResourceManager{} resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusExhausted, nil) tCtx.OnResourceManager().Return(&resourceManager) @@ -213,6 +215,8 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { tCtx := getMockTaskExecutionContext(ctx) kubeClient := mocks.KubeClient{} kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient()) + kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache()) + resourceManager := mocks.ResourceManager{} resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil) @@ -264,12 +268,14 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { arrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseSuccess)) } + cacheIndexes := bitarray.NewBitSet(5) newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, OriginalArraySize: 10, OriginalMinSuccesses: 5, ArrayStatus: *arrayStatus, + IndexesToCache: cacheIndexes, }) assert.Nil(t, err) diff --git a/flyteplugins/go/tasks/plugins/array/k8s/task.go b/flyteplugins/go/tasks/plugins/array/k8s/task.go index a0e5537d77..3066f24082 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/task.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/task.go @@ -2,15 +2,16 @@ package k8s import ( "context" - "fmt" "strconv" "strings" + "sigs.k8s.io/controller-runtime/pkg/client" + + idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" @@ -27,7 +28,6 @@ import ( ) type Task struct { - LogLinks []*idlCore.TaskLog State *arrayCore.State NewArrayStatus *arraystatus.ArrayStatus Config *Config @@ -111,21 +111,35 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl return LaunchWaiting, nil } - err = kubeClient.GetClient().Create(ctx, pod) - if err != nil && !k8serrors.IsAlreadyExists(err) { - if k8serrors.IsForbidden(err) { - if strings.Contains(err.Error(), "exceeded quota") { - // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. - logger.Infof(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) - t.State = t.State.SetPhase(arrayCore.PhaseWaitingForResources, 0).SetReason("Not enough resources to launch job") - } else { - t.State = t.State.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to launch job.") + // Check for existing pods to prevent unnecessary Resource-Quota usage: https://github.com/kubernetes/kubernetes/issues/76787 + existingPod := &corev1.Pod{} + err = kubeClient.GetCache().Get(ctx, client.ObjectKey{ + Namespace: pod.GetNamespace(), + Name: pod.GetName(), + }, existingPod) + + if err != nil && k8serrors.IsNotFound(err) { + // Attempt creating non-existing pod. + err = kubeClient.GetClient().Create(ctx, pod) + if err != nil && !k8serrors.IsAlreadyExists(err) { + if k8serrors.IsForbidden(err) { + if strings.Contains(err.Error(), "exceeded quota") { + // TODO: Quota errors are retried forever, it would be good to have support for backoff strategy. + logger.Infof(ctx, "Failed to launch job, resource quota exceeded. Err: %v", err) + t.State = t.State.SetPhase(arrayCore.PhaseWaitingForResources, 0).SetReason("Not enough resources to launch job") + } else { + t.State = t.State.SetPhase(arrayCore.PhaseRetryableFailure, 0).SetReason("Failed to launch job.") + } + + t.State = t.State.SetReason(err.Error()) + return LaunchReturnState, nil } - t.State = t.State.SetReason(err.Error()) - return LaunchReturnState, nil + return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") } - + } else if err != nil { + // Another error returned. + logger.Error(ctx, err) return LaunchError, errors2.Wrapf(ErrSubmitJob, err, "Failed to submit job.") } @@ -133,10 +147,11 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl } func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference, - logPlugin tasklog.Plugin) (MonitorResult, error) { + logPlugin tasklog.Plugin) (MonitorResult, []*idlCore.TaskLog, error) { indexStr := strconv.Itoa(t.ChildIdx) podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr) t.SubTaskIDs = append(t.SubTaskIDs, &podName) + var loglinks []*idlCore.TaskLog // Use original-index for log-name/links originalIdx := arrayCore.CalculateOriginalIndex(t.ChildIdx, t.State.GetIndexesToCache()) @@ -149,15 +164,11 @@ func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kube tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, logPlugin) if err != nil { - return MonitorError, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") + return MonitorError, loglinks, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.") } if phaseInfo.Info() != nil { - // Append sub-job status in Log Name for viz. - for _, log := range phaseInfo.Info().Logs { - log.Name += fmt.Sprintf(" (%s)", phaseInfo.Phase().String()) - } - t.LogLinks = append(t.LogLinks, phaseInfo.Info().Logs...) + loglinks = phaseInfo.Info().Logs } if phaseInfo.Err() != nil { @@ -168,14 +179,14 @@ func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kube if phaseInfo.Phase().IsSuccess() { actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputDataSandbox, t.ChildIdx, originalIdx) if err != nil { - return MonitorError, err + return MonitorError, loglinks, err } } t.NewArrayStatus.Detailed.SetItem(t.ChildIdx, bitarray.Item(actualPhase)) t.NewArrayStatus.Summary.Inc(actualPhase) - return MonitorSuccess, nil + return MonitorSuccess, loglinks, nil } func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error {