Skip to content

Commit

Permalink
Use custom timeout middleware
Browse files Browse the repository at this point in the history
To prevent WriteHeader from being called multiple times we have created
a timeout middleware which checks that the response status hasn't already
been written  before setting the status to http.StatusGatewayTimeout
  • Loading branch information
tamirms committed Oct 23, 2019
1 parent b0804f1 commit c1f0ce9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 27 additions & 0 deletions services/horizon/internal/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,33 @@ func loggerMiddleware(h http.Handler) http.Handler {
})
}

// timeoutMiddleware ensures the request is terminated after the given timeout
func timeoutMiddleware(timeout time.Duration) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
mw, ok := w.(middleware.WrapResponseWriter)
if !ok {
mw = middleware.NewWrapResponseWriter(w, r.ProtoMajor)
w = http.ResponseWriter(mw)
}
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer func() {
cancel()
if ctx.Err() == context.DeadlineExceeded {
if mw.Status() == 0 {
// only write the header if it hasn't been written yet
w.WriteHeader(http.StatusGatewayTimeout)
}
}
}()

r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

// getClientData gets client data (name or version) from header or GET parameter
// (useful when not possible to set headers, like in EventStream).
func getClientData(r *http.Request, headerName string) string {
Expand Down
2 changes: 1 addition & 1 deletion services/horizon/internal/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ func (w *web) mustInstallMiddlewares(app *App, connTimeout time.Duration) {
}

r := w.router
r.Use(chimiddleware.Timeout(connTimeout))
r.Use(chimiddleware.StripSlashes)

//TODO: remove this middleware
Expand All @@ -106,6 +105,7 @@ func (w *web) mustInstallMiddlewares(app *App, connTimeout time.Duration) {
r.Use(contextMiddleware)
r.Use(xff.Handler)
r.Use(loggerMiddleware)
r.Use(timeoutMiddleware(connTimeout))
r.Use(requestMetricsMiddleware)
r.Use(recoverMiddleware)
r.Use(chimiddleware.Compress(flate.DefaultCompression, "application/hal+json"))
Expand Down

0 comments on commit c1f0ce9

Please sign in to comment.