Skip to content

Commit

Permalink
feat: add custom query with check state (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yongwoo Lee authored Jul 28, 2020
1 parent 7bdbf03 commit e6cb294
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 26 deletions.
93 changes: 76 additions & 17 deletions baseapp/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"syscall"

"github.com/cosmos/cosmos-sdk/store/types"
abci "github.com/tendermint/tendermint/abci/types"

"github.com/cosmos/cosmos-sdk/codec"
Expand Down Expand Up @@ -310,6 +311,9 @@ func (app *BaseApp) Query(req abci.RequestQuery) abci.ResponseQuery {

case "custom":
return handleQueryCustom(app, path, req)

case "check_state":
return handleQueryCheckState(app, path, req)
}

return sdkerrors.QueryResult(sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown query path"))
Expand Down Expand Up @@ -418,44 +422,99 @@ func handleQueryP2P(app *BaseApp, path []string) abci.ResponseQuery {
}

func handleQueryCustom(app *BaseApp, path []string, req abci.RequestQuery) abci.ResponseQuery {
// path[0] should be "custom" because "/custom" prefix is required for keeper
// queries.
//
// The QueryRouter routes using path[1]. For example, in the path
// "custom/gov/proposal", QueryRouter routes using "gov".
if len(path) < 2 || path[1] == "" {
return sdkerrors.QueryResult(sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "no route for custom query specified"))
}

querier := app.queryRouter.Route(path[1])
if querier == nil {
return sdkerrors.QueryResult(sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "no custom querier found for route %s", path[1]))
querier, err := getCustomQuerier(app, path)
if err != nil {
return sdkerrors.QueryResult(err)
}

// when a client did not provide a query height, manually inject the latest
if req.Height == 0 {
req.Height = app.LastBlockHeight()
}

if req.Height <= 1 && req.Prove {
if err := checkProvable(req); err != nil {
return sdkerrors.QueryResult(err)
}

cacheMS, err := app.cms.CacheMultiStoreWithVersion(req.Height)
if err != nil {
return sdkerrors.QueryResult(
sdkerrors.Wrap(
sdkerrors.Wrapf(
sdkerrors.ErrInvalidRequest,
"cannot query with proof when height <= 1; please provide a valid height",
"failed to load state at height %d; %s (latest height: %d)", req.Height, err, app.LastBlockHeight(),
),
)
}

cacheMS, err := app.cms.CacheMultiStoreWithVersion(req.Height)
return processCustomQuerier(app, cacheMS, querier, path, req)
}

func handleQueryCheckState(app *BaseApp, path []string, req abci.RequestQuery) abci.ResponseQuery {
querier, err := getCustomQuerier(app, path)
if err != nil {
return sdkerrors.QueryResult(err)
}

checkStateHeight := app.checkState.ctx.BlockHeight()

// when a client did not provide a query height, manually inject the latest
if req.Height == 0 {
req.Height = checkStateHeight
}

if req.Height != checkStateHeight {
return sdkerrors.QueryResult(
sdkerrors.Wrapf(
sdkerrors.ErrInvalidRequest,
"failed to load state at height %d; %s (latest height: %d)", req.Height, err, app.LastBlockHeight(),
"invalid request height %d; the height should be equal to the check state height %d",
req.Height,
checkStateHeight,
),
)
}

if err := checkProvable(req); err != nil {
return sdkerrors.QueryResult(err)
}

// a snapshot of CheckState multi-store
cacheMS := app.checkState.ms.CacheMultiStore()

return processCustomQuerier(app, cacheMS, querier, path, req)
}

func getCustomQuerier(app *BaseApp, path []string) (sdk.Querier, error) {
// path[0] should be "custom" or "check_state" because the prefix is required for keeper
// queries.
//
// The QueryRouter routes using path[1]. For example, in the path
// "custom/gov/proposal", QueryRouter routes using "gov".
if len(path) < 2 || path[1] == "" {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "no route for custom query specified")
}

querier := app.queryRouter.Route(path[1])
if querier == nil {
return nil, sdkerrors.Wrapf(
sdkerrors.ErrUnknownRequest, "no custom querier found for route %s", path[1])
}
return querier, nil
}

func checkProvable(req abci.RequestQuery) error {
if req.Height <= 1 && req.Prove {
return sdkerrors.Wrap(
sdkerrors.ErrInvalidRequest,
"cannot query with proof when height <= 1; please provide a valid height",
)
}
return nil
}

func processCustomQuerier(
app *BaseApp, cacheMS types.CacheMultiStore, querier sdk.Querier, path []string, req abci.RequestQuery,
) abci.ResponseQuery {

// cache wrap the commit-multistore for safety
ctx := sdk.NewContext(
cacheMS, app.checkState.ctx.BlockHeader(), true, app.logger,
Expand Down
67 changes: 67 additions & 0 deletions baseapp/baseapp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,73 @@ func TestQuery(t *testing.T) {
require.Equal(t, value, res.Value)
}

// Test that we can only query from the latest committed state.
func TestCheckStateQuery(t *testing.T) {
keyForAnte, valueForAnte := []byte("ante key"), []byte("ante value")
keyForMsg, valueForMsg := []byte("msg key"), []byte("msg value")
anteOpt := func(bapp *BaseApp) {
bapp.SetAnteHandler(func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) {
store := ctx.KVStore(capKey1)
store.Set(keyForAnte, valueForAnte)
return
})
}

routerOpt := func(bapp *BaseApp) {
bapp.Router().AddRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
store := ctx.KVStore(capKey1)
store.Set(keyForMsg, valueForMsg)
return &sdk.Result{}, nil
})
}

queryRouterOpt := func(bapp *BaseApp) {
querier := func(ctx sdk.Context, path []string, req abci.RequestQuery) ([]byte, error) {
store := ctx.KVStore(capKey1)
anteValue := store.Get(keyForAnte)
msgValue := store.Get(keyForMsg)
return append(anteValue, msgValue...), nil
}
bapp.QueryRouter().AddRoute("queryFoo", querier)
}

app := setupBaseApp(t, anteOpt, routerOpt, queryRouterOpt)

app.InitChain(abci.RequestInitChain{})

// Request for query check state
query := abci.RequestQuery{
Path: "/check_state/queryFoo",
}
tx := newTxCounter(0, 0)

// query is empty before we do anything
res := app.Query(query)
require.Equal(t, 0, len(res.Value))

// ante has been done, so changes of ante should be returned
// however msg has not been executed on CheckTx.
_, resTx, err := app.Check(tx)
require.NoError(t, err)
require.NotNil(t, resTx)
res = app.Query(query)
require.Equal(t, valueForAnte, res.Value)

header := abci.Header{Height: app.LastBlockHeight() + 1}
app.BeginBlock(abci.RequestBeginBlock{Header: header})

_, resTx, err = app.Deliver(tx)
require.NoError(t, err)
require.NotNil(t, resTx)
res = app.Query(query)
require.Equal(t, valueForAnte, res.Value)

// query returns correct value after Commit
app.Commit()
res = app.Query(query)
require.Equal(t, append(valueForAnte, valueForMsg...), res.Value)
}

// Test p2p filter queries
func TestP2PQuery(t *testing.T) {
addrPeerFilterOpt := func(bapp *BaseApp) {
Expand Down
7 changes: 6 additions & 1 deletion x/auth/client/cli/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

const (
flagEvents = "events"
flagCheckState = "check_state"

eventFormat = "{eventType}.{eventAttribute}={value}"
)
Expand Down Expand Up @@ -49,8 +50,10 @@ func GetAccountCmd(cdc *codec.Codec) *cobra.Command {
Short: "Query account balance",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
checkState := viper.GetBool(flagCheckState)

cliCtx := context.NewCLIContext().WithCodec(cdc)
accGetter := types.NewAccountRetriever(cliCtx)
accGetter := types.NewAccountRetriever(cliCtx).WithCheckState(checkState)

key, err := sdk.AccAddressFromBech32(args[0])
if err != nil {
Expand All @@ -66,6 +69,8 @@ func GetAccountCmd(cdc *codec.Codec) *cobra.Command {
},
}

cmd.Flags().Bool(flagCheckState, false, "query with the check state")

return flags.GetCommands(cmd)[0]
}

Expand Down
12 changes: 11 additions & 1 deletion x/auth/client/rest/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,22 @@ func QueryAccountRequestHandlerFn(storeName string, cliCtx context.CLIContext) h
return
}

checkState := false
if p := r.FormValue("check_state"); len(p) > 0 {
checkState, err = strconv.ParseBool(p)
if err != nil {
err := fmt.Errorf("'%s' is not a valid bool", p)
rest.WriteErrorResponse(w, http.StatusBadRequest, err.Error())
return
}
}

cliCtx, ok := rest.ParseQueryHeightOrReturnBadRequest(w, cliCtx, r)
if !ok {
return
}

accGetter := types.NewAccountRetriever(cliCtx)
accGetter := types.NewAccountRetriever(cliCtx).WithCheckState(checkState)

account, height, err := accGetter.GetAccountWithHeight(addr)
if err != nil {
Expand Down
21 changes: 17 additions & 4 deletions x/auth/types/account_retriever.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package types

import (
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/exported"
)

const (
customQueryPath = "custom/" + QuerierRoute + "/" + QueryAccount
checkStateQueryPath = "check_state/" + QuerierRoute + "/" + QueryAccount
)

// NodeQuerier is an interface that is satisfied by types that provide the QueryWithData method
type NodeQuerier interface {
// QueryWithData performs a query to a Tendermint node with the provided path
Expand All @@ -18,14 +21,20 @@ type NodeQuerier interface {
// AccountRetriever defines the properties of a type that can be used to
// retrieve accounts.
type AccountRetriever struct {
querier NodeQuerier
querier NodeQuerier
checkState bool
}

// NewAccountRetriever initialises a new AccountRetriever instance.
func NewAccountRetriever(querier NodeQuerier) AccountRetriever {
return AccountRetriever{querier: querier}
}

func (ar AccountRetriever) WithCheckState(checkState bool) AccountRetriever {
ar.checkState = checkState
return ar
}

// GetAccount queries for an account given an address and a block height. An
// error is returned if the query or decoding fails.
func (ar AccountRetriever) GetAccount(addr sdk.AccAddress) (exported.Account, error) {
Expand All @@ -42,7 +51,11 @@ func (ar AccountRetriever) GetAccountWithHeight(addr sdk.AccAddress) (exported.A
return nil, 0, err
}

res, height, err := ar.querier.QueryWithData(fmt.Sprintf("custom/%s/%s", QuerierRoute, QueryAccount), bs)
var queryPath string
if queryPath = customQueryPath; ar.checkState {
queryPath = checkStateQueryPath
}
res, height, err := ar.querier.QueryWithData(queryPath, bs)
if err != nil {
return nil, height, err
}
Expand Down
33 changes: 30 additions & 3 deletions x/auth/types/account_retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,46 @@ func TestAccountRetriever(t *testing.T) {
bs, err := ModuleCdc.MarshalJSON(NewQueryAccountParams(addr))
require.NoError(t, err)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq("custom/acc/account"),
mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(customQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
_, err = accRetr.GetAccount(addr)
require.Error(t, err)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq("custom/acc/account"),
mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(customQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
n, s, err := accRetr.GetAccountNumberSequence(addr)
require.Error(t, err)
require.Equal(t, uint64(0), n)
require.Equal(t, uint64(0), s)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq("custom/acc/account"),
mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(customQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
require.Error(t, accRetr.EnsureExists(addr))
}

func TestAccountRetrieverWithCheckState(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockNodeQuerier := mocks.NewMockNodeQuerier(mockCtrl)
accRetr := NewAccountRetriever(mockNodeQuerier).WithCheckState(true)
addr := []byte("test")
bs, err := ModuleCdc.MarshalJSON(NewQueryAccountParams(addr))
require.NoError(t, err)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(checkStateQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
_, err = accRetr.GetAccount(addr)
require.Error(t, err)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(checkStateQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
n, s, err := accRetr.GetAccountNumberSequence(addr)
require.Error(t, err)
require.Equal(t, uint64(0), n)
require.Equal(t, uint64(0), s)

mockNodeQuerier.EXPECT().QueryWithData(gomock.Eq(checkStateQueryPath),
gomock.Eq(bs)).Return(nil, int64(0), errFoo).Times(1)
require.Error(t, accRetr.EnsureExists(addr))
}

0 comments on commit e6cb294

Please sign in to comment.