From 1b9e0538e1827eda94785c87879ffb7765179cf9 Mon Sep 17 00:00:00 2001
From: yihuang <huang@crypto.com>
Date: Tue, 2 Apr 2024 09:54:25 +0800
Subject: [PATCH] Problem: Restore don't work snapshot revert usage (#245)

Solution:
- fix and add test to support the usage pattern in ethermint

add Discard method to CacheWrap

better testing
---
 store/cachekv/store.go         |  4 +++
 store/cachemulti/store.go      | 47 ++++++++++++++++++++++++----------
 store/cachemulti/store_test.go | 39 +++++++++++++++++++++++-----
 store/types/store.go           |  3 +++
 4 files changed, 72 insertions(+), 21 deletions(-)

diff --git a/store/cachekv/store.go b/store/cachekv/store.go
index c11ddd9e2b21..5d174acaaeb8 100644
--- a/store/cachekv/store.go
+++ b/store/cachekv/store.go
@@ -119,6 +119,10 @@ func (store *GStore[V]) Write() {
 	store.writeSet.Clear()
 }
 
+func (store *GStore[V]) Discard() {
+	store.writeSet.Clear()
+}
+
 // CacheWrap implements CacheWrapper.
 func (store *GStore[V]) CacheWrap() types.CacheWrap {
 	return NewGStore(store, store.isZero, store.valueLen)
diff --git a/store/cachemulti/store.go b/store/cachemulti/store.go
index 3c1164018072..a78a4b92dd12 100644
--- a/store/cachemulti/store.go
+++ b/store/cachemulti/store.go
@@ -45,7 +45,7 @@ func NewFromKVStore(
 	}
 
 	for key, store := range stores {
-		cms.stores[key] = cms.initStore(key, store)
+		cms.initStore(key, store)
 	}
 
 	return cms
@@ -80,7 +80,9 @@ func (cms Store) initStore(key types.StoreKey, store types.CacheWrapper) types.C
 			store = tracekv.NewStore(kvstore, cms.traceWriter, tctx)
 		}
 	}
-	return store.CacheWrap()
+	cache := store.CacheWrap()
+	cms.stores[key] = cache
+	return cache
 }
 
 // SetTracer sets the tracer for the MultiStore that the underlying
@@ -126,6 +128,12 @@ func (cms Store) Write() {
 	}
 }
 
+func (cms Store) Discard() {
+	for _, store := range cms.stores {
+		store.Discard()
+	}
+}
+
 // Implements CacheWrapper.
 func (cms Store) CacheWrap() types.CacheWrap {
 	return cms.CacheMultiStore().(types.CacheWrap)
@@ -138,14 +146,9 @@ func (cms Store) CacheMultiStore() types.CacheMultiStore {
 
 func (cms Store) getCacheWrap(key types.StoreKey) types.CacheWrap {
 	store, ok := cms.stores[key]
-	if !ok {
+	if !ok && cms.parentStore != nil {
 		// load on demand
-		if cms.branched {
-			store = cms.parentStore(key).(types.BranchStore).Clone().(types.CacheWrap)
-		} else if cms.parentStore != nil {
-			store = cms.initStore(key, cms.parentStore(key))
-		}
-		cms.stores[key] = store
+		store = cms.initStore(key, cms.parentStore(key))
 	}
 	if key == nil || store == nil {
 		panic(fmt.Sprintf("kv store with key %v has not been registered in stores", key))
@@ -181,12 +184,15 @@ func (cms Store) GetObjKVStore(key types.StoreKey) types.ObjKVStore {
 }
 
 func (cms Store) Clone() Store {
+	stores := make(map[types.StoreKey]types.CacheWrap, len(cms.stores))
+	for k, v := range cms.stores {
+		stores[k] = v.(types.BranchStore).Clone().(types.CacheWrap)
+	}
 	return Store{
-		stores: make(map[types.StoreKey]types.CacheWrap),
-
+		stores:       stores,
 		traceWriter:  cms.traceWriter,
 		traceContext: cms.traceContext,
-		parentStore:  cms.getCacheWrap,
+		parentStore:  cms.parentStore,
 
 		branched: true,
 	}
@@ -197,9 +203,22 @@ func (cms Store) Restore(other Store) {
 		panic("cannot restore from non-branched store")
 	}
 
-	// restore the stores
+	// discard the non-exists stores
+	for k, v := range cms.stores {
+		if _, ok := other.stores[k]; !ok {
+			// clear the cache store if it's not in the other
+			v.Discard()
+		}
+	}
+
+	// restore the other stores
 	for k, v := range other.stores {
-		cms.stores[k].(types.BranchStore).Restore(v.(types.BranchStore))
+		store, ok := cms.stores[k]
+		if !ok {
+			store = cms.initStore(k, cms.parentStore(k))
+		}
+
+		store.(types.BranchStore).Restore(v.(types.BranchStore))
 	}
 }
 
diff --git a/store/cachemulti/store_test.go b/store/cachemulti/store_test.go
index 03053c4e5633..01c7a810bf91 100644
--- a/store/cachemulti/store_test.go
+++ b/store/cachemulti/store_test.go
@@ -35,22 +35,22 @@ func TestRunAtomic(t *testing.T) {
 		func(v any) int { return 1 },
 	)
 	keys := map[string]types.StoreKey{
-		"abc":  types.NewKVStoreKey("abc"),
-		"obj":  types.NewObjectStoreKey("obj"),
-		"lazy": types.NewKVStoreKey("lazy"),
+		"abc": types.NewKVStoreKey("abc"),
+		"obj": types.NewObjectStoreKey("obj"),
 	}
-	s := Store{stores: map[types.StoreKey]types.CacheWrap{
-		keys["abc"]:  store.CacheWrap(),
-		keys["obj"]:  objStore.CacheWrap(),
-		keys["lazy"]: nil,
+	parent := Store{stores: map[types.StoreKey]types.CacheWrap{
+		keys["abc"]: store.CacheWrap(),
+		keys["obj"]: objStore.CacheWrap(),
 	}}
 
+	s := Store{stores: map[types.StoreKey]types.CacheWrap{}, parentStore: parent.getCacheWrap}
 	s.RunAtomic(func(ms types.CacheMultiStore) error {
 		ms.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value"))
 		ms.GetObjKVStore(keys["obj"]).Set([]byte("key"), "value")
 		return nil
 	})
 	require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key")))
+	require.Equal(t, []byte(nil), s.GetKVStore(keys["abc"]).Get([]byte("key-non-exist")))
 	require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string))
 
 	require.Error(t, s.RunAtomic(func(ms types.CacheMultiStore) error {
@@ -61,3 +61,28 @@ func TestRunAtomic(t *testing.T) {
 	require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key")))
 	require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string))
 }
+
+func TestBranchStore(t *testing.T) {
+	store := dbadapter.Store{DB: dbm.NewMemDB()}
+	objStore := internal.NewBTreeStore(btree.NewBTree[any](),
+		func(v any) bool { return v == nil },
+		func(v any) int { return 1 },
+	)
+	keys := map[string]types.StoreKey{
+		"abc": types.NewKVStoreKey("abc"),
+		"obj": types.NewObjectStoreKey("obj"),
+	}
+	parent := Store{stores: map[types.StoreKey]types.CacheWrap{
+		keys["abc"]: store.CacheWrap(),
+		keys["obj"]: objStore.CacheWrap(),
+	}}
+
+	s := Store{stores: map[types.StoreKey]types.CacheWrap{}, parentStore: parent.getCacheWrap}
+	s.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value"))
+	snapshot := s.Clone()
+	s.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value2"))
+	s.GetObjKVStore(keys["obj"]).Set([]byte("key"), "value")
+	s.Restore(snapshot)
+	require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key")))
+	require.Equal(t, nil, s.GetObjKVStore(keys["obj"]).Get([]byte("key")))
+}
diff --git a/store/types/store.go b/store/types/store.go
index e63ad6d9a79b..67bd140f5e75 100644
--- a/store/types/store.go
+++ b/store/types/store.go
@@ -339,6 +339,9 @@ type CacheWrap interface {
 
 	// Write syncs with the underlying store.
 	Write()
+
+	// Discard the write set
+	Discard()
 }
 
 type CacheWrapper interface {