diff --git a/pipeline/errors/when.go b/pipeline/errors/when.go index c7684aab49..fcb4548e60 100644 --- a/pipeline/errors/when.go +++ b/pipeline/errors/when.go @@ -21,8 +21,13 @@ type ( } WhenRequest struct { - CIDR []string `json:"cidr"` - Header *WhenRequestHeader `json:"header"` + RemoteIP *WhenRequestRemoteIP `json:"remote_ip"` + Header *WhenRequestHeader `json:"header"` + } + + WhenRequestRemoteIP struct { + Match []string `json:"match"` + RespectForwardedForHeader bool `json:"respect_forwarded_for_header"` } WhenRequestHeader struct { @@ -104,20 +109,20 @@ func matchesAcceptMIME(request string, handlers []string) bool { for _, match := range hspec { m := match.Value switch { - case a == "*/*": - return true case m == "*/*": return true + case strings.HasSuffix(m, "/*") && strings.TrimSuffix(m, "/*") == strings.Split(a, "/")[0]: + return true + case a == m: return true - case strings.HasSuffix(m, "/*"): - if strings.TrimSuffix(m, "/*") == strings.Split(a, "/")[0] { - return true - } - case strings.HasSuffix(a, "/*"): - if strings.TrimSuffix(a, "/*") == strings.Split(m, "/")[0] { - return true - } + + // If the request contains wildcards, we expect the handler / search value to have an exact match! Otherwise + // we will get a lot of conflicts. + case a == "*/*" && a == m: + return true + case strings.HasSuffix(a, "/*") && a == m: + return true } } } @@ -161,18 +166,21 @@ func matchesRequest(when When, r *http.Request) error { } } - if len(when.Request.CIDR) > 0 { + if when.Request.RemoteIP != nil && len(when.Request.RemoteIP.Match) > 0 { remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return errors.WithStack(err) } check := []string{remoteIP} - for _, fwd := range stringsx.Splitx(r.Header.Get("X-Forwarded-For"), ",") { - check = append(check, strings.TrimSpace(fwd)) + + if when.Request.RemoteIP.RespectForwardedForHeader { + for _, fwd := range stringsx.Splitx(r.Header.Get("X-Forwarded-For"), ",") { + check = append(check, strings.TrimSpace(fwd)) + } } - for _, rn := range when.Request.CIDR { + for _, rn := range when.Request.RemoteIP.Match { _, cidr, err := net.ParseCIDR(rn) if err != nil { return errors.WithStack(err) diff --git a/pipeline/errors/when_test.go b/pipeline/errors/when_test.go index fbec554a94..bb2c5f241b 100644 --- a/pipeline/errors/when_test.go +++ b/pipeline/errors/when_test.go @@ -20,8 +20,11 @@ func TestMatchesMIME(t *testing.T) { assert.False(t, matchesAcceptMIME("application/json", []string{"application/xml"})) assert.True(t, matchesAcceptMIME("application/json", []string{"application/xml", "application/json"})) - assert.True(t, matchesAcceptMIME("application/*", []string{"application/json"})) - assert.True(t, matchesAcceptMIME("*/*", []string{"application/json"})) + assert.False(t, matchesAcceptMIME("application/*", []string{"application/json"})) + assert.False(t, matchesAcceptMIME("*/*", []string{"application/json"})) + + assert.True(t, matchesAcceptMIME("application/*", []string{"application/*"})) + assert.True(t, matchesAcceptMIME("*/*", []string{"*/*"})) assert.True(t, matchesAcceptMIME("text/html;q=0.9, application/xml;q=0.8, application/json;q=0.7", []string{"application/xml", "application/json"})) assert.True(t, matchesAcceptMIME("application/xml", []string{"application/xml", "application/json"})) @@ -39,6 +42,14 @@ func TestMatchesMIME(t *testing.T) { func TestMatchesWhen(t *testing.T) { mixedAccept := func(t *testing.T, r *http.Request) { r.Header.Set("Accept", "application/json,text/html") } jsonAccept := func(t *testing.T, r *http.Request) { r.Header.Set("Accept", "application/json") } + + chromeAccept := func(t *testing.T, r *http.Request) { + r.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9") + } + firefoxAccept := func(t *testing.T, r *http.Request) { + r.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") + } + jsonContentType := func(t *testing.T, r *http.Request) { r.Header.Set("Content-Type", "application/json") } withIPs := func(remote string, forwarded ...string) func(t *testing.T, r *http.Request) { return func(t *testing.T, r *http.Request) { @@ -128,7 +139,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -241,7 +252,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -262,7 +273,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -274,7 +285,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -287,7 +298,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -300,7 +311,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -313,7 +324,20 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{Match: []string{"192.168.1.0/24"}}, + }, + }, + }, + in: errors.WithStack(&herodot.ErrNotFound), + ee: ErrDoesNotMatchWhen, + }, + { + init: combine(jsonAccept, jsonContentType, withIPs("127.0.0.1:123", "127.0.0.2", "127.0.0.3", "192.168.1.2")), + w: Whens{ + When{ + Error: []string{statusText(http.StatusNotFound)}, + Request: &WhenRequest{ + RemoteIP: &WhenRequestRemoteIP{RespectForwardedForHeader: true, Match: []string{"192.168.1.0/24"}}, }, }, }, @@ -325,7 +349,7 @@ func TestMatchesWhen(t *testing.T) { When{ Error: []string{statusText(http.StatusNotFound)}, Request: &WhenRequest{ - CIDR: []string{"192.168.1.0/24"}, + RemoteIP: &WhenRequestRemoteIP{RespectForwardedForHeader: true, Match: []string{"192.168.1.0/24"}}, Header: &WhenRequestHeader{ ContentType: []string{"application/json"}, Accept: []string{"application/xml", "application/json"}, @@ -365,6 +389,36 @@ func TestMatchesWhen(t *testing.T) { }, in: errors.WithStack(&herodot.ErrNotFound), }, + { + init: combine(chromeAccept), + w: Whens{ + When{ + Error: []string{statusText(http.StatusNotFound)}, + Request: &WhenRequest{ + Header: &WhenRequestHeader{ + Accept: []string{"application/json"}, + }, + }, + }, + }, + in: errors.WithStack(&herodot.ErrNotFound), + ee: ErrDoesNotMatchWhen, + }, + { + init: combine(firefoxAccept), + w: Whens{ + When{ + Error: []string{statusText(http.StatusNotFound)}, + Request: &WhenRequest{ + Header: &WhenRequestHeader{ + Accept: []string{"application/json"}, + }, + }, + }, + }, + in: errors.WithStack(&herodot.ErrNotFound), + ee: ErrDoesNotMatchWhen, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { r := httptest.NewRequest("GET", "/test", nil)