From 30ff4329ab6bb4fe1cddc2fa12155627f163dfd0 Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 23 Oct 2023 17:10:05 -0500 Subject: [PATCH] handle error in realip parsing explicitly --- metrics.go | 4 +--- onlyfrom.go | 19 ++++++++++++------- onlyfrom_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/metrics.go b/metrics.go index 13e8080..1cf0a41 100644 --- a/metrics.go +++ b/metrics.go @@ -9,12 +9,10 @@ import ( // Metrics responds to GET /metrics with list of expvar func Metrics(onlyIps ...string) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/metrics") { - if matched, ip := matchSourceIP(r, onlyIps); !matched { + if matched, ip, err := matchSourceIP(r, onlyIps); !matched || err != nil { w.WriteHeader(http.StatusForbidden) RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)}) return diff --git a/onlyfrom.go b/onlyfrom.go index 61fbf06..cccd77e 100644 --- a/onlyfrom.go +++ b/onlyfrom.go @@ -19,7 +19,12 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { h.ServeHTTP(w, r) return } - matched, ip := matchSourceIP(r, onlyIps) + matched, ip, err := matchSourceIP(r, onlyIps) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + RenderJSON(w, JSON{"error": fmt.Sprintf("can't get realip: %s", err)}) + return + } if matched { // matched ip - allow h.ServeHTTP(w, r) @@ -27,28 +32,28 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { } w.WriteHeader(http.StatusForbidden) - RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)}) + RenderJSON(w, JSON{"error": fmt.Sprintf("ip %q rejected", ip)}) } return http.HandlerFunc(fn) } } // matchSourceIP returns true if request's ip matches any of ips -func matchSourceIP(r *http.Request, ips []string) (result bool, match string) { +func matchSourceIP(r *http.Request, ips []string) (result bool, match string, err error) { ip, err := realip.Get(r) if err != nil { - return false, "" // we can't get ip, so no match + return false, "", fmt.Errorf("can't get realip: %w", err) // we can't get ip, so no match } // check for ip prefix or CIDR for _, exclIP := range ips { if _, cidrnet, err := net.ParseCIDR(exclIP); err == nil { if cidrnet.Contains(net.ParseIP(ip)) { - return true, ip + return true, ip, nil } } if strings.HasPrefix(ip, exclIP) { - return true, ip + return true, ip, nil } } - return false, ip + return false, ip, nil } diff --git a/onlyfrom_test.go b/onlyfrom_test.go index d83612f..718a85b 100644 --- a/onlyfrom_test.go +++ b/onlyfrom_test.go @@ -112,3 +112,41 @@ func TestOnlyFromRejected(t *testing.T) { defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) } + +func TestOnlyFromErrors(t *testing.T) { + tests := []struct { + name string + remoteAddr string + status int + }{ + { + name: "Invalid RemoteAddr", + remoteAddr: "bad-addr", + status: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.RemoteAddr = tt.remoteAddr + OnlyFrom("1.1.1.1")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("blah blah")) + require.NoError(t, err) + })).ServeHTTP(w, r) + }) + + ts := httptest.NewServer(handler) + defer ts.Close() + + req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) + require.NoError(t, err) + + client := http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, tt.status, resp.StatusCode) + }) + } +}