diff --git a/x/ratelimit/keeper/grpc_query.go b/x/ratelimit/keeper/grpc_query.go index 4270a6af2..32a03e2a1 100644 --- a/x/ratelimit/keeper/grpc_query.go +++ b/x/ratelimit/keeper/grpc_query.go @@ -7,7 +7,9 @@ import ( "google.golang.org/grpc/status" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + host "github.com/cosmos/ibc-go/v7/modules/core/24-host" ibctmtypes "github.com/cosmos/ibc-go/v7/modules/light-clients/07-tendermint" "github.com/notional-labs/centauri/v5/x/ratelimit/types" @@ -36,17 +38,28 @@ func (q queryServer) AllRateLimits(c context.Context, req *types.QueryAllRateLim return &types.QueryAllRateLimitsResponse{RateLimits: rateLimits}, nil } -// RateLimit queries a rate limit by denom and channel id. +// RateLimit queries the rate limit by the given denom and channel id. func (q queryServer) RateLimit(c context.Context, req *types.QueryRateLimitRequest) (*types.QueryRateLimitResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") } + if err := sdk.ValidateDenom(req.Denom); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid denom") + } + + if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid channel id") + } + ctx := sdk.UnwrapSDKContext(c) rateLimit, found := q.GetRateLimit(ctx, req.Denom, req.ChannelID) if !found { - return nil, status.Errorf(codes.NotFound, "rate limit by denom %s and channel id %s not found", req.Denom, req.ChannelID) + return nil, status.Errorf( + codes.NotFound, + sdkerrors.Wrapf(types.ErrRateLimitNotFound, "denom: %s, channel-id %s", req.Denom, req.ChannelID).Error(), + ) } return &types.QueryRateLimitResponse{RateLimit: &rateLimit}, nil } @@ -60,12 +73,11 @@ func (q queryServer) RateLimitsByChainID(c context.Context, req *types.QueryRate ctx := sdk.UnwrapSDKContext(c) chainId := req.ChainId - rateLimits := []types.RateLimit{} for _, rateLimit := range q.GetAllRateLimits(ctx) { _, clientState, err := q.channelKeeper.GetChannelClientState(ctx, transfertypes.PortID, rateLimit.Path.ChannelID) if err != nil { - return nil, status.Errorf(codes.NotFound, "unable to fetch client state by port id %s and channel id: %s", transfertypes.PortID, rateLimit.Path.ChannelID) + return nil, status.Error(codes.NotFound, err.Error()) } client, ok := clientState.(*ibctmtypes.ClientState) @@ -88,12 +100,17 @@ func (q queryServer) RateLimitsByChannelID(c context.Context, req *types.QueryRa return nil, status.Error(codes.InvalidArgument, "empty request") } + if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid channel id") + } + ctx := sdk.UnwrapSDKContext(c) + channelId := req.ChannelID rateLimits := []types.RateLimit{} for _, rateLimit := range q.GetAllRateLimits(ctx) { - // If the channel ID matches, add the rate limit to the returned list - if rateLimit.Path.ChannelID == req.ChannelID { + // Append the rate limit when it matches with the requested channel id + if rateLimit.Path.ChannelID == channelId { rateLimits = append(rateLimits, rateLimit) } } @@ -108,6 +125,7 @@ func (q queryServer) AllWhitelistedAddresses(c context.Context, req *types.Query } ctx := sdk.UnwrapSDKContext(c) + whitelistedAddresses := q.GetAllWhitelistedAddressPairs(ctx) return &types.QueryAllWhitelistedAddressesResponse{AddressPairs: whitelistedAddresses}, nil }