diff --git a/client.go b/client.go index 5a53b548..b85316e2 100644 --- a/client.go +++ b/client.go @@ -16,8 +16,8 @@ type Client struct { client *rpc.Client // rate limiter - rl *rate.Limiter - rlPerCall bool + rl *rate.Limiter + rlCostFunc func(method string) (cost int) } // NewClient returns a new Client given an rpc.Client client. @@ -72,11 +72,6 @@ func (c *Client) CallCtx(ctx context.Context, calls ...w3types.Caller) error { return nil } - // invoke rate limiter - if err := c.rateLimit(ctx, len(calls)); err != nil { - return err - } - // create requests batchElems := make([]rpc.BatchElem, len(calls)) var err error @@ -87,6 +82,11 @@ func (c *Client) CallCtx(ctx context.Context, calls ...w3types.Caller) error { } } + // invoke rate limiter + if err := c.rateLimit(ctx, batchElems); err != nil { + return err + } + // do requests if len(batchElems) > 1 { // batch requests if >1 request @@ -130,15 +130,22 @@ func (c *Client) Call(calls ...w3types.Caller) error { return c.CallCtx(context.Background(), calls...) } -func (c *Client) rateLimit(ctx context.Context, n int) error { +func (c *Client) rateLimit(ctx context.Context, batchElems []rpc.BatchElem) error { if c.rl == nil { return nil } - if c.rlPerCall { - return c.rl.WaitN(ctx, n) + if c.rlCostFunc == nil { + // limit requests + return c.rl.Wait(ctx) + } + + // limit requests based on Compute Units (CUs) + var cost int + for _, batchElem := range batchElems { + cost += c.rlCostFunc(batchElem.Method) } - return c.rl.Wait(ctx) + return c.rl.WaitN(ctx, cost) } // CallErrors is an error type that contains the errors of multiple calls. The @@ -171,3 +178,17 @@ func (e CallErrors) Is(target error) bool { _, ok := target.(CallErrors) return ok } + +// An Option configures a Client. +type Option func(*Client) + +// WithRateLimiter sets the rate limiter for the client. Set the optional argument +// costFunc to nil to limit the number of requests. Supply a costFunc to limit +// the the number of requests based on individual RPC calls for advanced rate +// limiting by Compute Units (CUs). +func WithRateLimiter(rl *rate.Limiter, costFunc func(method string) (cost int)) Option { + return func(c *Client) { + c.rl = rl + c.rlCostFunc = costFunc + } +} diff --git a/client_options.go b/client_options.go deleted file mode 100644 index fff1ac02..00000000 --- a/client_options.go +++ /dev/null @@ -1,16 +0,0 @@ -package w3 - -import "golang.org/x/time/rate" - -// An Option configures a Client. -type Option func(*Client) - -// WithRateLimiter sets the rate limiter for the client. If perCall is true, the -// rate limiter is applied to each call. Otherwise, the rate limiter is applied -// to each request. -func WithRateLimiter(rl *rate.Limiter, perCall bool) Option { - return func(c *Client) { - c.rl = rl - c.rlPerCall = perCall - } -} diff --git a/client_test.go b/client_test.go index f81bdc37..fd27242c 100644 --- a/client_test.go +++ b/client_test.go @@ -299,7 +299,7 @@ func ExampleWithRateLimiter() { // Limit the client to 30 requests per second and allow bursts of up to // 100 requests. client := w3.MustDial("https://rpc.ankr.com/eth", - w3.WithRateLimiter(rate.NewLimiter(rate.Every(time.Second/30), 100), false), + w3.WithRateLimiter(rate.NewLimiter(rate.Every(time.Second/30), 100), nil), ) defer client.Close() }