Skip to content

Commit

Permalink
feat: add necessary validation checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybxyz committed Oct 1, 2023
1 parent 344d9a6 commit 9c21e1d
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions x/ratelimit/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"google.golang.org/grpc/status"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types"
host "github.com/cosmos/ibc-go/v7/modules/core/24-host"
ibctmtypes "github.com/cosmos/ibc-go/v7/modules/light-clients/07-tendermint"

"github.com/notional-labs/centauri/v5/x/ratelimit/types"
Expand Down Expand Up @@ -36,17 +38,28 @@ func (q queryServer) AllRateLimits(c context.Context, req *types.QueryAllRateLim
return &types.QueryAllRateLimitsResponse{RateLimits: rateLimits}, nil
}

// RateLimit queries a rate limit by denom and channel id.
// RateLimit queries the rate limit by the given denom and channel id.
func (q queryServer) RateLimit(c context.Context, req *types.QueryRateLimitRequest) (*types.QueryRateLimitResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "empty request")
}

if err := sdk.ValidateDenom(req.Denom); err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid denom")
}

if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid channel id")
}

ctx := sdk.UnwrapSDKContext(c)

rateLimit, found := q.GetRateLimit(ctx, req.Denom, req.ChannelID)
if !found {
return nil, status.Errorf(codes.NotFound, "rate limit by denom %s and channel id %s not found", req.Denom, req.ChannelID)
return nil, status.Errorf(
codes.NotFound,
sdkerrors.Wrapf(types.ErrRateLimitNotFound, "denom: %s, channel-id %s", req.Denom, req.ChannelID).Error(),
)
}
return &types.QueryRateLimitResponse{RateLimit: &rateLimit}, nil
}
Expand All @@ -60,12 +73,11 @@ func (q queryServer) RateLimitsByChainID(c context.Context, req *types.QueryRate
ctx := sdk.UnwrapSDKContext(c)

chainId := req.ChainId

rateLimits := []types.RateLimit{}
for _, rateLimit := range q.GetAllRateLimits(ctx) {
_, clientState, err := q.channelKeeper.GetChannelClientState(ctx, transfertypes.PortID, rateLimit.Path.ChannelID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to fetch client state by port id %s and channel id: %s", transfertypes.PortID, rateLimit.Path.ChannelID)
return nil, status.Error(codes.NotFound, err.Error())
}

client, ok := clientState.(*ibctmtypes.ClientState)
Expand All @@ -88,12 +100,17 @@ func (q queryServer) RateLimitsByChannelID(c context.Context, req *types.QueryRa
return nil, status.Error(codes.InvalidArgument, "empty request")
}

if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid channel id")
}

ctx := sdk.UnwrapSDKContext(c)

channelId := req.ChannelID
rateLimits := []types.RateLimit{}
for _, rateLimit := range q.GetAllRateLimits(ctx) {
// If the channel ID matches, add the rate limit to the returned list
if rateLimit.Path.ChannelID == req.ChannelID {
// Append the rate limit when it matches with the requested channel id
if rateLimit.Path.ChannelID == channelId {
rateLimits = append(rateLimits, rateLimit)
}
}
Expand All @@ -108,6 +125,7 @@ func (q queryServer) AllWhitelistedAddresses(c context.Context, req *types.Query
}

ctx := sdk.UnwrapSDKContext(c)

whitelistedAddresses := q.GetAllWhitelistedAddressPairs(ctx)
return &types.QueryAllWhitelistedAddressesResponse{AddressPairs: whitelistedAddresses}, nil
}

0 comments on commit 9c21e1d

Please sign in to comment.