Skip to content

Commit

Permalink
handle error in realip parsing explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Oct 23, 2023
1 parent fdc1e79 commit 30ff432
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
4 changes: 1 addition & 3 deletions metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ import (

// Metrics responds to GET /metrics with list of expvar
func Metrics(onlyIps ...string) func(http.Handler) http.Handler {

return func(h http.Handler) http.Handler {

fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/metrics") {
if matched, ip := matchSourceIP(r, onlyIps); !matched {
if matched, ip, err := matchSourceIP(r, onlyIps); !matched || err != nil {
w.WriteHeader(http.StatusForbidden)
RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)})
return
Expand Down
19 changes: 12 additions & 7 deletions onlyfrom.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,41 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler {
h.ServeHTTP(w, r)
return
}
matched, ip := matchSourceIP(r, onlyIps)
matched, ip, err := matchSourceIP(r, onlyIps)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
RenderJSON(w, JSON{"error": fmt.Sprintf("can't get realip: %s", err)})
return
}
if matched {
// matched ip - allow
h.ServeHTTP(w, r)
return
}

w.WriteHeader(http.StatusForbidden)
RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)})
RenderJSON(w, JSON{"error": fmt.Sprintf("ip %q rejected", ip)})
}
return http.HandlerFunc(fn)
}
}

// matchSourceIP returns true if request's ip matches any of ips
func matchSourceIP(r *http.Request, ips []string) (result bool, match string) {
func matchSourceIP(r *http.Request, ips []string) (result bool, match string, err error) {
ip, err := realip.Get(r)
if err != nil {
return false, "" // we can't get ip, so no match
return false, "", fmt.Errorf("can't get realip: %w", err) // we can't get ip, so no match
}
// check for ip prefix or CIDR
for _, exclIP := range ips {
if _, cidrnet, err := net.ParseCIDR(exclIP); err == nil {
if cidrnet.Contains(net.ParseIP(ip)) {
return true, ip
return true, ip, nil
}
}
if strings.HasPrefix(ip, exclIP) {
return true, ip
return true, ip, nil
}
}
return false, ip
return false, ip, nil
}
38 changes: 38 additions & 0 deletions onlyfrom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,41 @@ func TestOnlyFromRejected(t *testing.T) {
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)
}

func TestOnlyFromErrors(t *testing.T) {
tests := []struct {
name string
remoteAddr string
status int
}{
{
name: "Invalid RemoteAddr",
remoteAddr: "bad-addr",
status: http.StatusInternalServerError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr = tt.remoteAddr
OnlyFrom("1.1.1.1")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
})).ServeHTTP(w, r)
})

ts := httptest.NewServer(handler)
defer ts.Close()

req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody)
require.NoError(t, err)

client := http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, tt.status, resp.StatusCode)
})
}
}

0 comments on commit 30ff432

Please sign in to comment.