diff --git a/store/cachemulti/store_test.go b/store/cachemulti/store_test.go index 8029282eafc4..03053c4e5633 100644 --- a/store/cachemulti/store_test.go +++ b/store/cachemulti/store_test.go @@ -1,6 +1,7 @@ package cachemulti import ( + "errors" "fmt" "testing" @@ -51,4 +52,12 @@ func TestRunAtomic(t *testing.T) { }) require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key"))) require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string)) + + require.Error(t, s.RunAtomic(func(ms types.CacheMultiStore) error { + ms.GetKVStore(keys["abc"]).Set([]byte("key"), []byte("value2")) + ms.GetObjKVStore(keys["obj"]).Set([]byte("key"), "value2") + return errors.New("failure") + })) + require.Equal(t, []byte("value"), s.GetKVStore(keys["abc"]).Get([]byte("key"))) + require.Equal(t, "value", s.GetObjKVStore(keys["obj"]).Get([]byte("key")).(string)) } diff --git a/store/rootmulti/store.go b/store/rootmulti/store.go index 749d7fe14678..9ee396b67b86 100644 --- a/store/rootmulti/store.go +++ b/store/rootmulti/store.go @@ -1061,26 +1061,24 @@ func (rs *Store) loadCommitStoreFromParams(key types.StoreKey, id types.CommitID return commitDBStoreAdapter{Store: dbadapter.Store{DB: db}}, nil case types.StoreTypeTransient: - _, ok := key.(*types.TransientStoreKey) - if !ok { - return nil, fmt.Errorf("invalid StoreKey for StoreTypeTransient: %s", key.String()) + if _, ok := key.(*types.TransientStoreKey); !ok { + return nil, fmt.Errorf("unexpected key type for a TransientStoreKey; got: %s, %T", key.String(), key) } return transient.NewStore(), nil case types.StoreTypeMemory: - _, ok := key.(*types.ObjectStoreKey) - if !ok { - return nil, fmt.Errorf("invalid StoreKey for StoreTypeTransient: %s", key.String()) - } - if _, ok := key.(*types.MemoryStoreKey); !ok { - return nil, fmt.Errorf("unexpected key type for a MemoryStoreKey; got: %s", key.String()) + return nil, fmt.Errorf("unexpected key type for a MemoryStoreKey; got: %s, %T", key.String(), key) } return mem.NewStore(), nil case types.StoreTypeObject: + if _, ok := key.(*types.ObjectStoreKey); !ok { + return nil, fmt.Errorf("unexpected key type for a ObjectStoreKey; got: %s, %T", key.String(), key) + } + return transient.NewObjStore(), nil default: diff --git a/store/rootmulti/store_test.go b/store/rootmulti/store_test.go index 66d9f296535c..4cc6a6ff8323 100644 --- a/store/rootmulti/store_test.go +++ b/store/rootmulti/store_test.go @@ -933,6 +933,7 @@ func prepareStoreMap() (map[types.StoreKey]types.CommitStore, error) { store.MountStoreWithDB(types.NewKVStoreKey("iavl1"), types.StoreTypeIAVL, nil) store.MountStoreWithDB(types.NewKVStoreKey("iavl2"), types.StoreTypeIAVL, nil) store.MountStoreWithDB(types.NewTransientStoreKey("trans1"), types.StoreTypeTransient, nil) + store.MountStoreWithDB(types.NewMemoryStoreKey("mem1"), types.StoreTypeMemory, nil) store.MountStoreWithDB(types.NewObjectStoreKey("obj1"), types.StoreTypeObject, nil) if err := store.LoadLatestVersion(); err != nil { return nil, err