From da77476ecbb8b183612405e9f3e3f7117d67d2a8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 12 Oct 2023 15:22:51 -0700 Subject: [PATCH] Update the item only if it exists in the cache (#4117) * Add item ID to the workqueue instead Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * Update the item only if it exists in the cache Signed-off-by: Kevin Su * Update tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * address comment Signed-off-by: Kevin Su * fixed tests Signed-off-by: Kevin Su * address comment Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- .../internal/webapi/monitor_test.go | 2 +- flytestdlib/cache/auto_refresh.go | 62 +++++++++++++++---- flytestdlib/cache/auto_refresh_test.go | 30 ++++++++- 3 files changed, 80 insertions(+), 14 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go index f916a93800..bb6f0b089c 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go @@ -37,7 +37,7 @@ func Test_monitor(t *testing.T) { client.OnStatusMatch(ctx, mock.Anything).Return(core2.PhaseInfoSuccess(nil), nil) wg := sync.WaitGroup{} - wg.Add(4) + wg.Add(8) cacheObj, err := cache.NewAutoRefreshCache(rand.String(5), func(ctx context.Context, batch cache.Batch) (updatedBatch []cache.ItemSyncResponse, err error) { wg.Done() t.Logf("Syncing Item [%+v]", batch[0]) diff --git a/flytestdlib/cache/auto_refresh.go b/flytestdlib/cache/auto_refresh.go index 13a787b03f..d0a609dc3b 100644 --- a/flytestdlib/cache/auto_refresh.go +++ b/flytestdlib/cache/auto_refresh.go @@ -3,6 +3,7 @@ package cache import ( "context" "fmt" + "sync" "time" "github.com/flyteorg/flyte/flytestdlib/contextutils" @@ -122,6 +123,7 @@ type autoRefresh struct { syncPeriod time.Duration workqueue workqueue.RateLimitingInterface parallelizm int + lock sync.RWMutex } func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value interface{}) { @@ -173,6 +175,25 @@ func (w *autoRefresh) Start(ctx context.Context) error { return nil } +// Update updates the item only if it exists in the cache, return true if we updated the item. +func (w *autoRefresh) Update(id ItemID, item Item) (ok bool) { + w.lock.Lock() + defer w.lock.Unlock() + ok = w.lruMap.Contains(id) + if ok { + w.lruMap.Add(id, item) + } + return ok +} + +// Delete deletes the item from the cache if it exists. +func (w *autoRefresh) Delete(key interface{}) { + w.lock.Lock() + defer w.lock.Unlock() + w.toDelete.Remove(key) + w.lruMap.Remove(key) +} + func (w *autoRefresh) Get(id ItemID) (Item, error) { if val, ok := w.lruMap.Get(id); ok { w.metrics.CacheHit.Inc() @@ -212,8 +233,7 @@ func (w *autoRefresh) enqueueBatches(ctx context.Context) error { snapshot := make([]ItemWrapper, 0, len(keys)) for _, k := range keys { if w.toDelete.Contains(k) { - w.lruMap.Remove(k) - w.toDelete.Remove(k) + w.Delete(k) continue } // If not ok, it means evicted between the item was evicted between getting the keys and this update loop @@ -273,18 +293,37 @@ func (w *autoRefresh) sync(ctx context.Context) (err error) { case <-ctx.Done(): return nil default: - item, shutdown := w.workqueue.Get() + batch, shutdown := w.workqueue.Get() if shutdown { + logger.Debugf(ctx, "Shutting down worker") return nil } - t := w.metrics.SyncLatency.Start() - updatedBatch, err := w.syncCb(ctx, *item.(*Batch)) - // Since we create batches every time we sync, we will just remove the item from the queue here // regardless of whether it succeeded the sync or not. - w.workqueue.Forget(item) - w.workqueue.Done(item) + w.workqueue.Forget(batch) + w.workqueue.Done(batch) + + newBatch := make(Batch, 0, len(*batch.(*Batch))) + for _, b := range *batch.(*Batch) { + itemID := b.GetID() + item, ok := w.lruMap.Get(itemID) + if !ok { + logger.Debugf(ctx, "item with id [%v] not found in cache", itemID) + continue + } + if item.(Item).IsTerminal() { + logger.Debugf(ctx, "item with id [%v] is terminal", itemID) + continue + } + newBatch = append(newBatch, b) + } + if len(newBatch) == 0 { + continue + } + + t := w.metrics.SyncLatency.Start() + updatedBatch, err := w.syncCb(ctx, newBatch) if err != nil { w.metrics.SyncErrors.Inc() @@ -295,14 +334,13 @@ func (w *autoRefresh) sync(ctx context.Context) (err error) { for _, item := range updatedBatch { if item.Action == Update { - // Add adds the item if it has been evicted or updates an existing one. - w.lruMap.Add(item.ID, item.Item) + // Updates an existing item. + w.Update(item.ID, item.Item) } } w.toDelete.Range(func(key interface{}) bool { - w.lruMap.Remove(key) - w.toDelete.Remove(key) + w.Delete(key) return true }) diff --git a/flytestdlib/cache/auto_refresh_test.go b/flytestdlib/cache/auto_refresh_test.go index e1250afd68..ccd2da6a19 100644 --- a/flytestdlib/cache/auto_refresh_test.go +++ b/flytestdlib/cache/auto_refresh_test.go @@ -62,7 +62,7 @@ func syncTerminalItem(_ context.Context, batch Batch) ([]ItemSyncResponse, error panic("This should never be called") } -func TestCacheThree(t *testing.T) { +func TestCacheFour(t *testing.T) { testResyncPeriod := time.Millisecond rateLimiter := workqueue.DefaultControllerRateLimiter() @@ -142,6 +142,34 @@ func TestCacheThree(t *testing.T) { cancel() }) + + t.Run("Test update and delete cache", func(t *testing.T) { + cache, err := NewAutoRefreshCache("fake3", syncTerminalItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope()) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + assert.NoError(t, cache.Start(ctx)) + + itemID := "dummy_id" + _, err = cache.GetOrCreate(itemID, terminalCacheItem{ + val: 0, + }) + assert.NoError(t, err) + + // Wait half a second for all resync periods to complete + // If the cache tries to enqueue the item, a panic will be thrown. + time.Sleep(500 * time.Millisecond) + + err = cache.DeleteDelayed(itemID) + assert.NoError(t, err) + + time.Sleep(500 * time.Millisecond) + item, err := cache.Get(itemID) + assert.Nil(t, item) + assert.Error(t, err) + + cancel() + }) } func TestQueueBuildUp(t *testing.T) {