diff --git a/pkg/object/httpserver/cache.go b/pkg/object/httpserver/cache.go index 7e4ac438df..47061e6af6 100644 --- a/pkg/object/httpserver/cache.go +++ b/pkg/object/httpserver/cache.go @@ -35,7 +35,7 @@ type ( ipFilterChan *ipfilter.IPFilters notFound bool methodNotAllowed bool - path *muxPath + path *MuxPath } ) diff --git a/pkg/object/httpserver/httpserver_test.go b/pkg/object/httpserver/httpserver_test.go new file mode 100644 index 0000000000..d79083ef1f --- /dev/null +++ b/pkg/object/httpserver/httpserver_test.go @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package httpserver + +import ( + "fmt" + + "net/http" + "net/http/httptest" + "testing" + + "github.com/megaease/easegress/pkg/context" + "github.com/megaease/easegress/pkg/context/contexttest" + "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/protocol" + "github.com/megaease/easegress/pkg/supervisor" + "github.com/megaease/easegress/pkg/util/httpheader" + "github.com/megaease/easegress/pkg/util/ipfilter" + "github.com/megaease/easegress/pkg/util/stringtool" + "github.com/stretchr/testify/assert" +) + +func init() { + logger.InitNop() +} + +type testCase struct { + path string + method string + headers map[string]string + realIP string + rules []*muxRule + expectedResult SearchResult +} + +func (tc *testCase) toCtx() *contexttest.MockedHTTPContext { + ctx := &contexttest.MockedHTTPContext{} + header := http.Header{} + for k, v := range tc.headers { + header.Add(k, v) + } + ctx.MockedRequest.MockedPath = func() string { + return tc.path + } + ctx.MockedRequest.MockedMethod = func() string { + return tc.method + } + ctx.MockedRequest.MockedHeader = func() *httpheader.HTTPHeader { + return httpheader.New(header) + } + ctx.MockedRequest.MockedRealIP = func() string { return tc.realIP } + return ctx +} + +func TestSearchPath(t *testing.T) { + assert := assert.New(t) + emptyHeaders := make(map[string]string) + jsonHeader := make(map[string]string) + jsonHeader["content-type"] = "application/json" + tests := []testCase{ + { + "/path/1", http.MethodGet, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{Path: "/path/2"}), + }), + }, NotFound, + }, + { + "/path/1", http.MethodGet, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{Path: "/path/1"}), + }), + }, Found, + }, + { + "/path/1", http.MethodGet, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/path/1", Methods: []string{http.MethodPost}, + }), + }), + }, MethodNotAllowed, + }, + { + "/otherpath", http.MethodPost, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/otherpath", Methods: []string{http.MethodPost}, + }), + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/blaa", Methods: []string{http.MethodPost}, + }), + }), + }, Found, + }, + { + "/otherpath", http.MethodPost, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, + &Rule{IPFilter: &ipfilter.Spec{BlockByDefault: true}}, []*MuxPath{ + newMuxPath( + &ipfilter.IPFilters{}, + &Path{Path: "/otherpath", Methods: []string{http.MethodPost}}, + ), + }, + ), + }, IPNotAllowed, + }, + { + "/route", http.MethodGet, jsonHeader, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/route", + Headers: []*Header{&Header{Key: "content-type", Values: []string{"application/json"}}}, + }), + }), + }, FoundSkipCache, + }, + { + "/route", http.MethodGet, jsonHeader, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/route", + Headers: []*Header{&Header{Key: "content-type", Values: []string{"application/csv"}}}, + }), + }), + }, MethodNotAllowed, + }, + { + "/multimethod", http.MethodPut, emptyHeaders, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multimethod", Methods: []string{http.MethodGet}, + }), + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multimethod", Methods: []string{http.MethodPost}, + }), + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multimethod", Methods: []string{http.MethodPut}, + }), + }), + }, Found, + }, + { + "/multiheader", http.MethodPut, jsonHeader, "", []*muxRule{ + newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{ + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multiheader", Methods: []string{http.MethodPut}, + Headers: []*Header{&Header{Key: "content-type", Values: []string{"application/csv"}}}, + }), + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multiheader", Methods: []string{http.MethodPut}, + Headers: []*Header{&Header{Key: "content-type", Values: []string{"application/json"}}}, + }), + newMuxPath(&ipfilter.IPFilters{}, &Path{ + Path: "/multiheader", Methods: []string{http.MethodPut}, + Headers: []*Header{&Header{Key: "content-type", Values: []string{"application/txt"}}}, + }), + }), + }, FoundSkipCache, + }, + } + + for i := 0; i < len(tests); i++ { + t.Run("test case "+fmt.Sprint(i), func(t *testing.T) { + testcase := tests[i] + result, _ := SearchPath(testcase.toCtx(), testcase.rules) + assert.NotNil(result) + + assert.Equal(result, testcase.expectedResult) + }) + } +} + +func TestSearchPathHeadersAndIPs(t *testing.T) { + assert := assert.New(t) + ipfilter1 := &ipfilter.Spec{AllowIPs: []string{"8.8.8.8"}, BlockByDefault: true} + ipfilter2 := &ipfilter.Spec{AllowIPs: []string{"9.9.9.9"}, BlockByDefault: true} + path1 := newMuxPath(&ipfilter.IPFilters{}, &Path{ + IPFilter: ipfilter1, + Path: "/pipeline", Methods: []string{http.MethodPost}, + Headers: []*Header{&Header{Key: "X-version", Values: []string{"v1"}}}, + }) + path2 := newMuxPath(&ipfilter.IPFilters{}, &Path{ + IPFilter: ipfilter2, + Path: "/pipeline", Methods: []string{http.MethodPost}, + Headers: []*Header{&Header{Key: "X-version", Values: []string{"v2"}}}, + }) + muxRules := []*muxRule{newMuxRule(&ipfilter.IPFilters{}, &Rule{}, []*MuxPath{path1, path2})} + + hdr := make(map[string]string) + hdr["X-version"] = "v1" + testCaseA := testCase{ + path: "/pipeline", + method: http.MethodPost, + headers: hdr, + realIP: "8.8.8.8", + } + + result, pathRes := SearchPath(testCaseA.toCtx(), muxRules) + assert.NotNil(result) + assert.NotNil(pathRes) + assert.Equal(result, FoundSkipCache) + assert.Equal(pathRes, path1) + + hdr["X-version"] = "v2" + testCaseB := testCase{ + path: "/pipeline", + method: http.MethodPost, + headers: hdr, + realIP: "9.9.9.9", + } + + result, pathRes = SearchPath(testCaseB.toCtx(), muxRules) + assert.NotNil(result) + assert.NotNil(pathRes) + assert.Equal(result, FoundSkipCache) + assert.Equal(pathRes, path2) + + hdr["X-version"] = "v3" + testCaseC := testCase{ + path: "/pipeline", + method: http.MethodPost, + headers: hdr, + realIP: "9.9.9.9", + } + + result, _ = SearchPath(testCaseC.toCtx(), muxRules) + assert.NotNil(result) + assert.Equal(result, MethodNotAllowed) + + hdr["X-version"] = "v1" + testCaseD := testCase{ + path: "/pipeline", + method: http.MethodPost, + headers: hdr, + realIP: "9.1.1.9", + } + + result, _ = SearchPath(testCaseD.toCtx(), muxRules) + assert.NotNil(result) + assert.Equal(result, IPNotAllowed) + + testCaseE := testCase{ + path: "/pipeline", + method: http.MethodPost, + realIP: "9.9.9.9", + } + + result, _ = SearchPath(testCaseE.toCtx(), muxRules) + assert.NotNil(result) + assert.Equal(result, MethodNotAllowed) // Missing required header + + // This might be unintuitive but as IP is used only to validate request and + // not route/choose between path, this request produces IPNotAllowed. + // 1st path is mathed and it has IPFilter "8.8.8.8" which does not match with the request. + hdr["X-version"] = "v1" + testCaseF := testCase{ + path: "/pipeline", + method: http.MethodPost, + headers: hdr, + realIP: "9.9.9.9", + } + + result, _ = SearchPath(testCaseF.toCtx(), muxRules) + assert.NotNil(result) + assert.Equal(result, IPNotAllowed) +} + +type handlerMock struct{} + +func (hm *handlerMock) Handle(ctx context.HTTPContext) string { + return "test" +} + +type muxMapperMock struct { + hm *handlerMock +} + +func (mmm *muxMapperMock) GetHandler(name string) (protocol.HTTPHandler, bool) { + handler := mmm.hm + return handler, true +} + +func TestServeHTTP(t *testing.T) { + assert := assert.New(t) + superSpecYaml := ` +name: http-server-test +kind: HTTPServer +port: 10080 +cacheSize: 200 +rules: + - paths: + - pathPrefix: /api +` + + superSpec, err := supervisor.NewSpec(superSpecYaml) + assert.Nil(err) + assert.NotNil(superSpec.ObjectSpec()) + mux := &muxMapperMock{&handlerMock{}} + httpServer := HTTPServer{} + httpServer.Init(superSpec, mux) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cacheKey := stringtool.Cat(r.Host, r.Method, r.URL.Path) + cacheItem := httpServer.runtime.mux.rules.Load().(*muxRules).cache.get(cacheKey) + assert.Nil(cacheItem) + + httpServer.runtime.mux.ServeHTTP(w, r) + + cacheItem = httpServer.runtime.mux.rules.Load().(*muxRules).cache.get(cacheKey) + assert.NotNil(cacheItem) + assert.Equal(true, cacheItem.notFound) + })) + res, err := http.Get(ts.URL + "/unknown-path") + assert.Nil(err) + assert.Equal("404 Not Found", res.Status) + ts.Close() + + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cacheKey := stringtool.Cat(r.Host, r.Method, r.URL.Path) + + httpServer.runtime.mux.ServeHTTP(w, r) + + cacheItem := httpServer.runtime.mux.rules.Load().(*muxRules).cache.get(cacheKey) + assert.NotNil(cacheItem) + assert.Equal(false, cacheItem.notFound) + assert.Equal("/api", cacheItem.path.pathPrefix) + })) + res, err = http.Get(ts.URL + "/api") + assert.Nil(err) + assert.Equal("200 OK", res.Status) + ts.Close() + + httpServer.Close() +} diff --git a/pkg/object/httpserver/mux.go b/pkg/object/httpserver/mux.go index 114b7dcb4b..99f3af4cc1 100644 --- a/pkg/object/httpserver/mux.go +++ b/pkg/object/httpserver/mux.go @@ -70,10 +70,11 @@ type ( host string hostRegexp string hostRE *regexp.Regexp - paths []*muxPath + paths []*MuxPath } - muxPath struct { + // MuxPath describes httpserver's path + MuxPath struct { ipFilter *ipfilter.IPFilter ipFilterChain *ipfilter.IPFilters @@ -86,6 +87,9 @@ type ( backend string headers []*Header } + + // SearchResult is returned by SearchPath + SearchResult string ) // newIPFilterChain returns nil if the number of final filters is zero. @@ -146,7 +150,7 @@ func (mr *muxRules) putCacheItem(ctx context.HTTPContext, ci *cacheItem) { mr.cache.put(key, ci) } -func newMuxRule(parentIPFilters *ipfilter.IPFilters, rule *Rule, paths []*muxPath) *muxRule { +func newMuxRule(parentIPFilters *ipfilter.IPFilters, rule *Rule, paths []*MuxPath) *muxRule { var hostRE *regexp.Regexp if rule.HostRegexp != "" { @@ -198,7 +202,7 @@ func (mr *muxRule) match(ctx context.HTTPContext) bool { return false } -func newMuxPath(parentIPFilters *ipfilter.IPFilters, path *Path) *muxPath { +func newMuxPath(parentIPFilters *ipfilter.IPFilters, path *Path) *MuxPath { var pathRE *regexp.Regexp if path.PathRegexp != "" { var err error @@ -214,7 +218,7 @@ func newMuxPath(parentIPFilters *ipfilter.IPFilters, path *Path) *muxPath { p.initHeaderRoute() } - return &muxPath{ + return &MuxPath{ ipFilter: newIPFilter(path.IPFilter), ipFilterChain: newIPFilterChain(parentIPFilters, path.IPFilter), @@ -229,7 +233,7 @@ func newMuxPath(parentIPFilters *ipfilter.IPFilters, path *Path) *muxPath { } } -func (mp *muxPath) pass(ctx context.HTTPContext) bool { +func (mp *MuxPath) pass(ctx context.HTTPContext) bool { if mp.ipFilter == nil { return true } @@ -237,7 +241,7 @@ func (mp *muxPath) pass(ctx context.HTTPContext) bool { return mp.ipFilter.AllowHTTPContext(ctx) } -func (mp *muxPath) matchPath(ctx context.HTTPContext) bool { +func (mp *MuxPath) matchPath(ctx context.HTTPContext) bool { r := ctx.Request() if mp.path == "" && mp.pathPrefix == "" && mp.pathRE == nil { @@ -257,7 +261,7 @@ func (mp *muxPath) matchPath(ctx context.HTTPContext) bool { return false } -func (mp *muxPath) matchMethod(ctx context.HTTPContext) bool { +func (mp *MuxPath) matchMethod(ctx context.HTTPContext) bool { if len(mp.methods) == 0 { return true } @@ -265,11 +269,11 @@ func (mp *muxPath) matchMethod(ctx context.HTTPContext) bool { return stringtool.StrInSlice(ctx.Request().Method(), mp.methods) } -func (mp *muxPath) hasHeaders() bool { +func (mp *MuxPath) hasHeaders() bool { return len(mp.headers) > 0 } -func (mp *muxPath) matchHeaders(ctx context.HTTPContext) bool { +func (mp *MuxPath) matchHeaders(ctx context.HTTPContext) bool { for _, h := range mp.headers { v := ctx.Request().Header().Get(h.Key) if stringtool.StrInSlice(v, h.Values) { @@ -340,7 +344,7 @@ func (m *mux) reloadRules(superSpec *supervisor.Spec, muxMapper protocol.MuxMapp ruleIPFilterChain := newIPFilterChain(rules.ipFilterChan, specRule.IPFilter) - paths := make([]*muxPath, len(specRule.Paths)) + paths := make([]*MuxPath, len(specRule.Paths)) for j := 0; j < len(paths); j++ { paths[j] = newMuxPath(ruleIPFilterChain, specRule.Paths[j]) } @@ -382,14 +386,58 @@ func (m *mux) ServeHTTP(stdw http.ResponseWriter, stdr *http.Request) { return } - for _, host := range rules.rules { + handleNotFound := func() { + ci = &cacheItem{ipFilterChan: rules.ipFilterChan, notFound: true} + rules.putCacheItem(ctx, ci) + m.handleRequestWithCache(rules, ctx, ci) + } + + result, path := SearchPath(ctx, rules.rules) + switch result { + case Found: + ci = &cacheItem{ipFilterChan: path.ipFilterChain, path: path} + rules.putCacheItem(ctx, ci) + m.handleRequestWithCache(rules, ctx, ci) + case FoundSkipCache: + ci := &cacheItem{ipFilterChan: path.ipFilterChain, path: path} + m.handleRequestWithCache(rules, ctx, ci) + case MethodNotAllowed: + ci := &cacheItem{ipFilterChan: path.ipFilterChain, methodNotAllowed: true} + rules.putCacheItem(ctx, ci) + m.handleRequestWithCache(rules, ctx, ci) + case IPNotAllowed: + m.handleIPNotAllow(ctx) + case NotFound: + handleNotFound() + default: + handleNotFound() + } +} + +const ( + // NotFound means no path found + NotFound SearchResult = "not-found" + // IPNotAllowed means context IP is not allowd + IPNotAllowed SearchResult = "ip-not-allowed" + // MethodNotAllowed means context method is not allowd + MethodNotAllowed SearchResult = "method-not-allowed" + // Found path + Found SearchResult = "found" + // FoundSkipCache means found path but skip caching result + FoundSkipCache SearchResult = "found-skip-cache" +) + +// SearchPath searches path among list of mux rules +func SearchPath(ctx context.HTTPContext, rulesToCheck []*muxRule) (SearchResult, *MuxPath) { + pathFound := false + var notAllowedPath *MuxPath + for _, host := range rulesToCheck { if !host.match(ctx) { continue } if !host.pass(ctx) { - m.handleIPNotAllow(ctx) - return + return IPNotAllowed, nil } for _, path := range host.paths { @@ -397,37 +445,35 @@ func (m *mux) ServeHTTP(stdw http.ResponseWriter, stdr *http.Request) { continue } + // at least one path matches + pathFound = true + notAllowedPath = path if !path.matchMethod(ctx) { - ci = &cacheItem{ipFilterChan: path.ipFilterChain, methodNotAllowed: true} - rules.putCacheItem(ctx, ci) - m.handleRequestWithCache(rules, ctx, ci) - return + continue } - if !path.pass(ctx) { - m.handleIPNotAllow(ctx) - return - } + var searchResult SearchResult - if !path.hasHeaders() { - ci = &cacheItem{ipFilterChan: path.ipFilterChain, path: path} - rules.putCacheItem(ctx, ci) - m.handleRequestWithCache(rules, ctx, ci) - return + if path.hasHeaders() { + if !path.matchHeaders(ctx) { + continue + } + searchResult = FoundSkipCache + } else { + searchResult = Found } - if path.matchHeaders(ctx) { - // NOTE: No cache for the request matching headers. - ci = &cacheItem{ipFilterChan: path.ipFilterChain, path: path} - m.handleRequestWithCache(rules, ctx, ci) - return + if !path.pass(ctx) { + return IPNotAllowed, path } + + return searchResult, path } } - - ci = &cacheItem{ipFilterChan: rules.ipFilterChan, notFound: true} - rules.putCacheItem(ctx, ci) - m.handleRequestWithCache(rules, ctx, ci) + if !pathFound { + return NotFound, nil + } + return MethodNotAllowed, notAllowedPath } func (m *mux) handleIPNotAllow(ctx context.HTTPContext) {