Skip to content

Commit

Permalink
fix #713: CORS headers are lost
Browse files Browse the repository at this point in the history
  • Loading branch information
localvar committed Jul 27, 2022
1 parent c992926 commit 9c44203
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 13 deletions.
4 changes: 2 additions & 2 deletions doc/reference/filters.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ pools:

## CORSAdaptor

The CORSAdaptor handles the [CORS](https://en.wikipedia.org/wiki/Cross-origin_resource_sharing) preflight request for the backend service.
The CORSAdaptor handles the [CORS](https://en.wikipedia.org/wiki/Cross-origin_resource_sharing) preflight, simple and not so simple request for the backend service.

The below example configuration handles the preflight `GET` request from `*.megaease.com`.
The below example configuration handles the CORS `GET` request from `*.megaease.com`.

```yaml
kind: CORSAdaptor
Expand Down
81 changes: 74 additions & 7 deletions pkg/filters/proxy/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (sp *ServerPool) collectMetrics(spCtx *serverPoolContext) {
}

// Now, the body must be a CallbackReader.
body, _ := spCtx.stdResp.Body.(*readers.CallbackReader)
body := spCtx.stdResp.Body.(*readers.CallbackReader)

// Collect when reach EOF or meet an error.
body.OnAfter(func(total int, p []byte, err error) {
Expand Down Expand Up @@ -442,6 +442,9 @@ func (sp *ServerPool) handle(ctx *context.Context, mirror bool) string {
defer sp.collectMetrics(spCtx)

if sp.buildResponseFromCache(spCtx) {
if _, ok := sp.failureCodes[spCtx.resp.StatusCode()]; ok {
return resultFailureCode
}
return ""
}

Expand Down Expand Up @@ -562,11 +565,30 @@ func (sp *ServerPool) doHandle(stdctx stdcontext.Context, spCtx *serverPoolConte
return serverPoolError{resp.StatusCode, resultFailureCode}
}

if sp.memoryCache != nil {
sp.memoryCache.Store(spCtx.req, spCtx.resp)
return nil
}

func copyCORSHeaders(dst, src http.Header) bool {
value := src.Get("Access-Control-Allow-Origin")
if value == "" {
return false
}

return nil
dst.Set("Access-Control-Allow-Origin", value)

if value = src.Get("Access-Control-Expose-Headers"); value != "" {
dst.Set("Access-Control-Expose-Headers", value)
}

if src.Get("Access-Control-Allow-Credentials"); value != "" {
dst.Set("Access-Control-Allow-Credential", value)
}

if !stringtool.StrInSlice("Origin", dst.Values("Vary")) {
dst.Add("Vary", "Origin")
}

return true
}

func (sp *ServerPool) buildResponse(spCtx *serverPoolContext) (err error) {
Expand Down Expand Up @@ -600,6 +622,19 @@ func (sp *ServerPool) buildResponse(spCtx *serverPoolContext) (err error) {
body.Close()
}

if sp.memoryCache != nil {
sp.memoryCache.Store(spCtx.req, resp)
}

if r, _ := spCtx.GetOutputResponse().(*httpprot.Response); r != nil {
copyCORSHeaders(resp.HTTPHeader(), r.HTTPHeader())

// reuse the existing output response, this is to align with
// buildResponseFromCache and buildFailureResponse and other filters.
*r = *resp
resp = r
}

spCtx.resp = resp
spCtx.SetOutputResponse(resp)
return nil
Expand All @@ -614,10 +649,38 @@ func (sp *ServerPool) buildResponseFromCache(spCtx *serverPoolContext) bool {
if ce == nil {
return false
}
header := ce.Header.Clone()

resp, _ := httpprot.NewResponse(nil)
corsHeadersCopied := false
resp, _ := spCtx.GetOutputResponse().(*httpprot.Response)
if resp == nil {
resp, _ = httpprot.NewResponse(nil)
} else {
corsHeadersCopied = copyCORSHeaders(header, resp.HTTPHeader())
}

if !corsHeadersCopied {
if spCtx.req.HTTPHeader().Get("Origin") == "" {
// remove these headers as the request is not a CORS request.
header.Del("Access-Control-Allow-Origin")
header.Del("Access-Control-Expose-Headers")
header.Del("Access-Control-Allow-Credentials")
} else {
// There are 3 cases here:
// 1. the cached response fully matches the request: we have
// nothing to do in this case.
// 2. the cached response does not match the request, because
// the Origin is not the same: we need to update the
// MemoryCache.Load function to make it return false.
// 3. the cached response does not have CORS headers: the user
// need to add a CORSAdaptor into the pipeline.
//
// So, we do nothing for now.
}
}

resp.Std().Header = header
resp.SetStatusCode(ce.StatusCode)
resp.Std().Header = ce.Header.Clone()
resp.SetPayload(ce.Body)

spCtx.resp = resp
Expand All @@ -626,7 +689,11 @@ func (sp *ServerPool) buildResponseFromCache(spCtx *serverPoolContext) bool {
}

func (sp *ServerPool) buildFailureResponse(spCtx *serverPoolContext, statusCode int) {
resp, _ := httpprot.NewResponse(nil)
resp, _ := spCtx.GetOutputResponse().(*httpprot.Response)
if resp == nil {
resp, _ = httpprot.NewResponse(nil)
}

resp.SetStatusCode(statusCode)
spCtx.resp = resp
spCtx.SetOutputResponse(resp)
Expand Down
11 changes: 8 additions & 3 deletions pkg/filters/ratelimiter/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,17 @@ func (rl *RateLimiter) Handle(ctx *context.Context) string {

permitted, d := u.rl.AcquirePermission()
if !permitted {
resp, _ := httpprot.NewResponse(nil)
ctx.SetOutputResponse(resp)
ctx.AddTag("rateLimiter: too many requests")

resp, _ := ctx.GetOutputResponse().(*httpprot.Response)
if resp == nil {
resp, _ = httpprot.NewResponse(nil)
}

resp.SetStatusCode(http.StatusTooManyRequests)
resp.Std().Header.Set("X-EG-Rate-Limiter", "too-many-requests")
resp.HTTPHeader().Set("X-EG-Rate-Limiter", "too-many-requests")

ctx.SetOutputResponse(resp)
return resultRateLimited
}

Expand Down
6 changes: 5 additions & 1 deletion pkg/filters/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ func (v *Validator) Handle(ctx *context.Context) string {
req := ctx.GetInputRequest().(*httpprot.Request)

prepareErrorResponse := func(status int, tagPrefix string, err error) {
resp, _ := httpprot.NewResponse(nil)
resp, _ := ctx.GetOutputResponse().(*httpprot.Response)
if resp == nil {
resp, _ = httpprot.NewResponse(nil)
}

resp.SetStatusCode(status)
ctx.SetOutputResponse(resp)
ctx.AddTag(stringtool.Cat(tagPrefix, err.Error()))
Expand Down

0 comments on commit 9c44203

Please sign in to comment.