diff --git a/go.mod b/go.mod index 519ee4524a..f4eceb6a65 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 github.com/flyteorg/flyteidl v0.21.0 - github.com/flyteorg/flytestdlib v0.3.33 + github.com/flyteorg/flytestdlib v0.3.36 github.com/go-logr/zapr v0.4.0 // indirect github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.4.3 diff --git a/go.sum b/go.sum index e9d895013f..893532f0a9 100644 --- a/go.sum +++ b/go.sum @@ -229,8 +229,8 @@ github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4 github.com/flyteorg/flyteidl v0.21.0 h1:AwHNusfxJMfRRSDk2QWfb3aIlyLJrFWVGtpXCbCtJ5A= github.com/flyteorg/flyteidl v0.21.0/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= -github.com/flyteorg/flytestdlib v0.3.33 h1:+oCx3zXUIldL7CWmNMD7PMFPXvGqaPgYkSKn9wB6qvY= -github.com/flyteorg/flytestdlib v0.3.33/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q= +github.com/flyteorg/flytestdlib v0.3.36 h1:XLvc7kfc9XkQBpPvNXevh5+Ijbgmd7gEOHTWhdEY5eA= +github.com/flyteorg/flytestdlib v0.3.36/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= diff --git a/go/tasks/pluginmachinery/internal/webapi/cache.go b/go/tasks/pluginmachinery/internal/webapi/cache.go index ce08ae2753..856e2104b3 100644 --- a/go/tasks/pluginmachinery/internal/webapi/cache.go +++ b/go/tasks/pluginmachinery/internal/webapi/cache.go @@ -101,7 +101,7 @@ func (q *ResourceCache) SyncResource(ctx context.Context, batch cache.Batch) ( logger.Debugf(ctx, "Querying AsyncPlugin for %s", resource.GetID()) newResource, err := q.client.Get(ctx, newPluginContext(cacheItem.ResourceMeta, cacheItem.Resource, "", nil)) if err != nil { - logger.Errorf(ctx, "Error retrieving resource [%s]. Error: %v", resource.GetID(), err) + logger.Infof(ctx, "Error retrieving resource [%s]. Error: %v", resource.GetID(), err) cacheItem.SyncFailureCount++ // Make sure we don't return nil for the first argument, because that deletes it from the cache. diff --git a/go/tasks/pluginmachinery/internal/webapi/monitor.go b/go/tasks/pluginmachinery/internal/webapi/monitor.go index 9f93195242..6d3d4ac33d 100644 --- a/go/tasks/pluginmachinery/internal/webapi/monitor.go +++ b/go/tasks/pluginmachinery/internal/webapi/monitor.go @@ -16,8 +16,8 @@ func monitor(ctx context.Context, tCtx core.TaskExecutionContext, p Client, cach State: *state, } - item, err := cache.GetOrCreate( - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), newCacheItem) + cacheItemID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + item, err := cache.GetOrCreate(cacheItemID, newCacheItem) if err != nil { return nil, core.PhaseInfo{}, err } @@ -50,6 +50,15 @@ func monitor(ctx context.Context, tCtx core.TaskExecutionContext, p Client, cach cacheItem.Phase = newPluginPhase + if newPluginPhase.IsTerminal() { + // Queue item for deletion in the cache. + err = cache.DeleteDelayed(cacheItemID) + if err != nil { + logger.Warnf(ctx, "Failed to queue item for deletion in the cache with Item Id: [%v]. Error: %v", + cacheItemID, err) + } + } + // If there were updates made to the state, we'll have picked them up automatically. Nothing more to do. return &cacheItem.State, newPhase, nil } diff --git a/go/tasks/pluginmachinery/internal/webapi/monitor_test.go b/go/tasks/pluginmachinery/internal/webapi/monitor_test.go new file mode 100644 index 0000000000..0842f87d6f --- /dev/null +++ b/go/tasks/pluginmachinery/internal/webapi/monitor_test.go @@ -0,0 +1,80 @@ +package webapi + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/flyteorg/flytestdlib/cache" + "github.com/flyteorg/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/util/rand" + "k8s.io/client-go/util/workqueue" + + core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + + "github.com/stretchr/testify/mock" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + internalMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/internal/webapi/mocks" +) + +func Test_monitor(t *testing.T) { + ctx := context.Background() + tCtx := &mocks.TaskExecutionContext{} + ctxMeta := &mocks.TaskExecutionMetadata{} + execID := &mocks.TaskExecutionID{} + execID.OnGetGeneratedName().Return("generated_name") + execID.OnGetID().Return(core.TaskExecutionIdentifier{}) + ctxMeta.OnGetTaskExecutionID().Return(execID) + tCtx.OnTaskExecutionMetadata().Return(ctxMeta) + + client := &internalMocks.Client{} + client.OnStatusMatch(ctx, mock.Anything).Return(core2.PhaseInfoSuccess(nil), nil) + + wg := sync.WaitGroup{} + wg.Add(4) + cacheObj, err := cache.NewAutoRefreshCache(rand.String(5), func(ctx context.Context, batch cache.Batch) (updatedBatch []cache.ItemSyncResponse, err error) { + wg.Done() + t.Logf("Syncing Item [%+v]", batch[0]) + return []cache.ItemSyncResponse{ + { + ID: batch[0].GetID(), + Item: batch[0].GetItem(), + Action: cache.Update, + }, + }, nil + }, workqueue.DefaultControllerRateLimiter(), time.Second, 1, 10, promutils.NewTestScope()) + assert.NoError(t, err) + + assert.NoError(t, cacheObj.Start(ctx)) + + // Insert a dummy item to make sure the sync loop keeps getting invoked + _, err = cacheObj.GetOrCreate("generated_name2", CacheItem{Resource: "fake_resource2"}) + assert.NoError(t, err) + + _, err = cacheObj.GetOrCreate("generated_name", CacheItem{Resource: "fake_resource"}) + assert.NoError(t, err) + + s := &State{} + newState, phaseInfo, err := monitor(ctx, tCtx, client, cacheObj, s) + assert.NoError(t, err) + assert.NotNil(t, newState) + assert.NotNil(t, phaseInfo) + assert.Equal(t, core2.PhaseSuccess.String(), phaseInfo.Phase().String()) + + // Make sure the item is still in the cache as is... + cachedItem, err := cacheObj.GetOrCreate("generated_name", CacheItem{Resource: "shouldnt_insert"}) + assert.NoError(t, err) + assert.Equal(t, "fake_resource", cachedItem.(CacheItem).Resource.(string)) + + // Wait for sync to run to actually delete the resource + wg.Wait() + cachedItem, err = cacheObj.GetOrCreate("generated_name", CacheItem{Resource: "new_resource"}) + assert.NoError(t, err) + assert.Equal(t, "new_resource", cachedItem.(CacheItem).Resource.(string)) +}