Skip to content

Commit

Permalink
server: cancel the request context on timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrus committed Dec 9, 2024
1 parent ca4ef6d commit f22281d
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 14 deletions.
1 change: 1 addition & 0 deletions .changelog/821.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
server: Added `ContextTimeoutMiddleware` to enforce request context timeout
33 changes: 24 additions & 9 deletions api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"reflect"

apiTypes "github.com/oasisprotocol/nexus/api/v1/types"
"github.com/oasisprotocol/nexus/common"
"github.com/oasisprotocol/nexus/log"
)

var (
Expand Down Expand Up @@ -64,16 +66,29 @@ func HttpCodeForError(err error) int {
}
}

// A simple error handler that renders any error as human-readable JSON to
// A simple error handler that logs and renders any error as human-readable JSON to
// the HTTP response stream `w`.
func HumanReadableJsonErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
w.Header().Set("content-type", "application/json; charset=utf-8")
w.Header().Set("x-content-type-options", "nosniff")
w.WriteHeader(HttpCodeForError(err))
func HumanReadableJsonErrorHandler(logger log.Logger) func(http.ResponseWriter, *http.Request, error) {
return func(w http.ResponseWriter, r *http.Request, err error) {
logger.Debug("request failed, handling human readable error",
"err", err,
"request_id", r.Context().Value(common.RequestIDContextKey),
"ctx_err", r.Context().Err(),
)

// Wrap the error into a trivial JSON object as specified in the OpenAPI spec.
msg := err.Error()
errStruct := apiTypes.HumanReadableError{Msg: msg}
// If request context is closed, don't bother writing a response.
if r.Context().Err() != nil {
return
}

_ = json.NewEncoder(w).Encode(errStruct)
w.Header().Set("content-type", "application/json; charset=utf-8")
w.Header().Set("x-content-type-options", "nosniff")
w.WriteHeader(HttpCodeForError(err))

// Wrap the error into a trivial JSON object as specified in the OpenAPI spec.
msg := err.Error()
errStruct := apiTypes.HumanReadableError{Msg: msg}

_ = json.NewEncoder(w).Encode(errStruct)
}
}
12 changes: 12 additions & 0 deletions api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,18 @@ func RuntimeFromURLMiddleware(baseURL string) func(next http.Handler) http.Handl
}
}

// ContextTimeoutMiddleware cancels the request context after a timeout.
func ContextTimeoutMiddleware(timeout time.Duration) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}

// CorsMiddleware is a restrictive CORS middleware that only allows GET requests.
//
// NOTE: To support other methods (e.g. POST), we'd also need to support OPTIONS
Expand Down
52 changes: 52 additions & 0 deletions api/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//nolint:noctx,bodyclose
package api

import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestTimeoutMiddleware(t *testing.T) {
withMiddleware := false
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Wait some time, so that the request timeout (10ms) should be reached.
<-time.After(2 * time.Second)

// Check if the request context is canceled.
if withMiddleware {
// If timeout middleware is used, the request context should be canceled.
require.True(t, r.Context().Err() != nil, "request context should be canceled")
} else {
// The default behavior is that the request context is not canceled.
require.False(t, r.Context().Err() != nil, "request context was not canceled")
}
})
backend := httptest.NewUnstartedServer(baseHandler)
backend.Config.WriteTimeout = 10 * time.Millisecond
backend.Start()
defer backend.Close()

// Test server without timeout middleware.
req, err := http.NewRequest("GET", backend.URL, nil)
require.NoError(t, err)
_, err = http.DefaultClient.Do(req)
require.ErrorIs(t, err, io.EOF, "client received EOF")
backend.Close()

// Test server with timeout middleware.
withMiddleware = true
backendWithTimeout := httptest.NewUnstartedServer(ContextTimeoutMiddleware(10 * time.Millisecond)(baseHandler))
backendWithTimeout.Config.WriteTimeout = 10 * time.Millisecond
backendWithTimeout.Start()
defer backendWithTimeout.Close()

req, err = http.NewRequest("GET", backendWithTimeout.URL, nil)
require.NoError(t, err)
_, err = http.DefaultClient.Do(req)
require.ErrorIs(t, err, io.EOF, "client received EOF")
}
16 changes: 11 additions & 5 deletions cmd/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const (
moduleName = "api"
// The path portion with which all v1 API endpoints start.
v1BaseURL = "/v1"

requestsTimeout = 10 * time.Second
)

var (
Expand Down Expand Up @@ -173,8 +175,8 @@ func (s *Service) Start() {
api.ParseBigIntParamsMiddleware,
},
apiTypes.StrictHTTPServerOptions{
RequestErrorHandlerFunc: api.HumanReadableJsonErrorHandler,
ResponseErrorHandlerFunc: api.HumanReadableJsonErrorHandler,
RequestErrorHandlerFunc: api.HumanReadableJsonErrorHandler(*s.logger),
ResponseErrorHandlerFunc: api.HumanReadableJsonErrorHandler(*s.logger),
},
)

Expand All @@ -187,20 +189,24 @@ func (s *Service) Start() {
api.RuntimeFromURLMiddleware(v1BaseURL),
},
BaseRouter: baseRouter,
ErrorHandlerFunc: api.HumanReadableJsonErrorHandler,
ErrorHandlerFunc: api.HumanReadableJsonErrorHandler(*s.logger),
})
// Manually apply the CORS middleware; we want it to run always.
// HandlerWithOptions() above does not apply it to some requests (404 URLs, requests with bad params, etc.).
handler = api.CorsMiddleware(handler)
// Request context is not cancelled by the server when write timeout is reached. Ensure the context gets canceled.
// Ref: https://github.com/golang/go/issues/59602
// Metrics middleware should be applied after timeout, since we do not want to cancel the context for metrics.
handler = api.ContextTimeoutMiddleware(requestsTimeout)(handler)
// Manually apply the metrics middleware; we want it to run always, and at the outermost layer.
// HandlerWithOptions() above does not apply it to some requests (404 URLs, requests with bad params, etc.).
handler = api.MetricsMiddleware(metrics.NewDefaultRequestMetrics(moduleName), *s.logger)(handler)

server := &http.Server{
Addr: s.address,
Handler: handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ReadTimeout: requestsTimeout,
WriteTimeout: requestsTimeout,
MaxHeaderBytes: 1 << 20,
}

Expand Down

0 comments on commit f22281d

Please sign in to comment.