diff --git a/store/cachekv/store.go b/store/cachekv/store.go index c11ddd9e2b2..5d174acaaeb 100644 --- a/store/cachekv/store.go +++ b/store/cachekv/store.go @@ -119,6 +119,10 @@ func (store *GStore[V]) Write() { store.writeSet.Clear() } +func (store *GStore[V]) Discard() { + store.writeSet.Clear() +} + // CacheWrap implements CacheWrapper. func (store *GStore[V]) CacheWrap() types.CacheWrap { return NewGStore(store, store.isZero, store.valueLen) diff --git a/store/cachemulti/store.go b/store/cachemulti/store.go index 3c116401807..a78a4b92dd1 100644 --- a/store/cachemulti/store.go +++ b/store/cachemulti/store.go @@ -45,7 +45,7 @@ func NewFromKVStore( } for key, store := range stores { - cms.stores[key] = cms.initStore(key, store) + cms.initStore(key, store) } return cms @@ -80,7 +80,9 @@ func (cms Store) initStore(key types.StoreKey, store types.CacheWrapper) types.C store = tracekv.NewStore(kvstore, cms.traceWriter, tctx) } } - return store.CacheWrap() + cache := store.CacheWrap() + cms.stores[key] = cache + return cache } // SetTracer sets the tracer for the MultiStore that the underlying @@ -126,6 +128,12 @@ func (cms Store) Write() { } } +func (cms Store) Discard() { + for _, store := range cms.stores { + store.Discard() + } +} + // Implements CacheWrapper. func (cms Store) CacheWrap() types.CacheWrap { return cms.CacheMultiStore().(types.CacheWrap) @@ -138,14 +146,9 @@ func (cms Store) CacheMultiStore() types.CacheMultiStore { func (cms Store) getCacheWrap(key types.StoreKey) types.CacheWrap { store, ok := cms.stores[key] - if !ok { + if !ok && cms.parentStore != nil { // load on demand - if cms.branched { - store = cms.parentStore(key).(types.BranchStore).Clone().(types.CacheWrap) - } else if cms.parentStore != nil { - store = cms.initStore(key, cms.parentStore(key)) - } - cms.stores[key] = store + store = cms.initStore(key, cms.parentStore(key)) } if key == nil || store == nil { panic(fmt.Sprintf("kv store with key %v has not been registered in stores", key)) @@ -181,12 +184,15 @@ func (cms Store) GetObjKVStore(key types.StoreKey) types.ObjKVStore { } func (cms Store) Clone() Store { + stores := make(map[types.StoreKey]types.CacheWrap, len(cms.stores)) + for k, v := range cms.stores { + stores[k] = v.(types.BranchStore).Clone().(types.CacheWrap) + } return Store{ - stores: make(map[types.StoreKey]types.CacheWrap), - + stores: stores, traceWriter: cms.traceWriter, traceContext: cms.traceContext, - parentStore: cms.getCacheWrap, + parentStore: cms.parentStore, branched: true, } @@ -197,9 +203,22 @@ func (cms Store) Restore(other Store) { panic("cannot restore from non-branched store") } - // restore the stores + // discard the non-exists stores + for k, v := range cms.stores { + if _, ok := other.stores[k]; !ok { + // clear the cache store if it's not in the other + v.Discard() + } + } + + // restore the other stores for k, v := range other.stores { - cms.stores[k].(types.BranchStore).Restore(v.(types.BranchStore)) + store, ok := cms.stores[k] + if !ok { + store = cms.initStore(k, cms.parentStore(k)) + } + + store.(types.BranchStore).Restore(v.(types.BranchStore)) } } diff --git a/store/cachemulti/store_test.go b/store/cachemulti/store_test.go index 03053c4e563..01c7a810bf9 100644 --- a/store/cachemulti/store_test.go +++ b/store/cachemulti/store_test.go @@ -35,22 +35,22 @@ func TestRunAtomic(t *testing.T) { func(v any) int { return 1 }, ) keys := map[string]types.StoreKey{ - "abc": types.NewKVStoreKey("abc"), - "obj": types.NewObjectStoreKey("obj"), - "lazy": types.NewKVStoreKey("lazy"), + "abc": types.NewKVStoreKey("abc"), + "obj": types.NewObjectStoreKey("obj"), } - s := Store{stores: map[types.StoreKey]types.CacheWrap{ - keys["abc"]: store.CacheWrap(), - keys["obj"]: objStore.CacheWrap(), - keys["lazy"]: nil, + parent := Store{stores: map[types.StoreKey]types.CacheWrap{ + keys["abc"]: store.CacheWrap(), + keys["obj"]: objStore.CacheWrap(), }} + s := Store{stores: map[types.StoreKey]types.CacheWrap{}, parentStore: parent.getCacheWrap} s.RunAtomic(func(ms types.CacheMultiStore) error { ms.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value")) ms.GetObjKVStore(keys["obj"]).Set([]byte("key"), "value") return nil }) require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key"))) + require.Equal(t, []byte(nil), s.GetKVStore(keys["abc"]).Get([]byte("key-non-exist"))) require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string)) require.Error(t, s.RunAtomic(func(ms types.CacheMultiStore) error { @@ -61,3 +61,28 @@ func TestRunAtomic(t *testing.T) { require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key"))) require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string)) } + +func TestBranchStore(t *testing.T) { + store := dbadapter.Store{DB: dbm.NewMemDB()} + objStore := internal.NewBTreeStore(btree.NewBTree[any](), + func(v any) bool { return v == nil }, + func(v any) int { return 1 }, + ) + keys := map[string]types.StoreKey{ + "abc": types.NewKVStoreKey("abc"), + "obj": types.NewObjectStoreKey("obj"), + } + parent := Store{stores: map[types.StoreKey]types.CacheWrap{ + keys["abc"]: store.CacheWrap(), + keys["obj"]: objStore.CacheWrap(), + }} + + s := Store{stores: map[types.StoreKey]types.CacheWrap{}, parentStore: parent.getCacheWrap} + s.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value")) + snapshot := s.Clone() + s.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value2")) + s.GetObjKVStore(keys["obj"]).Set([]byte("key"), "value") + s.Restore(snapshot) + require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key"))) + require.Equal(t, nil, s.GetObjKVStore(keys["obj"]).Get([]byte("key"))) +} diff --git a/store/types/store.go b/store/types/store.go index e63ad6d9a79..67bd140f5e7 100644 --- a/store/types/store.go +++ b/store/types/store.go @@ -339,6 +339,9 @@ type CacheWrap interface { // Write syncs with the underlying store. Write() + + // Discard the write set + Discard() } type CacheWrapper interface {