Skip to content

Commit

Permalink
Request Limiter reloadable config (#25095)
Browse files Browse the repository at this point in the history
This commit introduces a new reloadable stanza to the server config to allow disabling the Request Limiter.
  • Loading branch information
mpalmi authored Jan 26, 2024
1 parent 43be9fc commit 5933768
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 1 deletion.
3 changes: 3 additions & 0 deletions changelog/25095.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
limits: Introduce a reloadable disable configuration for the Request Limiter.
```
12 changes: 12 additions & 0 deletions command/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,12 @@ func (c *ServerCommand) Run(args []string) int {
infoKeys = append(infoKeys, "administrative namespace")
info["administrative namespace"] = config.AdministrativeNamespacePath

infoKeys = append(infoKeys, "request limiter")
info["request limiter"] = "enabled"
if config.RequestLimiter != nil && config.RequestLimiter.Disable {
info["request limiter"] = "disabled"
}

sort.Strings(infoKeys)
c.UI.Output("==> Vault server configuration:\n")

Expand Down Expand Up @@ -1661,6 +1667,8 @@ func (c *ServerCommand) Run(args []string) int {
// Setting log request with the new value in the config after reload
core.ReloadLogRequestsLevel()

core.ReloadRequestLimiter()

// reloading HCP link
hcpLink, err = c.reloadHCPLink(hcpLink, config, core, hcpLogger)
if err != nil {
Expand Down Expand Up @@ -3095,6 +3103,10 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical.
AdministrativeNamespacePath: config.AdministrativeNamespacePath,
}

if config.RequestLimiter != nil {
coreConfig.DisableRequestLimiter = config.RequestLimiter.Disable
}

if c.flagDev {
coreConfig.EnableRaw = true
coreConfig.EnableIntrospection = true
Expand Down
53 changes: 53 additions & 0 deletions command/server/config_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package server

import (
"fmt"
"testing"

"github.com/hashicorp/vault/internalshared/configutil"
Expand Down Expand Up @@ -86,3 +87,55 @@ func TestCheckSealConfig(t *testing.T) {
})
}
}

// TestRequestLimiterConfig verifies that the census config is correctly instantiated from HCL
func TestRequestLimiterConfig(t *testing.T) {
testCases := []struct {
name string
inConfig string
outErr bool
outRequestLimiter *configutil.RequestLimiter
}{
{
name: "empty",
outRequestLimiter: nil,
},
{
name: "disabled",
inConfig: `
request_limiter {
disable = true
}`,
outRequestLimiter: &configutil.RequestLimiter{Disable: true},
},
{
name: "invalid disable",
inConfig: `
request_limiter {
disable = "whywouldyoudothis"
}`,
outErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
config := fmt.Sprintf(`
ui = false
storage "file" {
path = "/tmp/test"
}
listener "tcp" {
address = "0.0.0.0:8200"
}
%s`, tc.inConfig)
gotConfig, err := ParseConfig(config, "")
if tc.outErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.outRequestLimiter, gotConfig.RequestLimiter)
}
})
}
}
16 changes: 16 additions & 0 deletions internalshared/configutil/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type SharedConfig struct {
ClusterName string `hcl:"cluster_name"`

AdministrativeNamespacePath string `hcl:"administrative_namespace_path"`

RequestLimiter *RequestLimiter `hcl:"request_limiter"`
}

func ParseConfig(d string) (*SharedConfig, error) {
Expand Down Expand Up @@ -156,6 +158,13 @@ func ParseConfig(d string) (*SharedConfig, error) {
}
}

if o := list.Filter("request_limiter"); len(o.Items) > 0 {
result.found("request_limiter", "RequestLimiter")
if err := parseRequestLimiter(&result, o); err != nil {
return nil, fmt.Errorf("error parsing 'request_limiter': %w", err)
}
}

entConfig := &(result.EntSharedConfig)
if err := entConfig.ParseConfig(list); err != nil {
return nil, fmt.Errorf("error parsing enterprise config: %w", err)
Expand Down Expand Up @@ -284,6 +293,13 @@ func (c *SharedConfig) Sanitized() map[string]interface{} {
result["telemetry"] = sanitizedTelemetry
}

if c.RequestLimiter != nil {
sanitizedRequestLimiter := map[string]interface{}{
"disable": c.RequestLimiter.Disable,
}
result["request_limiter"] = sanitizedRequestLimiter
}

return result
}

Expand Down
5 changes: 5 additions & 0 deletions internalshared/configutil/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@ func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig {
result.ClusterName = c2.ClusterName
}

result.RequestLimiter = c.RequestLimiter
if c2.RequestLimiter != nil {
result.RequestLimiter = c2.RequestLimiter
}

return result
}
59 changes: 59 additions & 0 deletions internalshared/configutil/request_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package configutil

import (
"fmt"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
)

type RequestLimiter struct {
UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"`

Disable bool `hcl:"-"`
DisableRaw interface{} `hcl:"disable"`
}

func (r *RequestLimiter) Validate(source string) []ConfigError {
return ValidateUnusedFields(r.UnusedKeys, source)
}

func (r *RequestLimiter) GoString() string {
return fmt.Sprintf("*%#v", *r)
}

var DefaultRequestLimiter = &RequestLimiter{
Disable: false,
}

func parseRequestLimiter(result *SharedConfig, list *ast.ObjectList) error {
if len(list.Items) > 1 {
return fmt.Errorf("only one 'request_limiter' block is permitted")
}

result.RequestLimiter = DefaultRequestLimiter

// Get our one item
item := list.Items[0]

if err := hcl.DecodeObject(&result.RequestLimiter, item.Val); err != nil {
return multierror.Prefix(err, "request_limiter:")
}

if result.RequestLimiter.DisableRaw != nil {
var err error
if result.RequestLimiter.Disable, err = parseutil.ParseBool(result.RequestLimiter.DisableRaw); err != nil {
return err
}
result.RequestLimiter.DisableRaw = nil
} else {
result.RequestLimiter.Disable = false
}

return nil
}
17 changes: 17 additions & 0 deletions limits/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) {
r.Limiters[flags.Name] = limiter
}

// Disable drops its references to underlying limiters.
func (r *LimiterRegistry) Disable() {
r.Lock()

if !r.Enabled {
return
}

r.Logger.Info("disabling request limiters")
// Any outstanding tokens will be flushed when their request completes, as
// they've already acquired a listener. Just drop the limiter references
// here and the garbage-collector should take care of the rest.
r.Limiters = map[string]*RequestLimiter{}
r.Enabled = false
r.Unlock()
}

// GetLimiter looks up a RequestLimiter by key in the LimiterRegistry.
func (r *LimiterRegistry) GetLimiter(key string) *RequestLimiter {
r.RLock()
Expand Down
33 changes: 32 additions & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,8 @@ type CoreConfig struct {

ClusterAddrBridge *raft.ClusterAddrBridge

LimiterRegistry *limits.LimiterRegistry
DisableRequestLimiter bool
LimiterRegistry *limits.LimiterRegistry
}

// GetServiceRegistration returns the config's ServiceRegistration, or nil if it does
Expand Down Expand Up @@ -1293,6 +1294,15 @@ func NewCore(conf *CoreConfig) (*Core, error) {
return nil, err
}

c.limiterRegistry = conf.LimiterRegistry
c.limiterRegistryLock.Lock()
if conf.DisableRequestLimiter {
c.limiterRegistry.Disable()
} else {
c.limiterRegistry.Enable()
}
c.limiterRegistryLock.Unlock()

err = c.adjustForSealMigration(conf.UnwrapSeal)
if err != nil {
return nil, err
Expand Down Expand Up @@ -4056,6 +4066,27 @@ func (c *Core) ReloadLogRequestsLevel() {
}
}

func (c *Core) ReloadRequestLimiter() {
c.limiterRegistry.Logger.Info("reloading request limiter config")
conf := c.rawConfig.Load()
if conf == nil {
return
}

disable := false
requestLimiterConfig := conf.(*server.Config).RequestLimiter
if requestLimiterConfig != nil {
disable = requestLimiterConfig.Disable
}

switch disable {
case true:
c.limiterRegistry.Disable()
default:
c.limiterRegistry.Enable()
}
}

func (c *Core) ReloadIntrospectionEndpointEnabled() {
conf := c.rawConfig.Load()
if conf == nil {
Expand Down

0 comments on commit 5933768

Please sign in to comment.