diff --git a/go.mod b/go.mod index a2e9a134b..7a5caddbf 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyteidl v1.3.14 - github.com/flyteorg/flyteplugins v1.0.43 + github.com/flyteorg/flyteplugins v1.0.44 github.com/flyteorg/flytestdlib v1.0.15 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/go.sum b/go.sum index 92e4c9d37..6b4e8cba5 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYF github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= -github.com/flyteorg/flyteplugins v1.0.43 h1:uI/Y88xqJKfvfuxfu0Sw9CNZ7iu3+HUwwRhxh558cbs= -github.com/flyteorg/flyteplugins v1.0.43/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= +github.com/flyteorg/flyteplugins v1.0.44 h1:uKizng+i0vfXslyPBlrsfecInhvy71fTB4kRg7eiifE= +github.com/flyteorg/flyteplugins v1.0.44/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/pkg/controller/nodes/task/k8s/plugin_context.go b/pkg/controller/nodes/task/k8s/plugin_context.go index cb90edfb3..aed5bc468 100644 --- a/pkg/controller/nodes/task/k8s/plugin_context.go +++ b/pkg/controller/nodes/task/k8s/plugin_context.go @@ -2,6 +2,7 @@ package k8s import ( "context" + "fmt" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -15,7 +16,8 @@ var _ k8s.PluginContext = &pluginContext{} type pluginContext struct { pluginsCore.TaskExecutionContext // Lazily creates a buffered outputWriter, overriding the input outputWriter. - ow *ioutils.BufferedOutputWriter + ow *ioutils.BufferedOutputWriter + k8sPluginState *k8s.PluginState } // Provides an output sync of type io.OutputWriter @@ -26,9 +28,38 @@ func (p *pluginContext) OutputWriter() io.OutputWriter { return buf } -func newPluginContext(tCtx pluginsCore.TaskExecutionContext) *pluginContext { +// pluginStateReader overrides the default PluginStateReader to return a pre-assigned PluginState. This allows us to +// encapsulate plugin state persistence in the existing k8s PluginManager and only expose the ability to read the +// previous Phase, PhaseVersion, and Reason for all k8s plugins. +type pluginStateReader struct { + k8sPluginState *k8s.PluginState +} + +func (p pluginStateReader) GetStateVersion() uint8 { + return 0 +} + +func (p pluginStateReader) Get(t interface{}) (stateVersion uint8, err error) { + if pointer, ok := t.(*k8s.PluginState); ok { + *pointer = *p.k8sPluginState + } else { + return 0, fmt.Errorf("unexpected type when reading plugin state") + } + + return 0, nil +} + +// PluginStateReader overrides the default behavior to return our k8s plugin specific reader. +func (p *pluginContext) PluginStateReader() pluginsCore.PluginStateReader { + return pluginStateReader{ + k8sPluginState: p.k8sPluginState, + } +} + +func newPluginContext(tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) *pluginContext { return &pluginContext{ TaskExecutionContext: tCtx, ow: nil, + k8sPluginState: k8sPluginState, } } diff --git a/pkg/controller/nodes/task/k8s/plugin_manager.go b/pkg/controller/nodes/task/k8s/plugin_manager.go index e0d786858..67b0356a3 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -59,7 +59,8 @@ const ( ) type PluginState struct { - Phase PluginPhase + Phase PluginPhase + K8sPluginState k8s.PluginState } type PluginMetrics struct { @@ -247,7 +248,7 @@ func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.Tas return pluginsCore.DoTransition(pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "task submitted to K8s")), nil } -func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { +func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) (pluginsCore.Transition, error) { o, err := e.plugin.BuildIdentityResource(ctx, tCtx.TaskExecutionMetadata()) if err != nil { @@ -274,7 +275,7 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore e.metrics.ResourceDeleted.Inc(ctx) } - pCtx := newPluginContext(tCtx) + pCtx := newPluginContext(tCtx, k8sPluginState) p, err := e.plugin.GetTaskPhase(ctx, pCtx, o) if err != nil { logger.Warnf(ctx, "failed to check status of resource in plugin [%s], with error: %s", e.GetID(), err.Error()) @@ -311,6 +312,7 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore } func (e PluginManager) Handle(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { + // read phase state ps := PluginState{} if v, err := tCtx.PluginStateReader().Get(&ps); err != nil { if v != pluginStateVersion { @@ -318,16 +320,44 @@ func (e PluginManager) Handle(ctx context.Context, tCtx pluginsCore.TaskExecutio } return pluginsCore.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } + + // evaluate plugin + var err error + var transition pluginsCore.Transition + pluginPhase := ps.Phase if ps.Phase == PluginPhaseNotStarted { - t, err := e.LaunchResource(ctx, tCtx) - if err == nil && t.Info().Phase() == pluginsCore.PhaseQueued { - if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &PluginState{Phase: PluginPhaseStarted}); err != nil { - return pluginsCore.UnknownTransition, err - } + transition, err = e.LaunchResource(ctx, tCtx) + if err == nil && transition.Info().Phase() == pluginsCore.PhaseQueued { + pluginPhase = PluginPhaseStarted } - return t, err + } else { + transition, err = e.CheckResourcePhase(ctx, tCtx, &ps.K8sPluginState) + } + + if err != nil { + return transition, err } - return e.CheckResourcePhase(ctx, tCtx) + + // persist any changes in phase state + k8sPluginState := ps.K8sPluginState + if ps.Phase != pluginPhase || k8sPluginState.Phase != transition.Info().Phase() || + k8sPluginState.PhaseVersion != transition.Info().Version() || k8sPluginState.Reason != transition.Info().Reason() { + + newPluginState := PluginState{ + Phase: pluginPhase, + K8sPluginState: k8s.PluginState{ + Phase: transition.Info().Phase(), + PhaseVersion: transition.Info().Version(), + Reason: transition.Info().Reason(), + }, + } + + if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &newPluginState); err != nil { + return pluginsCore.UnknownTransition, err + } + } + + return transition, err } func (e PluginManager) Abort(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) error { diff --git a/pkg/controller/nodes/task/k8s/plugin_manager_test.go b/pkg/controller/nodes/task/k8s/plugin_manager_test.go index 160dc335f..94b6b5524 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager_test.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "testing" "k8s.io/client-go/kubernetes/scheme" @@ -715,6 +716,157 @@ func TestPluginManager_Handle_CheckResourceStatus(t *testing.T) { } } +func TestPluginManager_Handle_PluginState(t *testing.T) { + ctx := context.TODO() + tm := getMockTaskExecutionMetadata() + res := &v1.Pod{ + ObjectMeta: v12.ObjectMeta{ + Name: tm.GetTaskExecutionID().GetGeneratedName(), + Namespace: tm.GetNamespace(), + }, + } + + pluginStateQueued := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 0, + Reason: "foo", + }, + } + pluginStateQueuedVersion1 := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 1, + Reason: "foo", + }, + } + pluginStateQueuedReasonBar := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 0, + Reason: "bar", + }, + } + pluginStateRunning := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseRunning, + PhaseVersion: 0, + Reason: "", + }, + } + + phaseInfoQueued := pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginStateQueued.K8sPluginState.PhaseVersion, pluginStateQueued.K8sPluginState.Reason, nil) + phaseInfoQueuedVersion1 := pluginsCore.PhaseInfoQueuedWithTaskInfo( + pluginStateQueuedVersion1.K8sPluginState.PhaseVersion, + pluginStateQueuedVersion1.K8sPluginState.Reason, + nil, + ) + phaseInfoQueuedReasonBar := pluginsCore.PhaseInfoQueuedWithTaskInfo( + pluginStateQueuedReasonBar.K8sPluginState.PhaseVersion, + pluginStateQueuedReasonBar.K8sPluginState.Reason, + nil, + ) + phaseInfoRunning := pluginsCore.PhaseInfoRunning(0, nil) + + tests := []struct { + name string + startPluginState PluginState + reportedPhaseInfo pluginsCore.PhaseInfo + expectedPluginState PluginState + }{ + { + "NoChange", + pluginStateQueued, + phaseInfoQueued, + pluginStateQueued, + }, + { + "K8sPhaseChange", + pluginStateQueued, + phaseInfoRunning, + pluginStateRunning, + }, + { + "PhaseVersionChange", + pluginStateQueued, + phaseInfoQueuedVersion1, + pluginStateQueuedVersion1, + }, + { + "ReasonChange", + pluginStateQueued, + phaseInfoQueuedReasonBar, + pluginStateQueuedReasonBar, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // mock TaskExecutionContext + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(getMockTaskExecutionMetadata()) + + tReader := &pluginsCoreMock.TaskReader{} + tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{}, nil) + tCtx.OnTaskReader().Return(tReader) + + // mock state reader / writer to use local pluginState variable + pluginState := &tt.startPluginState + customStateReader := &pluginsCoreMock.PluginStateReader{} + customStateReader.OnGetMatch(mock.MatchedBy(func(i interface{}) bool { + ps, ok := i.(*PluginState) + if ok { + *ps = *pluginState + return true + } + return false + })).Return(uint8(0), nil) + tCtx.OnPluginStateReader().Return(customStateReader) + + customStateWriter := &pluginsCoreMock.PluginStateWriter{} + customStateWriter.OnPutMatch(mock.Anything, mock.MatchedBy(func(i interface{}) bool { + ps, ok := i.(*PluginState) + if ok { + *pluginState = *ps + } + return ok + })).Return(nil) + tCtx.OnPluginStateWriter().Return(customStateWriter) + tCtx.OnOutputWriter().Return(&dummyOutputWriter{}) + + fc := extendedFakeClient{Client: fake.NewFakeClient(res)} + + mockResourceHandler := &pluginsk8sMock.Plugin{} + mockResourceHandler.OnGetProperties().Return(k8s.PluginProperties{}) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tCtx.TaskExecutionMetadata()).Return(&v1.Pod{}, nil) + mockResourceHandler.On("GetTaskPhase", mock.Anything, mock.Anything, mock.Anything).Return(tt.reportedPhaseInfo, nil) + + // create new PluginManager + pluginManager, err := NewPluginManager(ctx, dummySetupContext(fc), k8s.PluginEntry{ + ID: "x", + ResourceToWatch: &v1.Pod{}, + Plugin: mockResourceHandler, + }, NewResourceMonitorIndex()) + assert.NoError(t, err) + + // handle plugin + _, err = pluginManager.Handle(ctx, tCtx) + assert.NoError(t, err) + + // verify expected PluginState + newPluginState := PluginState{} + _, err = tCtx.PluginStateReader().Get(&newPluginState) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(newPluginState, tt.expectedPluginState)) + }) + } +} + func TestPluginManager_CustomKubeClient(t *testing.T) { ctx := context.TODO() tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted)