Skip to content

Commit

Permalink
feat: Add support of network mask
Browse files Browse the repository at this point in the history
  • Loading branch information
novln committed Aug 3, 2018
1 parent d578da3 commit 3b3ac78
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 84 deletions.
2 changes: 1 addition & 1 deletion drivers/middleware/stdlib/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
// Handler the middleware handler.
func (middleware *Middleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
context, err := middleware.Limiter.Get(r.Context(), limiter.GetIPKey(r, middleware.TrustForwardHeader))
context, err := middleware.Limiter.Get(r.Context(), middleware.Limiter.GetIPKey(r))
if err != nil {
middleware.OnError(w, r, err)
return
Expand Down
7 changes: 0 additions & 7 deletions drivers/middleware/stdlib/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,3 @@ func WithLimitReachedHandler(handler LimitReachedHandler) Option {
func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Limit exceeded", http.StatusTooManyRequests)
}

// WithForwardHeader will configure the Middleware to trust X-Real-IP and X-Forwarded-For headers.
func WithForwardHeader(trusted bool) Option {
return option(func(middleware *Middleware) {
middleware.TrustForwardHeader = trusted
})
}
20 changes: 15 additions & 5 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,25 @@ type Context struct {

// Limiter is the limiter instance.
type Limiter struct {
Store Store
Rate Rate
Store Store
Rate Rate
Options Options
}

// New returns an instance of Limiter.
func New(store Store, rate Rate) *Limiter {
func New(store Store, rate Rate, options ...Option) *Limiter {
opt := Options{
IPv4Mask: DefaultIPv4Mask,
IPv6Mask: DefaultIPv6Mask,
TrustForwardHeader: false,
}
for _, o := range options {
o(&opt)
}
return &Limiter{
Store: store,
Rate: rate,
Store: store,
Rate: rate,
Options: opt,
}
}

Expand Down
17 changes: 17 additions & 0 deletions limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package limiter_test

import (
"time"

"github.com/ulule/limiter"
"github.com/ulule/limiter/drivers/store/memory"
)

func New(options ...limiter.Option) *limiter.Limiter {
store := memory.NewStore()
rate := limiter.Rate{
Period: 1 * time.Second,
Limit: int64(10),
}
return limiter.New(store, rate, options...)
}
56 changes: 56 additions & 0 deletions network.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package limiter

import (
"net"
"net/http"
"strings"
)

var (
// DefaultIPv4Mask defines the default IPv4 mask used to obtain user IP.
DefaultIPv4Mask = net.CIDRMask(32, 32)
// DefaultIPv6Mask defines the default IPv6 mask used to obtain user IP.
DefaultIPv6Mask = net.CIDRMask(128, 128)
)

// GetIP returns IP address from request.
func (limiter *Limiter) GetIP(r *http.Request) net.IP {
if limiter.Options.TrustForwardHeader {
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
parts := strings.SplitN(ip, ",", 2)
part := strings.TrimSpace(parts[0])
return net.ParseIP(part)
}

ip = strings.TrimSpace(r.Header.Get("X-Real-IP"))
if ip != "" {
return net.ParseIP(ip)
}
}

remoteAddr := strings.TrimSpace(r.RemoteAddr)
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return net.ParseIP(remoteAddr)
}

return net.ParseIP(host)
}

// GetIPWithMask returns IP address from request by applying a mask.
func (limiter *Limiter) GetIPWithMask(r *http.Request) net.IP {
ip := limiter.GetIP(r)
if ip.To4() != nil {
return ip.Mask(limiter.Options.IPv4Mask)
}
if ip.To16() != nil {
return ip.Mask(limiter.Options.IPv6Mask)
}
return ip
}

// GetIPKey extracts IP from request and returns hashed IP to use as store key.
func (limiter *Limiter) GetIPKey(r *http.Request) string {
return limiter.GetIPWithMask(r).String()
}
98 changes: 76 additions & 22 deletions utils_test.go → network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
func TestGetIP(t *testing.T) {
is := require.New(t)

limiter1 := New(limiter.WithTrustForwardHeader(false))
limiter2 := New(limiter.WithTrustForwardHeader(true))
limiter3 := New(limiter.WithIPv4Mask(net.CIDRMask(24, 32)))
limiter4 := New(limiter.WithIPv6Mask(net.CIDRMask(48, 128)))

request1 := &http.Request{
URL: &url.URL{Path: "/"},
Header: http.Header{},
Expand All @@ -35,71 +40,98 @@ func TestGetIP(t *testing.T) {
}
request3.Header.Add("X-Real-IP", "6.6.6.6")

request4 := &http.Request{
URL: &url.URL{Path: "/"},
Header: http.Header{},
RemoteAddr: "[2001:db8:cafe:1234:beef::fafa]:8888",
}

scenarios := []struct {
request *http.Request
hasProxy bool
limiter *limiter.Limiter
expected net.IP
}{
{
//
// Scenario #1 : RemoteAddr without proxy.
//
request: request1,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
limiter: limiter1,
expected: net.ParseIP("8.8.8.8").To4(),
},
{
//
// Scenario #2 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
limiter: limiter1,
expected: net.ParseIP("8.8.8.8").To4(),
},
{
//
// Scenario #3 : X-Real-IP without proxy.
//
request: request3,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
limiter: limiter1,
expected: net.ParseIP("8.8.8.8").To4(),
},
{
//
// Scenario #4 : RemoteAddr with proxy.
//
request: request1,
hasProxy: true,
expected: net.ParseIP("8.8.8.8"),
limiter: limiter2,
expected: net.ParseIP("8.8.8.8").To4(),
},
{
//
// Scenario #5 : X-Forwarded-For with proxy.
//
request: request2,
hasProxy: true,
expected: net.ParseIP("9.9.9.9"),
limiter: limiter2,
expected: net.ParseIP("9.9.9.9").To4(),
},
{
//
// Scenario #6 : X-Real-IP with proxy.
//
request: request3,
hasProxy: true,
expected: net.ParseIP("6.6.6.6"),
limiter: limiter2,
expected: net.ParseIP("6.6.6.6").To4(),
},
{
//
// Scenario #7 : IPv4 with mask.
//
request: request1,
limiter: limiter3,
expected: net.ParseIP("8.8.8.0").To4(),
},
{
//
// Scenario #8 : IPv6 with mask.
//
request: request4,
limiter: limiter4,
expected: net.ParseIP("2001:db8:cafe::").To16(),
},
}

for i, scenario := range scenarios {
message := fmt.Sprintf("Scenario #%d", (i + 1))
ip := limiter.GetIP(scenario.request, scenario.hasProxy)
ip := scenario.limiter.GetIPWithMask(scenario.request)
is.Equal(scenario.expected, ip, message)
}
}

func TestGetIPKey(t *testing.T) {
is := require.New(t)

limiter1 := New(limiter.WithTrustForwardHeader(false))
limiter2 := New(limiter.WithTrustForwardHeader(true))
limiter3 := New(limiter.WithIPv4Mask(net.CIDRMask(24, 32)))
limiter4 := New(limiter.WithIPv6Mask(net.CIDRMask(48, 128)))

request1 := &http.Request{
URL: &url.URL{Path: "/"},
Header: http.Header{},
Expand All @@ -120,64 +152,86 @@ func TestGetIPKey(t *testing.T) {
}
request3.Header.Add("X-Real-IP", "6.6.6.6")

request4 := &http.Request{
URL: &url.URL{Path: "/"},
Header: http.Header{},
RemoteAddr: "[2001:db8:cafe:1234:beef::fafa]:8888",
}

scenarios := []struct {
request *http.Request
hasProxy bool
limiter *limiter.Limiter
expected string
}{
{
//
// Scenario #1 : RemoteAddr without proxy.
//
request: request1,
hasProxy: false,
limiter: limiter1,
expected: "8.8.8.8",
},
{
//
// Scenario #2 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: false,
limiter: limiter1,
expected: "8.8.8.8",
},
{
//
// Scenario #3 : X-Real-IP without proxy.
//
request: request3,
hasProxy: false,
limiter: limiter1,
expected: "8.8.8.8",
},
{
//
// Scenario #4 : RemoteAddr without proxy.
//
request: request1,
hasProxy: true,
limiter: limiter2,
expected: "8.8.8.8",
},
{
//
// Scenario #5 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: true,
limiter: limiter2,
expected: "9.9.9.9",
},
{
//
// Scenario #6 : X-Real-IP without proxy.
//
request: request3,
hasProxy: true,
limiter: limiter2,
expected: "6.6.6.6",
},
{
//
// Scenario #7 : IPv4 with mask.
//
request: request1,
limiter: limiter3,
expected: "8.8.8.0",
},
{
//
// Scenario #8 : IPv6 with mask.
//
request: request4,
limiter: limiter4,
expected: "2001:db8:cafe::",
},
}

for i, scenario := range scenarios {
message := fmt.Sprintf("Scenario #%d", (i + 1))
key := limiter.GetIPKey(scenario.request, scenario.hasProxy)
key := scenario.limiter.GetIPKey(scenario.request)
is.Equal(scenario.expected, key, message)
}
}
39 changes: 39 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package limiter

import (
"net"
)

// Option is a functional option.
type Option func(*Options)

// Options are limiter options.
type Options struct {
// IPv4Mask defines the mask used to obtain a IPv4 address.
IPv4Mask net.IPMask
// IPv6Mask defines the mask used to obtain a IPv6 address.
IPv6Mask net.IPMask
// TrustForwardHeader enable parsing of X-Real-IP and X-Forwarded-For headers to obtain user IP.
TrustForwardHeader bool
}

// WithIPv4Mask will configure the limiter to use given mask for IPv4 address.
func WithIPv4Mask(mask net.IPMask) Option {
return func(o *Options) {
o.IPv4Mask = mask
}
}

// WithIPv6Mask will configure the limiter to use given mask for IPv6 address.
func WithIPv6Mask(mask net.IPMask) Option {
return func(o *Options) {
o.IPv6Mask = mask
}
}

// WithTrustForwardHeader will configure the limiter to trust X-Real-IP and X-Forwarded-For headers.
func WithTrustForwardHeader(enable bool) Option {
return func(o *Options) {
o.TrustForwardHeader = enable
}
}
Loading

0 comments on commit 3b3ac78

Please sign in to comment.