From 1f18db9bcf7c071d4e779c1220e55951cb795e78 Mon Sep 17 00:00:00 2001 From: you06 Date: Fri, 6 Sep 2024 20:05:48 +0800 Subject: [PATCH] memdb: fix memdb snapshot get/iter is not actually snapshot (#1393) (#1433) ref tikv/client-go#1394 Signed-off-by: you06 --- internal/unionstore/memdb_arena.go | 2 +- internal/unionstore/memdb_test.go | 36 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/internal/unionstore/memdb_arena.go b/internal/unionstore/memdb_arena.go index 146d6a0fe..c1db745af 100644 --- a/internal/unionstore/memdb_arena.go +++ b/internal/unionstore/memdb_arena.go @@ -353,7 +353,7 @@ func (l *memdbVlog) getSnapshotValue(addr memdbArenaAddr, snap *MemDBCheckpoint) if result.isNull() { return nil, false } - return l.getValue(addr), true + return l.getValue(result), true } func (l *memdbVlog) selectValueHistory(addr memdbArenaAddr, predicate func(memdbArenaAddr) bool) memdbArenaAddr { diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index e5fe20631..0c2852e52 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -37,6 +37,7 @@ package unionstore import ( + "context" "encoding/binary" "fmt" "testing" @@ -860,3 +861,38 @@ func TestUnsetTemporaryFlag(t *testing.T) { require.Nil(err) require.False(flags.HasNeedConstraintCheckInPrewrite()) } + +func TestSnapshotGetIter(t *testing.T) { + assert := assert.New(t) + buffer := newMemDB() + var getters []Getter + var iters []Iterator + for i := 0; i < 100; i++ { + assert.Nil(buffer.Set([]byte{byte(0)}, []byte{byte(i)})) + // getter + getter := buffer.SnapshotGetter() + val, err := getter.Get(context.Background(), []byte{byte(0)}) + assert.Nil(err) + assert.Equal(val, []byte{byte(min(i, 50))}) + getters = append(getters, getter) + // iter + iter := buffer.SnapshotIter(nil, nil) + assert.Nil(err) + assert.Equal(iter.Key(), []byte{byte(0)}) + assert.Equal(iter.Value(), []byte{byte(min(i, 50))}) + iter.Close() + iters = append(iters, buffer.SnapshotIter(nil, nil)) + if i == 50 { + _ = buffer.Staging() + } + } + for _, getter := range getters { + val, err := getter.Get(context.Background(), []byte{byte(0)}) + assert.Nil(err) + assert.Equal(val, []byte{byte(50)}) + } + for _, iter := range iters { + assert.Equal(iter.Key(), []byte{byte(0)}) + assert.Equal(iter.Value(), []byte{byte(50)}) + } +}