diff --git a/bootstrap/handlers/httpserver.go b/bootstrap/handlers/httpserver.go index d56f9425..79f10850 100644 --- a/bootstrap/handlers/httpserver.go +++ b/bootstrap/handlers/httpserver.go @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2019 Dell Inc. - * Copyright 2021 IOTech Ltd + * Copyright 2021-2022 IOTech Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at @@ -17,6 +17,11 @@ package handlers import ( "context" + "encoding/json" + "fmt" + "github.com/edgexfoundry/go-mod-core-contracts/v2/clients/logger" + "github.com/edgexfoundry/go-mod-core-contracts/v2/common" + commonDTO "github.com/edgexfoundry/go-mod-core-contracts/v2/dtos/common" "net/http" "strconv" "sync" @@ -98,6 +103,9 @@ func (b *HttpServer) BootstrapHandler( b.router.Use(func(next http.Handler) http.Handler { return http.TimeoutHandler(next, timeout, "HTTP request timeout") }) + + b.router.Use(RequestLimitMiddleware(bootstrapConfig.Service.MaxRequestSize, lc)) + b.router.Use(ProcessCORS(bootstrapConfig.Service.CORSConfiguration)) // handle the CORS preflight request @@ -142,3 +150,30 @@ func (b *HttpServer) BootstrapHandler( return true } + +// RequestLimitMiddleware is a middleware function that limits the request body size to Service.MaxRequestSize in kilobytes +func RequestLimitMiddleware(sizeLimit int64, lc logger.LoggingClient) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch: + if sizeLimit > 0 && r.ContentLength > sizeLimit*1024 { + response := commonDTO.NewBaseResponse("", fmt.Sprintf("request size exceed Service.MaxRequestSize(%d KB)", sizeLimit), http.StatusRequestEntityTooLarge) + lc.Errorf(response.Message) + + w.Header().Set(common.ContentType, common.ContentTypeJSON) + w.WriteHeader(response.StatusCode) + if err := json.NewEncoder(w).Encode(response); err != nil { + lc.Errorf("Error encoding the data: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } else { + next.ServeHTTP(w, r) + } + default: + // ignore the other http methods because they do not have request bodies + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/bootstrap/handlers/httpserver_test.go b/bootstrap/handlers/httpserver_test.go new file mode 100644 index 00000000..9465ecad --- /dev/null +++ b/bootstrap/handlers/httpserver_test.go @@ -0,0 +1,57 @@ +// +// Copyright (C) 2022 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package handlers + +import ( + "encoding/json" + "github.com/edgexfoundry/go-mod-core-contracts/v2/clients/logger" + "github.com/edgexfoundry/go-mod-core-contracts/v2/common" + commonDTO "github.com/edgexfoundry/go-mod-core-contracts/v2/dtos/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRequestLimitMiddleware(t *testing.T) { + lc := logger.NewMockClient() + payload := make([]byte, 2048) + tests := []struct { + name string + sizeLimit int64 + errorExpected bool + }{ + {"Valid unlimited size", int64(0), false}, + {"Valid size", int64(2), false}, + {"Invalid size", int64(1), true}, + } + + for _, testCase := range tests { + middleware := RequestLimitMiddleware(testCase.sizeLimit, lc) + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + reader := strings.NewReader(string(payload)) + req, err := http.NewRequest(http.MethodPost, "/", reader) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + resp := recorder.Result() + + if testCase.errorExpected { + var res commonDTO.BaseResponse + err = json.Unmarshal(recorder.Body.Bytes(), &res) + require.NoError(t, err) + + assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode, "http status code is not as expected") + assert.Equal(t, common.ContentTypeJSON, resp.Header.Get(common.ContentType), "http header Content-Type is not as expected") + assert.Equal(t, http.StatusRequestEntityTooLarge, int(res.StatusCode), "Response status code not as expected") + assert.NotEmpty(t, res.Message, "Response message doesn't contain the error message") + } + } +} diff --git a/config/types.go b/config/types.go index 83a588e3..aa5d23c8 100644 --- a/config/types.go +++ b/config/types.go @@ -43,7 +43,7 @@ type ServiceInfo struct { // MaxResultCount specifies the maximum size list supported // in response to REST calls to other services. MaxResultCount int - // MaxRequestSize defines the maximum size of http request body in bytes + // MaxRequestSize defines the maximum size of http request body in kilobytes MaxRequestSize int64 // RequestTimeout specifies a timeout (in milliseconds) for // processing REST request calls from other services.