Skip to content

Commit

Permalink
Buffer body read up to MaxRequestSize (#24354) (#24369)
Browse files Browse the repository at this point in the history
* Buffer body read up to MaxRequestSize (#24354)

* adding back a context
  • Loading branch information
hghaf099 authored Dec 5, 2023
1 parent c5c7c98 commit 9b61934
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 83 deletions.
16 changes: 1 addition & 15 deletions helper/forwarding/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -63,19 +61,7 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request

func GenerateForwardedRequest(req *http.Request) (*Request, error) {
var reader io.Reader = req.Body
ctx := req.Context()
maxRequestSize := ctx.Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = io.LimitReader(req.Body, max)
}
}

body, err := ioutil.ReadAll(reader)
body, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
Expand Down
46 changes: 3 additions & 43 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,13 @@ func handler(props *vault.HandlerProperties) http.Handler {
corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core)
quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core)
genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props)
wrappedHandler := wrapMaxRequestSizeHandler(genericWrappedHandler, props)

// Wrap the handler with PrintablePathCheckHandler to check for non-printable
// characters in the request path.
printablePathCheckHandler := genericWrappedHandler
printablePathCheckHandler := wrappedHandler
if !props.DisablePrintableCheck {
printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(genericWrappedHandler, nil)
printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(wrappedHandler, nil)
}

return printablePathCheckHandler
Expand Down Expand Up @@ -321,18 +322,12 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
// are performed.
func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler {
var maxRequestDuration time.Duration
var maxRequestSize int64
if props.ListenerConfig != nil {
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
maxRequestSize = props.ListenerConfig.MaxRequestSize
}
if maxRequestDuration == 0 {
maxRequestDuration = vault.DefaultMaxRequestDuration
}
if maxRequestSize == 0 {
maxRequestSize = DefaultMaxRequestSize
}

// Swallow this error since we don't want to pollute the logs and we also don't want to
// return an HTTP error here. This information is best effort.
hostname, _ := os.Hostname()
Expand Down Expand Up @@ -366,12 +361,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
} else {
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
}

// if maxRequestSize < 0, no need to set context value
// Add a size limiter if desired
if maxRequestSize > 0 {
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
}
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
r = r.WithContext(ctx)
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
Expand Down Expand Up @@ -717,25 +706,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
reader := r.Body
ctx := r.Context()
maxRequestSize := ctx.Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
// MaxBytesReader won't do all the internal stuff it must unless it's
// given a ResponseWriter that implements the internal http interface
// requestTooLarger. So we let it have access to the underlying
// ResponseWriter.
inw := w
if myw, ok := inw.(logical.WrappingResponseWriter); ok {
inw = myw.Wrapped()
}
reader = http.MaxBytesReader(inw, r.Body, max)
}
}
var origBody io.ReadWriter
if perfStandby {
// Since we're checking PerfStandby here we key on origBody being nil
Expand All @@ -757,16 +727,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
//
// A nil map will be returned if the format is empty or invalid.
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
maxRequestSize := r.Context().Value("max_request_size")
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
}
}
if err := r.ParseForm(); err != nil {
return nil, err
}
Expand Down
60 changes: 60 additions & 0 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package http

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
Expand All @@ -14,16 +15,19 @@ import (
"net/textproto"
"net/url"
"reflect"
"runtime"
"strings"
"testing"

"github.com/go-test/deep"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/versions"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)

func TestHandler_parseMFAHandler(t *testing.T) {
Expand Down Expand Up @@ -887,3 +891,59 @@ func TestHandler_Parse_Form(t *testing.T) {
t.Fatal(diff)
}
}

// TestHandler_MaxRequestSize verifies that a request larger than the
// MaxRequestSize fails
func TestHandler_MaxRequestSize(t *testing.T) {
t.Parallel()
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
DefaultHandlerProperties: vault.HandlerProperties{
ListenerConfig: &configutil.Listener{
MaxRequestSize: 1024,
},
},
HandlerFunc: Handler,
NumCores: 1,
})
cluster.Start()
defer cluster.Cleanup()

client := cluster.Cores[0].Client
_, err := client.KVv2("secret").Put(context.Background(), "foo", map[string]interface{}{
"bar": strings.Repeat("a", 1025),
})

require.ErrorContains(t, err, "error parsing JSON")
}

// TestHandler_MaxRequestSize_Memory sets the max request size to 1024 bytes,
// and creates a 1MB request. The test verifies that less than 1MB of memory is
// allocated when the request is sent. This test shouldn't be run in parallel,
// because it modifies GOMAXPROCS
func TestHandler_MaxRequestSize_Memory(t *testing.T) {
ln, addr := TestListener(t)
core, _, token := vault.TestCoreUnsealed(t)
TestServerWithListenerAndProperties(t, ln, addr, core, &vault.HandlerProperties{
Core: core,
ListenerConfig: &configutil.Listener{
Address: addr,
MaxRequestSize: 1024,
},
})
defer ln.Close()

data := bytes.Repeat([]byte{0x1}, 1024*1024)

req, err := http.NewRequest("POST", addr+"/v1/sys/unseal", bytes.NewReader(data))
require.NoError(t, err)
req.Header.Set(consts.AuthHeaderName, token)

client := cleanhttp.DefaultClient()
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
var start, end runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&start)
client.Do(req)
runtime.ReadMemStats(&end)
require.Less(t, end.TotalAlloc-start.TotalAlloc, uint64(1024*1024))
}
68 changes: 57 additions & 11 deletions http/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ package http
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"io"
"net"
"net/http"
"strings"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/logical"

"github.com/hashicorp/vault/helper/namespace"
Expand All @@ -34,6 +34,27 @@ var (
adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {}
)

func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var maxRequestSize int64
if props.ListenerConfig != nil {
maxRequestSize = props.ListenerConfig.MaxRequestSize
}
if maxRequestSize == 0 {
maxRequestSize = DefaultMaxRequestSize
}
ctx := r.Context()
originalBody := r.Body
if maxRequestSize > 0 {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
}
ctx = logical.CreateContextOriginalBody(ctx, originalBody)
r = r.WithContext(ctx)

handler.ServeHTTP(w, r)
})
}

func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ns, err := namespace.FromContext(r.Context())
Expand All @@ -52,14 +73,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
}
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)

// Clone body, so we do not close the request body reader
bodyBytes, err := ioutil.ReadAll(r.Body)
if err != nil {
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
return
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))

quotaReq := &quotas.Request{
Type: quotas.TypeRateLimit,
Path: path,
Expand All @@ -79,7 +92,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
// If any role-based quotas are enabled for this namespace/mount, just
// do the role resolution once here.
if requiresResolveRole {
role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes)
buf := bytes.Buffer{}
teeReader := io.TeeReader(r.Body, &buf)
role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader)

// Reset the body if it was read
if buf.Len() > 0 {
r.Body = io.NopCloser(&buf)
originalBody, ok := logical.ContextOriginalBodyValue(r.Context())
if ok {
r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody)))
}
}
// add an entry to the context to prevent recalculating request role unnecessarily
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
quotaReq.Role = role
Expand Down Expand Up @@ -138,3 +162,25 @@ func parseRemoteIPAddress(r *http.Request) string {

return ip
}

type multiReaderCloser struct {
readers []io.Reader
io.Reader
}

func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser {
return &multiReaderCloser{
readers: readers,
Reader: io.MultiReader(readers...),
}
}

func (m *multiReaderCloser) Close() error {
var err error
for _, r := range m.readers {
if c, ok := r.(io.Closer); ok {
err = multierror.Append(err, c.Close())
}
}
return err
}
12 changes: 12 additions & 0 deletions sdk/logical/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package logical
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -453,3 +454,14 @@ type CtxKeyRequestRole struct{}
func (c CtxKeyRequestRole) String() string {
return "request-role"
}

type ctxKeyOriginalBody struct{}

func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) {
value, ok := ctx.Value(ctxKeyOriginalBody{}).(io.ReadCloser)
return value, ok
}

func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context {
return context.WithValue(parent, ctxKeyOriginalBody{}, body)
}
37 changes: 25 additions & 12 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -4059,22 +4059,24 @@ func (c *Core) LoadNodeID() (string, error) {
return hostname, nil
}

// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given
// login request, accepting a byte payload
func (c *Core) DetermineRoleFromLoginRequestFromBytes(ctx context.Context, mountPoint string, payload []byte) string {
data := make(map[string]interface{})
err := jsonutil.DecodeJSON(payload, &data)
if err != nil {
// Cannot discern a role from a request we cannot parse
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
// login request
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
c.authLock.RLock()
defer c.authLock.RUnlock()
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential {
// Role based quotas do not apply to this request
return ""
}

return c.DetermineRoleFromLoginRequest(ctx, mountPoint, data)
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
}

// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
// login request
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
// DetermineRoleFromLoginRequestFromReader will determine the role that should
// be applied to a quota for a given login request. The reader will only be
// consumed if the matching backend for the mount point exists and is a secret
// backend
func (c *Core) DetermineRoleFromLoginRequestFromReader(ctx context.Context, mountPoint string, reader io.Reader) string {
c.authLock.RLock()
defer c.authLock.RUnlock()
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
Expand All @@ -4083,6 +4085,17 @@ func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint str
return ""
}

data := make(map[string]interface{})
err := jsonutil.DecodeJSONFromReader(reader, &data)
if err != nil {
return ""
}
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
}

// doResolveRoleLocked does a login and resolve role request on the matching
// backend. Callers should have a read lock on c.authLock
func (c *Core) doResolveRoleLocked(ctx context.Context, mountPoint string, matchingBackend logical.Backend, data map[string]interface{}) string {
resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{
MountPoint: mountPoint,
Path: "login",
Expand Down
Loading

0 comments on commit 9b61934

Please sign in to comment.