diff --git a/.gitignore b/.gitignore index 003f7ba8e..fcbbd33d3 100644 --- a/.gitignore +++ b/.gitignore @@ -33,7 +33,6 @@ bin/* # Dependency directories (remove the comment below to include it) # vendor/ - .DS_Store .terraform *.pem @@ -43,4 +42,7 @@ bin/* .vendor vendor go.work -go.work.sum \ No newline at end of file +go.work.sum + +# Testing +coverage.txt diff --git a/Makefile b/Makefile index 0cecb6827..3a3e954bf 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ ifeq (,$(VERSION)) endif endif +PACKAGES_NOSIMULATION=$(shell go list ./... | grep -v '/simulation') LEDGER_ENABLED ?= true SDK_PACK := $(shell go list -m github.com/cosmos/cosmos-sdk | sed 's/ /\@/g') DOCKER := $(shell which docker) @@ -82,6 +83,9 @@ ifeq (,$(findstring nostrip,$(COSMOS_BUILD_OPTIONS))) BUILD_FLAGS += -trimpath endif +############################################################################### +### Build ### +############################################################################### all: install @@ -100,6 +104,24 @@ lint: @find . -name '*.go' -type f -not -path "./vendor*" -not -path "*.git*" -not -name '*.pb.go' -not -name '*.gw.go' | xargs go run golang.org/x/tools/cmd/goimports -w -local github.com/notional-labs/centauri .PHONY: lint +############################################################################### +### Tests & Simulation ### +############################################################################### + +test: test-unit +test-all: test-unit test-race test-cover + +test-unit: + @VERSION=$(VERSION) go test -mod=readonly -tags='norace' $(PACKAGES_NOSIMULATION) + +test-race: + @go test -mod=readonly -timeout 30m -race -coverprofile=coverage.txt -covermode=atomic -tags='ledger test_ledger_mock' ./... + +test-cover: + @go test -mod=readonly -timeout 30m -coverprofile=coverage.txt -covermode=atomic -tags='norace ledger test_ledger_mock' ./... + +.PHONY: test test-all test-unit test-race test-cover + ############################################################################### ### Proto ### ############################################################################### @@ -128,6 +150,8 @@ proto-check-breaking: @$(DOCKER_BUF) breaking --against $(HTTPS_GIT)#branch=main .PHONY: proto-all proto-gen proto-format proto-lint proto-check-breaking + +############################################################################### ### Interchain test ### ############################################################################### diff --git a/app/ibctesting/path.go b/app/ibctesting/path.go index 1d1eb3324..0f1500b90 100644 --- a/app/ibctesting/path.go +++ b/app/ibctesting/path.go @@ -1,6 +1,7 @@ package ibctesting import ( + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" ibctesting "github.com/cosmos/ibc-go/v7/testing" ) @@ -26,6 +27,12 @@ func NewPath(chainA, chainB *TestChain) *Path { } } +// SetChannelOrdered sets the channel order for both endpoints to ORDERED. +func (path *Path) SetChannelOrdered() { + path.EndpointA.ChannelConfig.Order = channeltypes.ORDERED + path.EndpointB.ChannelConfig.Order = channeltypes.ORDERED +} + // NewDefaultEndpoint constructs a new endpoint using default values. // CONTRACT: the counterparty endpoitn must be set by the caller. func NewDefaultEndpoint(chain *TestChain) *Endpoint { diff --git a/x/ratelimit/client/cli/query.go b/x/ratelimit/client/cli/query.go index 863873fcd..b45b699a8 100644 --- a/x/ratelimit/client/cli/query.go +++ b/x/ratelimit/client/cli/query.go @@ -7,6 +7,7 @@ import ( "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/flags" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/notional-labs/centauri/v5/x/ratelimit/types" ) @@ -24,6 +25,10 @@ func GetQueryCmd() *cobra.Command { cmd.AddCommand( GetCmdQueryAllRateLimits(), + GetCmdQueryRateLimit(), + GetRateLimitsByChainID(), + GetRateLimitsByChannelID(), + GetAllWhitelistedAddresses(), ) return cmd } @@ -35,7 +40,11 @@ func GetCmdQueryAllRateLimits() *cobra.Command { Short: "Query all rate limits", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - clientCtx := client.GetClientContextFromCmd(cmd) + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + queryClient := types.NewQueryClient(clientCtx) req := &types.QueryAllRateLimitsRequest{} @@ -52,3 +61,133 @@ func GetCmdQueryAllRateLimits() *cobra.Command { return cmd } + +// GetCmdQueryRateLimit return a rate limit by denom and channel id. +func GetCmdQueryRateLimit() *cobra.Command { + cmd := &cobra.Command{ + Use: "rate-limit [denom] [channel-id]", + Short: "Query a rate limit by denom and channel id", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + + denom := args[0] + channelID := args[1] + + if err := sdk.ValidateDenom(denom); err != nil { + return err + } + + queryClient := types.NewQueryClient(clientCtx) + + req := &types.QueryRateLimitRequest{ + Denom: denom, + ChannelID: channelID, + } + res, err := queryClient.RateLimit(cmd.Context(), req) + if err != nil { + return err + } + + return clientCtx.PrintProto(res) + }, + } + + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} + +// GetRateLimitsByChainID return all rate limits by chain id. +func GetRateLimitsByChainID() *cobra.Command { + cmd := &cobra.Command{ + Use: "list-rate-limits [chain-id]", + Short: "Query all rate limits by chain id", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + + queryClient := types.NewQueryClient(clientCtx) + + req := &types.QueryRateLimitsByChainIDRequest{ + ChainId: args[0], + } + res, err := queryClient.RateLimitsByChainID(cmd.Context(), req) + if err != nil { + return err + } + + return clientCtx.PrintProto(res) + }, + } + + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} + +// GetRateLimitsByChannelID return all rate limits by channel id. +func GetRateLimitsByChannelID() *cobra.Command { + cmd := &cobra.Command{ + Use: "list-rate-limits [channel-id]", + Short: "Query a rate limit by denom and channel id", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + + queryClient := types.NewQueryClient(clientCtx) + + req := &types.QueryRateLimitsByChannelIDRequest{ + ChannelID: args[0], + } + res, err := queryClient.RateLimitsByChannelID(cmd.Context(), req) + if err != nil { + return err + } + + return clientCtx.PrintProto(res) + }, + } + + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} + +// GetAllWhitelistedAddresses return all whitelisted addresses. +func GetAllWhitelistedAddresses() *cobra.Command { + cmd := &cobra.Command{ + Use: "list-whitelisted-addresses", + Short: "Query all whitelisted addresses", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + + queryClient := types.NewQueryClient(clientCtx) + + req := &types.QueryAllWhitelistedAddressesRequest{} + res, err := queryClient.AllWhitelistedAddresses(cmd.Context(), req) + if err != nil { + return err + } + + return clientCtx.PrintProto(res) + }, + } + + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} diff --git a/x/ratelimit/client/cli/tx.go b/x/ratelimit/client/cli/tx.go deleted file mode 100644 index f4bdde62b..000000000 --- a/x/ratelimit/client/cli/tx.go +++ /dev/null @@ -1,23 +0,0 @@ -package cli - -import ( - "fmt" - - "github.com/spf13/cobra" - - "github.com/notional-labs/centauri/v5/x/ratelimit/types" -) - -// GetTxCmd returns the tx commands for router -func GetTxCmd() *cobra.Command { - txCmd := &cobra.Command{ - Use: "transfermiddleware", - DisableFlagParsing: true, - SuggestionsMinimumDistance: 2, - Short: fmt.Sprintf("Tx commands for the %s module", types.ModuleName), - } - - txCmd.AddCommand() - - return txCmd -} diff --git a/x/ratelimit/keeper/grpc_query.go b/x/ratelimit/keeper/grpc_query.go index 87ff085c6..32a03e2a1 100644 --- a/x/ratelimit/keeper/grpc_query.go +++ b/x/ratelimit/keeper/grpc_query.go @@ -3,52 +3,90 @@ package keeper import ( "context" - errorsmod "cosmossdk.io/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + host "github.com/cosmos/ibc-go/v7/modules/core/24-host" ibctmtypes "github.com/cosmos/ibc-go/v7/modules/light-clients/07-tendermint" "github.com/notional-labs/centauri/v5/x/ratelimit/types" ) -var _ types.QueryServer = Keeper{} +var _ types.QueryServer = queryServer{} + +type queryServer struct { + Keeper +} + +// NewQueryServer returns an implementation of the QueryServer +// for the provided Keeper. +func NewQueryServer(k Keeper) types.QueryServer { + return queryServer{Keeper: k} +} + +// AllRateLimits queries all rate limits. +func (q queryServer) AllRateLimits(c context.Context, req *types.QueryAllRateLimitsRequest) (*types.QueryAllRateLimitsResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "empty request") + } -// Query all rate limits -func (k Keeper) AllRateLimits(goCtx context.Context, _ *types.QueryAllRateLimitsRequest) (*types.QueryAllRateLimitsResponse, error) { - ctx := sdk.UnwrapSDKContext(goCtx) - rateLimits := k.GetAllRateLimits(ctx) + ctx := sdk.UnwrapSDKContext(c) + rateLimits := q.GetAllRateLimits(ctx) return &types.QueryAllRateLimitsResponse{RateLimits: rateLimits}, nil } -// Query a rate limit by denom and channelID -func (k Keeper) RateLimit(goCtx context.Context, req *types.QueryRateLimitRequest) (*types.QueryRateLimitResponse, error) { - ctx := sdk.UnwrapSDKContext(goCtx) - rateLimit, found := k.GetRateLimit(ctx, req.Denom, req.ChannelID) +// RateLimit queries the rate limit by the given denom and channel id. +func (q queryServer) RateLimit(c context.Context, req *types.QueryRateLimitRequest) (*types.QueryRateLimitResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "empty request") + } + + if err := sdk.ValidateDenom(req.Denom); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid denom") + } + + if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid channel id") + } + + ctx := sdk.UnwrapSDKContext(c) + + rateLimit, found := q.GetRateLimit(ctx, req.Denom, req.ChannelID) if !found { - return &types.QueryRateLimitResponse{}, nil + return nil, status.Errorf( + codes.NotFound, + sdkerrors.Wrapf(types.ErrRateLimitNotFound, "denom: %s, channel-id %s", req.Denom, req.ChannelID).Error(), + ) } return &types.QueryRateLimitResponse{RateLimit: &rateLimit}, nil } -// Query all rate limits for a given chain -func (k Keeper) RateLimitsByChainID(goCtx context.Context, req *types.QueryRateLimitsByChainIDRequest) (*types.QueryRateLimitsByChainIDResponse, error) { - ctx := sdk.UnwrapSDKContext(goCtx) +// RateLimitsByChainID queries all rate limits for a given chain. +func (q queryServer) RateLimitsByChainID(c context.Context, req *types.QueryRateLimitsByChainIDRequest) (*types.QueryRateLimitsByChainIDResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "empty request") + } - rateLimits := []types.RateLimit{} - for _, rateLimit := range k.GetAllRateLimits(ctx) { + ctx := sdk.UnwrapSDKContext(c) - // Determine the client state from the channel Id - _, clientState, err := k.channelKeeper.GetChannelClientState(ctx, transfertypes.PortID, rateLimit.Path.ChannelID) + chainId := req.ChainId + rateLimits := []types.RateLimit{} + for _, rateLimit := range q.GetAllRateLimits(ctx) { + _, clientState, err := q.channelKeeper.GetChannelClientState(ctx, transfertypes.PortID, rateLimit.Path.ChannelID) if err != nil { - return &types.QueryRateLimitsByChainIDResponse{}, errorsmod.Wrapf(types.ErrInvalidClientState, "Unable to fetch client state from channelID") + return nil, status.Error(codes.NotFound, err.Error()) } + client, ok := clientState.(*ibctmtypes.ClientState) if !ok { - return &types.QueryRateLimitsByChainIDResponse{}, errorsmod.Wrapf(types.ErrInvalidClientState, "Client state is not tendermint") + return nil, status.Error(codes.InvalidArgument, "invalid client state") } - // If the chain ID matches, add the rate limit to the returned list - if client.ChainId == req.ChainId { + // Append the rate limit when it matches with the requested chain id + if client.ChainId == chainId { rateLimits = append(rateLimits, rateLimit) } } @@ -56,14 +94,23 @@ func (k Keeper) RateLimitsByChainID(goCtx context.Context, req *types.QueryRateL return &types.QueryRateLimitsByChainIDResponse{RateLimits: rateLimits}, nil } -// Query all rate limits for a given channel -func (k Keeper) RateLimitsByChannelID(goCtx context.Context, req *types.QueryRateLimitsByChannelIDRequest) (*types.QueryRateLimitsByChannelIDResponse, error) { - ctx := sdk.UnwrapSDKContext(goCtx) +// RateLimitsByChannelID queries all rate limits for a given channel. +func (q queryServer) RateLimitsByChannelID(c context.Context, req *types.QueryRateLimitsByChannelIDRequest) (*types.QueryRateLimitsByChannelIDResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "empty request") + } + + if err := host.ChannelIdentifierValidator(req.ChannelID); err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid channel id") + } + + ctx := sdk.UnwrapSDKContext(c) + channelId := req.ChannelID rateLimits := []types.RateLimit{} - for _, rateLimit := range k.GetAllRateLimits(ctx) { - // If the channel ID matches, add the rate limit to the returned list - if rateLimit.Path.ChannelID == req.ChannelID { + for _, rateLimit := range q.GetAllRateLimits(ctx) { + // Append the rate limit when it matches with the requested channel id + if rateLimit.Path.ChannelID == channelId { rateLimits = append(rateLimits, rateLimit) } } @@ -71,9 +118,14 @@ func (k Keeper) RateLimitsByChannelID(goCtx context.Context, req *types.QueryRat return &types.QueryRateLimitsByChannelIDResponse{RateLimits: rateLimits}, nil } -// Query all whitelisted addresses -func (k Keeper) AllWhitelistedAddresses(goCtx context.Context, _ *types.QueryAllWhitelistedAddressesRequest) (*types.QueryAllWhitelistedAddressesResponse, error) { - ctx := sdk.UnwrapSDKContext(goCtx) - whitelistedAddresses := k.GetAllWhitelistedAddressPairs(ctx) +// AllWhitelistedAddresses queries all whitelisted addresses. +func (q queryServer) AllWhitelistedAddresses(c context.Context, req *types.QueryAllWhitelistedAddressesRequest) (*types.QueryAllWhitelistedAddressesResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "empty request") + } + + ctx := sdk.UnwrapSDKContext(c) + + whitelistedAddresses := q.GetAllWhitelistedAddressPairs(ctx) return &types.QueryAllWhitelistedAddressesResponse{AddressPairs: whitelistedAddresses}, nil } diff --git a/x/ratelimit/keeper/grpc_query_test.go b/x/ratelimit/keeper/grpc_query_test.go new file mode 100644 index 000000000..b1b29a8b0 --- /dev/null +++ b/x/ratelimit/keeper/grpc_query_test.go @@ -0,0 +1,303 @@ +package keeper_test + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + // ibctesting "github.com/cosmos/ibc-go/v7/testing" + + ibctesting "github.com/notional-labs/centauri/v5/app/ibctesting" + "github.com/notional-labs/centauri/v5/x/ratelimit/types" +) + +func (s *KeeperTestSuite) TestGRPCAllRateLimits() { + // Add some sample rate limits + s.SetupSampleRateLimits(sampleRateLimitA, sampleRateLimitB, sampleRateLimitC, sampleRateLimitD) + + for _, tc := range []struct { + name string + req *types.QueryAllRateLimitsRequest + expectErr bool + postRun func(*types.QueryAllRateLimitsResponse) + }{ + { + "nil request", + nil, + true, + nil, + }, + { + "happy case", + &types.QueryAllRateLimitsRequest{}, + false, + func(resp *types.QueryAllRateLimitsResponse) { + s.Require().Len(resp.GetRateLimits(), 4) + s.Require().Equal(sampleRateLimitA.Denom, resp.RateLimits[0].Path.Denom) + s.Require().Equal(sampleRateLimitA.ChannelID, resp.RateLimits[0].Path.ChannelID) + s.Require().Equal(sampleRateLimitA.MaxPercentSend, resp.RateLimits[0].Quota.MaxPercentSend) + s.Require().Equal(sampleRateLimitA.MaxPercentRecv, resp.RateLimits[0].Quota.MaxPercentRecv) + s.Require().Equal(sampleRateLimitA.MinRateLimitAmount, resp.RateLimits[0].MinRateLimitAmount) + s.Require().Equal(sampleRateLimitA.DurationHours, resp.RateLimits[0].Quota.DurationHours) + }, + }, + } { + s.Run(tc.name, func() { + resp, err := s.querier.AllRateLimits(sdk.WrapSDKContext(s.ctx), tc.req) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + tc.postRun(resp) + } + }) + } +} + +func (s *KeeperTestSuite) TestGRPCRateLimit() { + // Add some sample rate limits + s.SetupSampleRateLimits(sampleRateLimitA, sampleRateLimitB, sampleRateLimitC, sampleRateLimitD) + + for _, tc := range []struct { + name string + req *types.QueryRateLimitRequest + expectErr bool + postRun func(*types.QueryRateLimitResponse) + }{ + { + "nil request", + nil, + true, + nil, + }, + { + "happy case", + &types.QueryRateLimitRequest{ + Denom: sampleRateLimitA.Denom, + ChannelID: sampleRateLimitA.ChannelID, + }, + false, + func(resp *types.QueryRateLimitResponse) { + s.Require().Equal(sampleRateLimitA.Denom, resp.RateLimit.Path.Denom) + s.Require().Equal(sampleRateLimitA.ChannelID, resp.RateLimit.Path.ChannelID) + s.Require().Equal(sampleRateLimitA.MaxPercentSend, resp.RateLimit.Quota.MaxPercentSend) + s.Require().Equal(sampleRateLimitA.MaxPercentRecv, resp.RateLimit.Quota.MaxPercentRecv) + s.Require().Equal(sampleRateLimitA.MinRateLimitAmount, resp.RateLimit.MinRateLimitAmount) + s.Require().Equal(sampleRateLimitA.DurationHours, resp.RateLimit.Quota.DurationHours) + }, + }, + { + "query by invalid denom", + &types.QueryRateLimitRequest{ + Denom: "invalidDenom", + ChannelID: sampleRateLimitA.ChannelID, + }, + true, + nil, + }, + { + "query by invalid channel id", + &types.QueryRateLimitRequest{ + Denom: sampleRateLimitA.Denom, + ChannelID: "invalidChannelID", + }, + true, + nil, + }, + } { + s.Run(tc.name, func() { + resp, err := s.querier.RateLimit(sdk.WrapSDKContext(s.ctx), tc.req) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + tc.postRun(resp) + } + }) + } +} + +func (s *KeeperTestSuite) TestRateLimitsByChainID() { + // Add some sample rate limits + s.SetupSampleRateLimits(sampleRateLimitA, sampleRateLimitB) + + // // Create client and connections on both chains + // path := ibctesting.NewPath(s.chainA, s.chainB) + // s.coordinator.SetupConnections(path) + // path.SetChannelOrdered() + + // // Initialize channel + // err := path.EndpointA.ChanOpenInit() + // s.Require().NoError(err) + + for _, tc := range []struct { + name string + req *types.QueryRateLimitsByChainIDRequest + expectErr bool + postRun func(*types.QueryRateLimitsByChainIDResponse) + }{ + { + "nil request", + nil, + true, + nil, + }, + { + "happy case", + &types.QueryRateLimitsByChainIDRequest{ + ChainId: s.chainA.ChainID, + }, + false, + func(resp *types.QueryRateLimitsByChainIDResponse) { + fmt.Println("resp: ", resp) + // s.Require().Len(resp.GetRateLimits(), 2) + }, + }, + } { + s.Run(tc.name, func() { + ctx := sdk.WrapSDKContext(s.chainA.GetContext()) + + resp, err := s.querier.RateLimitsByChainID(ctx, tc.req) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + tc.postRun(resp) + } + }) + } +} + +func (s *KeeperTestSuite) TestRateLimitsByChannelID() { + // Add some sample rate limits + s.SetupSampleRateLimits(sampleRateLimitA, sampleRateLimitB, sampleRateLimitC, sampleRateLimitD) + + for _, tc := range []struct { + name string + req *types.QueryRateLimitsByChannelIDRequest + expectErr bool + postRun func(*types.QueryRateLimitsByChannelIDResponse) + }{ + { + "nil request", + nil, + true, + nil, + }, + { + "happy case", + &types.QueryRateLimitsByChannelIDRequest{ + ChannelID: sampleRateLimitA.ChannelID, + }, + false, + func(resp *types.QueryRateLimitsByChannelIDResponse) { + s.Require().Len(resp.GetRateLimits(), 2) + s.Require().Equal(sampleRateLimitA.Denom, resp.RateLimits[0].Path.Denom) + s.Require().Equal(sampleRateLimitA.ChannelID, resp.RateLimits[0].Path.ChannelID) + s.Require().Equal(sampleRateLimitA.MaxPercentSend, resp.RateLimits[0].Quota.MaxPercentSend) + s.Require().Equal(sampleRateLimitA.MaxPercentRecv, resp.RateLimits[0].Quota.MaxPercentRecv) + s.Require().Equal(sampleRateLimitA.MinRateLimitAmount, resp.RateLimits[0].MinRateLimitAmount) + s.Require().Equal(sampleRateLimitA.DurationHours, resp.RateLimits[0].Quota.DurationHours) + s.Require().Equal(sampleRateLimitB.Denom, resp.RateLimits[1].Path.Denom) + s.Require().Equal(sampleRateLimitB.ChannelID, resp.RateLimits[1].Path.ChannelID) + s.Require().Equal(sampleRateLimitB.MaxPercentSend, resp.RateLimits[1].Quota.MaxPercentSend) + s.Require().Equal(sampleRateLimitB.MaxPercentRecv, resp.RateLimits[1].Quota.MaxPercentRecv) + s.Require().Equal(sampleRateLimitB.MinRateLimitAmount, resp.RateLimits[1].MinRateLimitAmount) + s.Require().Equal(sampleRateLimitB.DurationHours, resp.RateLimits[1].Quota.DurationHours) + }, + }, + { + "query by chain id that does not exist", + &types.QueryRateLimitsByChannelIDRequest{ + ChannelID: "channel-10", + }, + false, + func(resp *types.QueryRateLimitsByChannelIDResponse) { + s.Require().Empty(resp.RateLimits) + }, + }, + { + "query by invalid chain id", + &types.QueryRateLimitsByChannelIDRequest{ + ChannelID: "invalid/ChannelID", + }, + true, + nil, + }, + } { + s.Run(tc.name, func() { + resp, err := s.querier.RateLimitsByChannelID(sdk.WrapSDKContext(s.ctx), tc.req) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + tc.postRun(resp) + } + }) + } +} + +func (s *KeeperTestSuite) TestAllWhitelistedAddresses() { + // Add some sample whitelisted addresses + whitelistedAddrPairs := []types.WhitelistedAddressPair{ + {Sender: s.addr(1).String(), Receiver: s.addr(2).String()}, + {Sender: s.addr(3).String(), Receiver: s.addr(4).String()}, + {Sender: s.addr(5).String(), Receiver: s.addr(6).String()}, + } + + for _, wap := range whitelistedAddrPairs { + s.keeper.SetWhitelistedAddressPair(s.ctx, wap) + } + + for _, tc := range []struct { + name string + req *types.QueryAllWhitelistedAddressesRequest + expectErr bool + postRun func(*types.QueryAllWhitelistedAddressesResponse) + }{ + { + "nil request", + nil, + true, + nil, + }, + { + "happy case", + &types.QueryAllWhitelistedAddressesRequest{}, + false, + func(resp *types.QueryAllWhitelistedAddressesResponse) { + s.Require().Len(resp.GetAddressPairs(), 3) + }, + }, + } { + s.Run(tc.name, func() { + resp, err := s.querier.AllWhitelistedAddresses(sdk.WrapSDKContext(s.ctx), tc.req) + if tc.expectErr { + s.Require().Error(err) + } else { + s.Require().NoError(err) + tc.postRun(resp) + } + }) + } +} + +func (s *KeeperTestSuite) TestChanOpenInit() { + // Create client and connections on both chains + path := ibctesting.NewPath(s.chainA, s.chainB) + s.coordinator.SetupConnections(path) + + path.SetChannelOrdered() + + // Initialize channel + // + // Works well with ibctesting "github.com/cosmos/ibc-go/v7/testing" but it does not working due to the following error message with customibctesting + // Error: could not retrieve module from port-id: ports/mock: capability not found + // + err := path.EndpointA.ChanOpenInit() + s.Require().NoError(err) + + storedChannel, found := s.chainA.App.GetIBCKeeper().ChannelKeeper.GetChannel(s.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + s.True(found) + fmt.Println("storedChannel: ", storedChannel) +} diff --git a/x/ratelimit/keeper/keeper_test.go b/x/ratelimit/keeper/keeper_test.go new file mode 100644 index 000000000..934391227 --- /dev/null +++ b/x/ratelimit/keeper/keeper_test.go @@ -0,0 +1,155 @@ +package keeper_test + +import ( + "encoding/binary" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + sdkmath "cosmossdk.io/math" + tmproto "github.com/cometbft/cometbft/proto/tendermint/types" + sdk "github.com/cosmos/cosmos-sdk/types" + minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" + + // ibctesting "github.com/cosmos/ibc-go/v7/testing" + + "github.com/notional-labs/centauri/v5/app" + "github.com/notional-labs/centauri/v5/app/helpers" + ibctesting "github.com/notional-labs/centauri/v5/app/ibctesting" + "github.com/notional-labs/centauri/v5/x/ratelimit/keeper" + "github.com/notional-labs/centauri/v5/x/ratelimit/types" +) + +var ( + sampleRateLimitA = types.MsgAddRateLimit{ + Denom: "denomA", + ChannelID: "channel-0", + MaxPercentSend: sdkmath.NewInt(10), + MaxPercentRecv: sdkmath.NewInt(10), + MinRateLimitAmount: sdkmath.NewInt(1_000_000), + DurationHours: uint64(1), + } + sampleRateLimitB = types.MsgAddRateLimit{ + Denom: "denomB", + ChannelID: "channel-0", + MaxPercentSend: sdkmath.NewInt(20), + MaxPercentRecv: sdkmath.NewInt(20), + MinRateLimitAmount: sdkmath.NewInt(1_000_000), + DurationHours: uint64(1), + } + sampleRateLimitC = types.MsgAddRateLimit{ + Denom: "denomB", + ChannelID: "channel-1", + MaxPercentSend: sdkmath.NewInt(50), + MaxPercentRecv: sdkmath.NewInt(50), + MinRateLimitAmount: sdkmath.NewInt(5_000_000), + DurationHours: uint64(5), + } + sampleRateLimitD = types.MsgAddRateLimit{ + Denom: "denomC", + ChannelID: "channel-2", + MaxPercentSend: sdkmath.NewInt(80), + MaxPercentRecv: sdkmath.NewInt(80), + MinRateLimitAmount: sdkmath.NewInt(10_000_000), + DurationHours: uint64(10), + } +) + +type KeeperTestSuite struct { + suite.Suite + + app *app.CentauriApp + ctx sdk.Context + keeper keeper.Keeper + querier types.QueryServer + msgServer types.MsgServer + + coordinator *ibctesting.Coordinator + + // testing chains used for convenience and readability + chainA *ibctesting.TestChain + chainB *ibctesting.TestChain +} + +func TestKeeperTestSuite(t *testing.T) { + suite.Run(t, new(KeeperTestSuite)) +} + +func (s *KeeperTestSuite) SetupTest() { + s.app = helpers.SetupCentauriAppWithValSet(s.T()) + s.ctx = s.app.BaseApp.NewContext(false, tmproto.Header{ + Height: 1, + ChainID: "centauri-1", + Time: time.Now().UTC(), + }) + s.keeper = s.app.RatelimitKeeper + s.querier = keeper.NewQueryServer(s.keeper) + s.msgServer = keeper.NewMsgServerImpl(s.keeper) + + // Creates a coordinator with 2 test chains + s.coordinator = ibctesting.NewCoordinator(s.T(), 2) + s.chainA = s.coordinator.GetChain(ibctesting.GetChainID(0)) + s.chainB = s.coordinator.GetChain(ibctesting.GetChainID(1)) + + // Commit some blocks so that QueryProof returns valid proof (cannot return valid query if height <= 1) + s.coordinator.CommitNBlocks(s.chainA, 2) + s.coordinator.CommitNBlocks(s.chainB, 2) +} + +func (s *KeeperTestSuite) SetupSampleRateLimits(rateLimits ...types.MsgAddRateLimit) { + for _, rateLimit := range rateLimits { + s.addRateLimit( + rateLimit.Denom, + rateLimit.ChannelID, + rateLimit.MaxPercentSend, + rateLimit.MaxPercentRecv, + rateLimit.MinRateLimitAmount, + rateLimit.DurationHours, + ) + } +} + +// +// Below are helper functions to write test code easily +// + +func (s *KeeperTestSuite) addr(addrNum int) sdk.AccAddress { + addr := make(sdk.AccAddress, 20) + binary.PutVarint(addr, int64(addrNum)) + return addr +} + +func (s *KeeperTestSuite) fundAddr(addr sdk.AccAddress, amt sdk.Coins) { + s.T().Helper() + err := s.app.BankKeeper.MintCoins(s.ctx, minttypes.ModuleName, amt) + s.Require().NoError(err) + err = s.app.BankKeeper.SendCoinsFromModuleToAccount(s.ctx, minttypes.ModuleName, addr, amt) + s.Require().NoError(err) +} + +// addRateLimit is a convenient method to add new RateLimit without the need of authority. +func (s *KeeperTestSuite) addRateLimit( + denom string, + channelID string, + maxPercentSend sdkmath.Int, + maxPercentRecv sdkmath.Int, + minRateLimitAmount sdkmath.Int, + durationHours uint64, +) { + s.T().Helper() + + // Increase total supply since adding new rate limit requires total supply of the given denom + s.fundAddr(s.addr(0), sdk.NewCoins(sdk.NewCoin(denom, sdk.NewInt(100_000_000_000)))) + + err := s.keeper.AddRateLimit(s.ctx, &types.MsgAddRateLimit{ + Authority: "", + Denom: denom, + ChannelID: channelID, + MaxPercentSend: maxPercentSend, + MaxPercentRecv: maxPercentRecv, + MinRateLimitAmount: minRateLimitAmount, + DurationHours: durationHours, + }) + s.Require().NoError(err) +} diff --git a/x/ratelimit/keeper/msg_server.go b/x/ratelimit/keeper/msg_server.go index 3ea6ce7d9..f6d7a1496 100644 --- a/x/ratelimit/keeper/msg_server.go +++ b/x/ratelimit/keeper/msg_server.go @@ -12,81 +12,70 @@ import ( var _ types.MsgServer = msgServer{} +type msgServer struct { + Keeper +} + // NewMsgServerImpl returns an implementation of the MsgServer interface // for the provided Keeper. -func NewMsgServerImpl(keeper Keeper) types.MsgServer { +func NewMsgServerImpl(k Keeper) types.MsgServer { return &msgServer{ - Keeper: keeper, + Keeper: k, } } -type msgServer struct { - Keeper -} - -func (k Keeper) AddTransferRateLimit(goCtx context.Context, msg *types.MsgAddRateLimit) (*types.MsgAddRateLimitResponse, error) { +func (m msgServer) AddTransferRateLimit(goCtx context.Context, msg *types.MsgAddRateLimit) (*types.MsgAddRateLimitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if k.authority != msg.Authority { - return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, msg.Authority) + if m.authority != msg.Authority { + return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", m.authority, msg.Authority) } - if err := msg.ValidateBasic(); err != nil { - return nil, err - } - - err := k.AddRateLimit(ctx, msg) - if err != nil { + if err := m.AddRateLimit(ctx, msg); err != nil { return nil, err } return &types.MsgAddRateLimitResponse{}, nil } -func (k Keeper) UpdateTransferRateLimit(goCtx context.Context, msg *types.MsgUpdateRateLimit) (*types.MsgUpdateRateLimitResponse, error) { +func (m msgServer) UpdateTransferRateLimit(goCtx context.Context, msg *types.MsgUpdateRateLimit) (*types.MsgUpdateRateLimitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if k.authority != msg.Authority { - return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, msg.Authority) + if m.authority != msg.Authority { + return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", m.authority, msg.Authority) } - if err := msg.ValidateBasic(); err != nil { - return nil, err - } - - err := k.UpdateRateLimit(ctx, msg) - if err != nil { + if err := m.UpdateRateLimit(ctx, msg); err != nil { return nil, err } return &types.MsgUpdateRateLimitResponse{}, nil } -func (k Keeper) RemoveTransferRateLimit(goCtx context.Context, msg *types.MsgRemoveRateLimit) (*types.MsgRemoveRateLimitResponse, error) { +func (m msgServer) RemoveTransferRateLimit(goCtx context.Context, msg *types.MsgRemoveRateLimit) (*types.MsgRemoveRateLimitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if k.authority != msg.Authority { - return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, msg.Authority) + if m.authority != msg.Authority { + return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", m.authority, msg.Authority) } - err := k.RemoveRateLimit(ctx, msg.Denom, msg.ChannelID) - if err != nil { + if err := m.RemoveRateLimit(ctx, msg.Denom, msg.ChannelID); err != nil { return nil, err } return &types.MsgRemoveRateLimitResponse{}, nil } -func (k Keeper) ResetTransferRateLimit(goCtx context.Context, msg *types.MsgResetRateLimit) (*types.MsgResetRateLimitResponse, error) { +func (m msgServer) ResetTransferRateLimit(goCtx context.Context, msg *types.MsgResetRateLimit) (*types.MsgResetRateLimitResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if k.authority != msg.Authority { - return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", k.authority, msg.Authority) + if m.authority != msg.Authority { + return nil, errors.Wrapf(govtypes.ErrInvalidSigner, "invalid authority; expected %s, got %s", m.authority, msg.Authority) } - err := k.ResetRateLimit(ctx, msg.Denom, msg.ChannelID) - if err != nil { + if err := m.ResetRateLimit(ctx, msg.Denom, msg.ChannelID); err != nil { return nil, err } + return &types.MsgResetRateLimitResponse{}, nil } diff --git a/x/ratelimit/keeper/rate_limit.go b/x/ratelimit/keeper/rate_limit.go index 184c8bb10..c8e569a25 100644 --- a/x/ratelimit/keeper/rate_limit.go +++ b/x/ratelimit/keeper/rate_limit.go @@ -180,7 +180,7 @@ func (k Keeper) GetRateLimit(ctx sdk.Context, denom, channelID string) (rateLimi return rateLimit, true } -// AddRateLimit +// AddRateLimit adds new rate limit func (k Keeper) AddRateLimit(ctx sdk.Context, msg *types.MsgAddRateLimit) error { // Check if this is denom - channel transfer from Picasso denom := msg.Denom @@ -190,6 +190,7 @@ func (k Keeper) AddRateLimit(ctx sdk.Context, msg *types.MsgAddRateLimit) error denom = tokenInfo.IbcDenom } } + // Confirm the channel value is not zero channelValue := k.GetChannelValue(ctx, denom) if channelValue.IsZero() { diff --git a/x/ratelimit/module.go b/x/ratelimit/module.go index e523d8f91..37bf49b7e 100644 --- a/x/ratelimit/module.go +++ b/x/ratelimit/module.go @@ -97,7 +97,7 @@ func (AppModule) QuerierRoute() string { // RegisterServices registers module services. func (am AppModule) RegisterServices(cfg module.Configurator) { - types.RegisterQueryServer(cfg.QueryServer(), am.keeper) + types.RegisterQueryServer(cfg.QueryServer(), keeper.NewQueryServer(*am.keeper)) types.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(*am.keeper)) } diff --git a/x/ratelimit/types/msg.go b/x/ratelimit/types/msg.go index 5a100c3e8..307438243 100644 --- a/x/ratelimit/types/msg.go +++ b/x/ratelimit/types/msg.go @@ -8,6 +8,14 @@ import ( host "github.com/cosmos/ibc-go/v7/modules/core/24-host" ) +var ( + _ sdk.Msg = &MsgAddRateLimit{} + _ sdk.Msg = &MsgUpdateRateLimit{} + _ sdk.Msg = &MsgRemoveRateLimit{} + _ sdk.Msg = &MsgResetRateLimit{} +) + +// Message types for the module const ( TypeMsgAddRateLimit = "add_rate_limit" TypeMsgUpdateRateLimit = "update_rate_limit" @@ -15,8 +23,6 @@ const ( TypeMsgResetRateLimit = "reset_rate_limit" ) -var _ sdk.Msg = &MsgAddRateLimit{} - func NewMsgAddRateLimit( authority string, denom string, @@ -54,12 +60,10 @@ func (msg *MsgAddRateLimit) GetSigners() []sdk.AccAddress { // ValidateBasic does a sanity check on the provided data. func (msg *MsgAddRateLimit) ValidateBasic() error { - // validate authority if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil { return errorsmod.Wrap(err, "invalid authority address") } - // validate channelIDs if err := host.ChannelIdentifierValidator(msg.ChannelID); err != nil { return err } @@ -87,8 +91,6 @@ func (msg *MsgAddRateLimit) ValidateBasic() error { return nil } -var _ sdk.Msg = &MsgUpdateRateLimit{} - func NewMsgUpdateRateLimit( authority string, denom string, @@ -126,12 +128,10 @@ func (msg *MsgUpdateRateLimit) GetSigners() []sdk.AccAddress { // ValidateBasic does a sanity check on the provided data. func (msg *MsgUpdateRateLimit) ValidateBasic() error { - // validate authority if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil { return errorsmod.Wrap(err, "invalid authority address") } - // validate channelIDs if err := host.ChannelIdentifierValidator(msg.ChannelID); err != nil { return err } @@ -159,8 +159,6 @@ func (msg *MsgUpdateRateLimit) ValidateBasic() error { return nil } -var _ sdk.Msg = &MsgRemoveRateLimit{} - func NewMsgRemoveRateLimit( authority string, denom string, @@ -192,19 +190,17 @@ func (msg *MsgRemoveRateLimit) GetSigners() []sdk.AccAddress { // ValidateBasic does a sanity check on the provided data. func (msg *MsgRemoveRateLimit) ValidateBasic() error { - // validate authority if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil { return errorsmod.Wrap(err, "invalid authority address") } - // validate channelIDs - err := host.ChannelIdentifierValidator(msg.ChannelID) + if err := host.ChannelIdentifierValidator(msg.ChannelID); err != nil { + return err + } - return err + return nil } -var _ sdk.Msg = &MsgResetRateLimit{} - func NewMsgResetRateLimit( authority string, denom string, @@ -236,13 +232,13 @@ func (msg *MsgResetRateLimit) GetSigners() []sdk.AccAddress { // ValidateBasic does a sanity check on the provided data. func (msg *MsgResetRateLimit) ValidateBasic() error { - // validate authority if _, err := sdk.AccAddressFromBech32(msg.Authority); err != nil { return errorsmod.Wrap(err, "invalid authority address") } - // validate channelIDs - err := host.ChannelIdentifierValidator(msg.ChannelID) + if err := host.ChannelIdentifierValidator(msg.ChannelID); err != nil { + return err + } - return err + return nil }