Skip to content

Commit

Permalink
Backport 1.9.x: Fixing ResponseWriter to assert to http.Flusher in /s…
Browse files Browse the repository at this point in the history
…ys/monitor endpoint (#13200) (#13260)

* Unify HTTPResponseWriter and StatusHeaderResponseWriter (#13200)

* Unify NewHTTPResponseWriter ant NewStatusHeaderResponseWriter to fix ResponseWriter issues

* adding changelog

* removing unnecessary function from the WrappingResponseWriter interface

* changing logical requests responseWriter type

* reverting change to HTTPResponseWriter

* Update changelog/13200.txt

* Update changelog/13200.txt
  • Loading branch information
hghaf099 authored Nov 24, 2021
1 parent 5213593 commit 5d6e7b5
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 117 deletions.
3 changes: 3 additions & 0 deletions changelog/13200.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
http:Fix /sys/monitor endpoint returning streaming not supported
```
112 changes: 12 additions & 100 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ import (
"net/url"
"os"
"regexp"
"strconv"
"strings"
"time"

"github.com/NYTimes/gziphandler"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/helper/namespace"
Expand Down Expand Up @@ -215,89 +213,6 @@ func Handler(props *vault.HandlerProperties) http.Handler {
return printablePathCheckHandler
}

type WrappingResponseWriter interface {
http.ResponseWriter
Wrapped() http.ResponseWriter
}

type statusHeaderResponseWriter struct {
wrapped http.ResponseWriter
logger log.Logger
wroteHeader bool
statusCode int
headers map[string][]*vault.CustomHeader
}

func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter {
return w.wrapped
}

func (w *statusHeaderResponseWriter) Header() http.Header {
return w.wrapped.Header()
}

func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) {
// It is allowed to only call ResponseWriter.Write and skip
// ResponseWriter.WriteHeader. An example of such a situation is
// "handleUIStub". The Write function will internally set the status code
// 200 for the response for which that call might invoke other
// implementations of the WriteHeader function. So, we still need to set
// the custom headers. In cases where both WriteHeader and Write of
// statusHeaderResponseWriter struct are called the internal call to the
// WriterHeader invoked from inside Write method won't change the headers.
if !w.wroteHeader {
w.setCustomResponseHeaders(w.statusCode)
}

return w.wrapped.Write(buf)
}

func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) {
w.setCustomResponseHeaders(statusCode)
w.wrapped.WriteHeader(statusCode)
w.statusCode = statusCode
// in cases where Write is called after WriteHeader, let's prevent setting
// ResponseWriter headers twice
w.wroteHeader = true
}

func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) {
sch := w.headers
if sch == nil {
w.logger.Warn("status code header map not configured")
return
}

// Checking the validity of the status code
if status >= 600 || status < 100 {
return
}

// setter function to set the headers
setter := func(hvl []*vault.CustomHeader) {
for _, hv := range hvl {
w.Header().Set(hv.Name, hv.Value)
}
}

// Setting the default headers first
setter(sch["default"])

// setting the Xyy pattern first
d := fmt.Sprintf("%vxx", status/100)
if val, ok := sch[d]; ok {
setter(val)
}

// Setting the specific headers
if val, ok := sch[strconv.Itoa(status)]; ok {
setter(val)
}

return
}

var _ WrappingResponseWriter = &statusHeaderResponseWriter{}

type copyResponseWriter struct {
wrapped http.ResponseWriter
Expand Down Expand Up @@ -389,25 +304,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
hostname, _ := os.Hostname()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// This block needs to be here so that upon sending SIGHUP, custom response
// headers are also reloaded into the handlers.
var customHeaders map[string][]*logical.CustomHeader
if props.ListenerConfig != nil {
la := props.ListenerConfig.Address
listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la)
if listenerCustomHeaders != nil {
w = &statusHeaderResponseWriter{
wrapped: w,
logger: core.Logger(),
wroteHeader: false,
statusCode: 200,
headers: listenerCustomHeaders.StatusCodeHeaderMap,
}
customHeaders = listenerCustomHeaders.StatusCodeHeaderMap
}
}

nw := logical.NewStatusHeaderResponseWriter(w, customHeaders)
// Set the Cache-Control header for all the responses returned
// by Vault
w.Header().Set("Cache-Control", "no-store")
nw.Header().Set("Cache-Control", "no-store")

// Start with the request context
ctx := r.Context()
Expand All @@ -431,38 +343,38 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
if core.RaftNodeIDHeaderEnabled() {
nodeID := core.GetRaftNodeID()
if nodeID != "" {
w.Header().Set("X-Vault-Raft-Node-ID", nodeID)
nw.Header().Set("X-Vault-Raft-Node-ID", nodeID)
}
}

if core.HostnameHeaderEnabled() && hostname != "" {
w.Header().Set("X-Vault-Hostname", hostname)
nw.Header().Set("X-Vault-Hostname", hostname)
}

switch {
case strings.HasPrefix(r.URL.Path, "/v1/"):
newR, status := adjustRequest(core, r)
if status != 0 {
respondError(w, status, nil)
respondError(nw, status, nil)
cancelFunc()
return
}
r = newR

case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/":
default:
respondError(w, http.StatusNotFound, nil)
respondError(nw, http.StatusNotFound, nil)
cancelFunc()
return
}

// Setting the namespace in the header to be included in the error message
ns := r.Header.Get(consts.NamespaceHeaderName)
if ns != "" {
w.Header().Set(consts.NamespaceHeaderName, ns)
nw.Header().Set(consts.NamespaceHeaderName, ns)
}

h.ServeHTTP(w, r)
h.ServeHTTP(nw, r)

cancelFunc()
return
Expand Down Expand Up @@ -742,7 +654,7 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
// requestTooLarger. So we let it have access to the underlying
// ResponseWriter.
inw := w
if myw, ok := inw.(WrappingResponseWriter); ok {
if myw, ok := inw.(logical.WrappingResponseWriter); ok {
inw = myw.Wrapped()
}
reader = http.MaxBytesReader(inw, r.Body, max)
Expand Down
5 changes: 5 additions & 0 deletions sdk/logical/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,8 @@ type InitializationRequest struct {
// Storage can be used to durably store and retrieve state.
Storage Storage
}

type CustomHeader struct {
Name string
Value string
}
103 changes: 97 additions & 6 deletions sdk/logical/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"sync/atomic"

"github.com/hashicorp/vault/sdk/helper/wrapping"
Expand Down Expand Up @@ -209,13 +210,103 @@ func NewHTTPResponseWriter(w http.ResponseWriter) *HTTPResponseWriter {
}

// Write will write the bytes to the underlying io.Writer.
func (rw *HTTPResponseWriter) Write(bytes []byte) (int, error) {
atomic.StoreUint32(rw.written, 1)

return rw.ResponseWriter.Write(bytes)
func (w *HTTPResponseWriter) Write(bytes []byte) (int, error) {
atomic.StoreUint32(w.written, 1)
return w.ResponseWriter.Write(bytes)
}

// Written tells us if the writer has been written to yet.
func (rw *HTTPResponseWriter) Written() bool {
return atomic.LoadUint32(rw.written) == 1
func (w *HTTPResponseWriter) Written() bool {
return atomic.LoadUint32(w.written) == 1
}

type WrappingResponseWriter interface {
http.ResponseWriter
Wrapped() http.ResponseWriter
}

type StatusHeaderResponseWriter struct {
wrapped http.ResponseWriter
wroteHeader bool
statusCode int
headers map[string][]*CustomHeader
}

func NewStatusHeaderResponseWriter(w http.ResponseWriter, h map[string][]*CustomHeader) *StatusHeaderResponseWriter {
return &StatusHeaderResponseWriter{
wrapped: w,
wroteHeader: false,
statusCode: 200,
headers: h,
}
}

func (w *StatusHeaderResponseWriter) Wrapped() http.ResponseWriter {
return w.wrapped
}

func (w *StatusHeaderResponseWriter) Header() http.Header {
return w.wrapped.Header()
}

func (w *StatusHeaderResponseWriter) Write(buf []byte) (int, error) {
// It is allowed to only call ResponseWriter.Write and skip
// ResponseWriter.WriteHeader. An example of such a situation is
// "handleUIStub". The Write function will internally set the status code
// 200 for the response for which that call might invoke other
// implementations of the WriteHeader function. So, we still need to set
// the custom headers. In cases where both WriteHeader and Write of
// statusHeaderResponseWriter struct are called the internal call to the
// WriterHeader invoked from inside Write method won't change the headers.
if !w.wroteHeader {
w.setCustomResponseHeaders(w.statusCode)
}

return w.wrapped.Write(buf)
}

func (w *StatusHeaderResponseWriter) WriteHeader(statusCode int) {
w.setCustomResponseHeaders(statusCode)
w.wrapped.WriteHeader(statusCode)
w.statusCode = statusCode
// in cases where Write is called after WriteHeader, let's prevent setting
// ResponseWriter headers twice
w.wroteHeader = true
}

func (w *StatusHeaderResponseWriter) setCustomResponseHeaders(status int) {
sch := w.headers
if sch == nil {
return
}

// Checking the validity of the status code
if status >= 600 || status < 100 {
return
}

// setter function to set the headers
setter := func(hvl []*CustomHeader) {
for _, hv := range hvl {
w.Header().Set(hv.Name, hv.Value)
}
}

// Setting the default headers first
setter(sch["default"])

// setting the Xyy pattern first
d := fmt.Sprintf("%vxx", status/100)
if val, ok := sch[d]; ok {
setter(val)
}

// Setting the specific headers
if val, ok := sch[strconv.Itoa(status)]; ok {
setter(val)
}

return
}

var _ WrappingResponseWriter = &StatusHeaderResponseWriter{}
17 changes: 7 additions & 10 deletions vault/custom_response_headers.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,35 @@
package vault

import (
"fmt"
"net/http"
"net/textproto"
"strings"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/logical"
)

type ListenerCustomHeaders struct {
Address string
StatusCodeHeaderMap map[string][]*CustomHeader
StatusCodeHeaderMap map[string][]*logical.CustomHeader
// ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through
// StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names
configuredHeadersStatusCodeMap map[string][]string
}

type CustomHeader struct {
Name string
Value string
}

func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders {
var listenerCustomHeadersList []*ListenerCustomHeaders

for _, l := range ln {
listenerCustomHeaderStruct := &ListenerCustomHeaders{
Address: l.Address,
}
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader)
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*logical.CustomHeader)
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string)
for statusCode, headerValMap := range l.CustomResponseHeaders {
var customHeaderList []*CustomHeader
var customHeaderList []*logical.CustomHeader
for headerName, headerVal := range headerValMap {
// Sanitizing custom headers
// X-Vault- prefix is reserved for Vault internal processes
Expand All @@ -45,7 +42,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea
if uiHeaders != nil {
exist := uiHeaders.Get(headerName)
if exist != "" {
logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.")
logger.Warn(fmt.Sprintf("found a duplicate header in UI: header=%s. Headers defined in the server configuration take precedence.", headerName))
}
}

Expand All @@ -55,7 +52,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea
continue
}

ch := &CustomHeader{
ch := &logical.CustomHeader{
Name: headerName,
Value: headerVal,
}
Expand Down
Loading

0 comments on commit 5d6e7b5

Please sign in to comment.