Skip to content

Commit

Permalink
improve accuracy of real-ip detection for onlyfrom, allow private ip …
Browse files Browse the repository at this point in the history
…to be realip if no public ip present in headers
  • Loading branch information
umputun committed Oct 23, 2023
1 parent 476a2f9 commit 53446d3
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 55 deletions.
2 changes: 1 addition & 1 deletion benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestBenchmark_Cleanup(t *testing.T) {
MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res)
}
{
res := bench.Stats(time.Minute)
res := bench.Stats(time.Minute - time.Second)
t.Logf("%+v", res)
assert.Equal(t, BenchmarkStats{Requests: 60, RequestsSec: 1, AverageRespTime: 50000,
MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res)
Expand Down
26 changes: 11 additions & 15 deletions onlyfrom.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ import (
"net"
"net/http"
"strings"

"github.com/go-pkgz/rest/realip"
)

// OnlyFrom middleware allows access for limited list of source IPs.
// Such IPs can be defined as complete ip (like 192.168.1.12), prefix (129.168.) or CIDR (192.168.0.0/16)
func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler {

return func(h http.Handler) http.Handler {

fn := func(w http.ResponseWriter, r *http.Request) {

if len(onlyIps) == 0 {
// no restrictions if no ips defined
h.ServeHTTP(w, r)
return
}
matched, ip := matchSourceIP(r, onlyIps)
if matched {
// matched ip - allow
h.ServeHTTP(w, r)
return
}
Expand All @@ -30,19 +35,10 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler {

// matchSourceIP returns true if request's ip matches any of ips
func matchSourceIP(r *http.Request, ips []string) (result bool, match string) {

// try X-Real-IP first then fail back to X-Forwarded-For and finally to RemoteAddr
ip := r.Header.Get("X-Real-IP")
if ip == "" {
ip = strings.Split(r.Header.Get("X-Forwarded-For"), ", ")[0]
}
if ip == "" {
ip = r.Header.Get("RemoteAddr")
ip, err := realip.Get(r)
if err != nil {
return false, "" // we can't get ip, so no match
}
if ip == "" {
ip = strings.Split(r.RemoteAddr, ":")[0]
}

// check for ip prefix or CIDR
for _, exclIP := range ips {
if _, cidrnet, err := net.ParseCIDR(exclIP); err == nil {
Expand Down
46 changes: 25 additions & 21 deletions onlyfrom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestOnlyFromAllowed(t *testing.T) {

func TestOnlyFromAllowedIP(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
Expand All @@ -30,7 +29,6 @@ func TestOnlyFromAllowed(t *testing.T) {
}

func TestOnlyFromAllowedHeaders(t *testing.T) {

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
Expand All @@ -48,26 +46,32 @@ func TestOnlyFromAllowedHeaders(t *testing.T) {
}
client := http.Client{}

req, err := reqWithHeader("X-Real-IP")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
t.Run("X-Real-IP", func(t *testing.T) {
req, err := reqWithHeader("X-Real-IP")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
})

req, err = reqWithHeader("X-Forwarded-For")
require.NoError(t, err)
resp, err = client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
t.Run("X-Forwarded-For", func(t *testing.T) {
req, err := reqWithHeader("X-Forwarded-For")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
})

req, err = reqWithHeader("RemoteAddr")
require.NoError(t, err)
resp, err = client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
t.Run("X-Forwarded-For and X-Real-IP missing", func(t *testing.T) {
req, err := reqWithHeader("blah")
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)
})
}

func TestOnlyFromAllowedCIDR(t *testing.T) {
Expand Down
11 changes: 7 additions & 4 deletions realip/real.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,27 @@ var privateRanges = []ipRange{

// Get returns real ip from the given request
func Get(r *http.Request) (string, error) {

var firstIP string
for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} {
addresses := strings.Split(r.Header.Get(h), ",")
// march from right to left until we get a public address
// that will be the address right before our proxy.
for i := len(addresses) - 1; i >= 0; i-- {
ip := strings.TrimSpace(addresses[i])
realIP := net.ParseIP(ip)
if firstIP == "" && realIP != nil {
firstIP = ip
}
if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) {
continue
}
return ip, nil
}
}

// X-Forwarded-For header set but parsing failed above
if r.Header.Get("X-Forwarded-For") != "" {
return "", fmt.Errorf("no valid ip found")
// if we cannot find a public address in X-Forwarded-For or X-Real-IP headers, fallback to first ip
if firstIP != "" {
return firstIP, nil
}

// get IP from RemoteAddr
Expand Down
45 changes: 31 additions & 14 deletions realip/real_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,43 @@ import (
)

func TestGetFromHeaders(t *testing.T) {
{
t.Run("single X-Real-IP", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Real-IP", "8.8.8.8")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
}
{
})
t.Run("X-Forwarded-For last public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
})
t.Run("X-Forwarded-For last private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "1.1.1.2", adr)
}
{
})
t.Run("X-Forwarded-For all private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "192.168.1.1,10.0.0.65")
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "10.0.0.65", adr)
})
t.Run("X-Forwarded-For public, X-Real-IP private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
Expand All @@ -48,8 +57,8 @@ func TestGetFromHeaders(t *testing.T) {
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
})
t.Run("X-Forwarded-For and X-Real-IP public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
Expand All @@ -58,8 +67,8 @@ func TestGetFromHeaders(t *testing.T) {
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
}
{
})
t.Run("X-Forwarded-For private and X-Real-IP public]", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
Expand All @@ -68,14 +77,22 @@ func TestGetFromHeaders(t *testing.T) {
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
}
{
})
t.Run("RemoteAddr fallback", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.RemoteAddr = "192.0.2.1:1234"
adr, err := Get(req)
require.NoError(t, err)
assert.Equal(t, "192.0.2.1", adr)
})
t.Run("X-Forwarded-For and X-Real-IP missing, no RemoteAddr either", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
ip, err := Get(req)
assert.Error(t, err)
assert.Equal(t, "", ip)
}
})
}

func TestGetFromRemoteAddr(t *testing.T) {
Expand Down

0 comments on commit 53446d3

Please sign in to comment.