diff --git a/command/server.go b/command/server.go index aef950683343..09fd651af74b 100644 --- a/command/server.go +++ b/command/server.go @@ -38,7 +38,6 @@ import ( "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/mlock" "github.com/hashicorp/vault/helper/parseutil" - "github.com/hashicorp/vault/helper/proxyutil" "github.com/hashicorp/vault/helper/reload" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" @@ -459,43 +458,6 @@ CLUSTER_SYNTHESIS_COMPLETE: return 1 } - if val, ok := lnConfig.Config["proxy_protocol_behavior"]; ok { - behavior, ok := val.(string) - if !ok { - c.Ui.Output(fmt.Sprintf( - "Error parsing proxy_protocol_behavior value for listener of type %s: not a string", - lnConfig.Type)) - return 1 - } - - authorizedAddrsRaw, ok := lnConfig.Config["proxy_protocol_authorized_addrs"] - if !ok { - c.Ui.Output(fmt.Sprintf( - "proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value for listener of type %s", - lnConfig.Type)) - return 1 - } - - proxyProtoConfig := &proxyutil.ProxyProtoConfig{ - Behavior: behavior, - } - if err := proxyProtoConfig.SetAuthorizedAddrs(authorizedAddrsRaw); err != nil { - c.Ui.Output(fmt.Sprintf( - "Error parsing proxy_protocol_authorized_addrs for listener of type %s: %v", - lnConfig.Type, err)) - return 1 - } - - newLn, err := proxyutil.WrapInProxyProto(ln, proxyProtoConfig) - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error configuring PROXY protocol wrapper: %s", err)) - return 1 - } - - ln = newLn - } - lns = append(lns, ln) if reloadFunc != nil { diff --git a/command/server/listener.go b/command/server/listener.go index 9a1f0e1196cd..4f9aedf7d3ee 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -12,6 +12,7 @@ import ( "net" "github.com/hashicorp/vault/helper/parseutil" + "github.com/hashicorp/vault/helper/proxyutil" "github.com/hashicorp/vault/helper/reload" "github.com/hashicorp/vault/helper/tlsutil" ) @@ -35,6 +36,37 @@ func NewListener(t string, config map[string]interface{}, logger io.Writer) (net return f(config, logger) } +func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.Listener, error) { + behaviorRaw, ok := config["proxy_protocol_behavior"] + if !ok { + return ln, nil + } + + behavior, ok := behaviorRaw.(string) + if !ok { + return nil, fmt.Errorf("failed parsing proxy_protocol_behavior value: not a string") + } + + authorizedAddrsRaw, ok := config["proxy_protocol_authorized_addrs"] + if !ok { + return nil, fmt.Errorf("proxy_protocol_behavior set but no proxy_protocol_authorized_addrs value") + } + + proxyProtoConfig := &proxyutil.ProxyProtoConfig{ + Behavior: behavior, + } + if err := proxyProtoConfig.SetAuthorizedAddrs(authorizedAddrsRaw); err != nil { + return nil, fmt.Errorf("failed parsing proxy_protocol_authorized_addrs: %v", err) + } + + newLn, err := proxyutil.WrapInProxyProto(ln, proxyProtoConfig) + if err != nil { + return nil, fmt.Errorf("failed configuring PROXY protocol wrapper: %s", err) + } + + return newLn, nil +} + func listenerWrapTLS( ln net.Listener, props map[string]string, diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index 07c3158ed3a7..b0ab68764806 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -31,6 +31,12 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer) (net.Listene } ln = tcpKeepAliveListener{ln.(*net.TCPListener)} + + ln, err = listenerWrapProxy(ln, config) + if err != nil { + return nil, nil, nil, err + } + props := map[string]string{"addr": addr} return listenerWrapTLS(ln, props, config) }