Skip to content

Commit

Permalink
backport of commit 5ff44bd (hashicorp#19144)
Browse files Browse the repository at this point in the history
Co-authored-by: Christopher Swenson <[email protected]>
  • Loading branch information
hc-github-team-secure-vault-core and Christopher Swenson authored Feb 10, 2023
1 parent cfb8f08 commit 13313d0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
17 changes: 15 additions & 2 deletions http/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -66,13 +67,25 @@ func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCo
}
}

func handleEventsSubscribe(core *vault.Core) http.Handler {
func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := core.Logger().Named("events-subscribe")

logger.Debug("Got request to", "url", r.URL, "version", r.Proto)

ctx := r.Context()

// ACL check
_, _, err := core.CheckToken(ctx, req, false)
if err != nil {
if errors.Is(err, logical.ErrPermissionDenied) {
respondError(w, http.StatusUnauthorized, logical.ErrPermissionDenied)
return
}
logger.Debug("Error validating token", "error", err)
respondError(w, http.StatusInternalServerError, fmt.Errorf("error validating token"))
return
}

ns, err := namespace.FromContext(ctx)
if err != nil {
logger.Info("Could not find namespace", "error", err)
Expand Down
23 changes: 22 additions & 1 deletion http/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"context"
"net/http"
"strings"
"sync/atomic"
"testing"
Expand All @@ -21,6 +22,15 @@ func TestEventsSubscribe(t *testing.T) {
ln, addr := TestServer(t, core)
defer ln.Close()

// unseal the core
keys, token := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}

stop := atomic.Bool{}

eventType := "abc"
Expand Down Expand Up @@ -53,7 +63,18 @@ func TestEventsSubscribe(t *testing.T) {
t.Cleanup(cancelFunc)

wsAddr := strings.Replace(addr, "http", "ws", 1)
conn, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", nil)

// check that the connection fails if we don't have a token
_, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", nil)
if err == nil {
t.Error("Expected websocket error but got none")
} else if !strings.HasSuffix(err.Error(), "401") {
t.Errorf("Expected 401 websocket but got %v", err)
}

conn, _, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/"+eventType+"?json=true", &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{token}},
})
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions http/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"strings"
"time"

uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/experiments"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/consts"
Expand Down Expand Up @@ -359,7 +359,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
nsPath = ""
}
if strings.HasPrefix(r.URL.Path, fmt.Sprintf("/v1/%ssys/events/subscribe/", nsPath)) {
handler := handleEventsSubscribe(core)
handler := handleEventsSubscribe(core, req)
handler.ServeHTTP(w, r)
return
}
Expand Down
6 changes: 3 additions & 3 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func (c *Core) fetchACLTokenEntryAndEntity(ctx context.Context, req *logical.Req
return acl, te, entity, identityPolicies, nil
}

func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool) (*logical.Auth, *logical.TokenEntry, error) {
func (c *Core) CheckToken(ctx context.Context, req *logical.Request, unauth bool) (*logical.Auth, *logical.TokenEntry, error) {
defer metrics.MeasureSince([]string{"core", "check_token"}, time.Now())

var acl *ACL
Expand Down Expand Up @@ -857,7 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
}

// Validate the token
auth, te, ctErr := c.checkToken(ctx, req, false)
auth, te, ctErr := c.CheckToken(ctx, req, false)
if ctErr == logical.ErrRelativePath {
return logical.ErrorResponse(ctErr.Error()), nil, ctErr
}
Expand Down Expand Up @@ -1272,7 +1272,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// Do an unauth check. This will cause EGP policies to be checked
var auth *logical.Auth
var ctErr error
auth, _, ctErr = c.checkToken(ctx, req, true)
auth, _, ctErr = c.CheckToken(ctx, req, true)
if ctErr == logical.ErrPerfStandbyPleaseForward {
return nil, nil, ctErr
}
Expand Down

0 comments on commit 13313d0

Please sign in to comment.