From c1f0ce96bff1df4508bff11538691c17e1b92002 Mon Sep 17 00:00:00 2001 From: Tamir Sen Date: Tue, 22 Oct 2019 18:34:50 +0200 Subject: [PATCH] Use custom timeout middleware To prevent WriteHeader from being called multiple times we have created a timeout middleware which checks that the response status hasn't already been written before setting the status to http.StatusGatewayTimeout --- services/horizon/internal/middleware.go | 27 +++++++++++++++++++++++++ services/horizon/internal/web.go | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/services/horizon/internal/middleware.go b/services/horizon/internal/middleware.go index f8e678b947..756473b9d6 100644 --- a/services/horizon/internal/middleware.go +++ b/services/horizon/internal/middleware.go @@ -87,6 +87,33 @@ func loggerMiddleware(h http.Handler) http.Handler { }) } +// timeoutMiddleware ensures the request is terminated after the given timeout +func timeoutMiddleware(timeout time.Duration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + mw, ok := w.(middleware.WrapResponseWriter) + if !ok { + mw = middleware.NewWrapResponseWriter(w, r.ProtoMajor) + w = http.ResponseWriter(mw) + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer func() { + cancel() + if ctx.Err() == context.DeadlineExceeded { + if mw.Status() == 0 { + // only write the header if it hasn't been written yet + w.WriteHeader(http.StatusGatewayTimeout) + } + } + }() + + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + // getClientData gets client data (name or version) from header or GET parameter // (useful when not possible to set headers, like in EventStream). func getClientData(r *http.Request, headerName string) string { diff --git a/services/horizon/internal/web.go b/services/horizon/internal/web.go index 4d6b2aed11..77d1f44690 100644 --- a/services/horizon/internal/web.go +++ b/services/horizon/internal/web.go @@ -95,7 +95,6 @@ func (w *web) mustInstallMiddlewares(app *App, connTimeout time.Duration) { } r := w.router - r.Use(chimiddleware.Timeout(connTimeout)) r.Use(chimiddleware.StripSlashes) //TODO: remove this middleware @@ -106,6 +105,7 @@ func (w *web) mustInstallMiddlewares(app *App, connTimeout time.Duration) { r.Use(contextMiddleware) r.Use(xff.Handler) r.Use(loggerMiddleware) + r.Use(timeoutMiddleware(connTimeout)) r.Use(requestMetricsMiddleware) r.Use(recoverMiddleware) r.Use(chimiddleware.Compress(flate.DefaultCompression, "application/hal+json"))