Skip to content

Commit

Permalink
backport of commit 184939e (hashicorp#19234)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Proctor <[email protected]>
  • Loading branch information
hc-github-team-secure-vault-core and tomhjp authored Feb 17, 2023
1 parent eb2d03e commit 1e3dcee
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 42 deletions.
23 changes: 11 additions & 12 deletions http/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ import (
)

type eventSubscribeArgs struct {
ctx context.Context
logger hclog.Logger
events *eventbus.EventBus
ns *namespace.Namespace
eventType logical.EventType
conn *websocket.Conn
json bool
ctx context.Context
logger hclog.Logger
events *eventbus.EventBus
ns *namespace.Namespace
pattern string
conn *websocket.Conn
json bool
}

// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason
// only if the connection closes or there was an error.
func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) {
ctx := args.ctx
logger := args.logger
ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.eventType)
ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.pattern)
if err != nil {
logger.Info("Error subscribing", "error", err)
return websocket.StatusUnsupportedData, "Error subscribing", nil
Expand Down Expand Up @@ -97,12 +97,11 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
if ns.ID != namespace.RootNamespaceID {
prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path)
}
eventTypeStr := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
if eventTypeStr == "" {
pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
if pattern == "" {
respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to"))
return
}
eventType := logical.EventType(eventTypeStr)

json := false
jsonRaw := r.URL.Query().Get("json")
Expand Down Expand Up @@ -135,7 +134,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
}
}()

closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, eventType, conn, json})
closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, pattern, conn, json})
if err != nil {
closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 {
Expand Down
78 changes: 54 additions & 24 deletions vault/eventbus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
"github.com/ryanuber/go-glob"
"google.golang.org/protobuf/types/known/timestamppb"
)

const defaultTimeout = 60 * time.Second
const (
// eventTypeAll is purely internal to the event bus. We use it to send all
// events down one big firehose, and pipelines define their own filtering
// based on what each subscriber is interested in.
eventTypeAll = "*"
defaultTimeout = 60 * time.Second
)

var (
ErrNotStarted = errors.New("event broker has not been started")
Expand All @@ -45,16 +52,14 @@ type pluginEventBus struct {

type asyncChanNode struct {
// TODO: add bounded deque buffer of *EventReceived
ctx context.Context
ch chan *logical.EventReceived
namespace *namespace.Namespace
logger hclog.Logger
ctx context.Context
ch chan *logical.EventReceived
logger hclog.Logger

// used to close the connection
closeOnce sync.Once
cancelFunc context.CancelFunc
pipelineID eventlogger.PipelineID
eventType eventlogger.EventType
broker *eventlogger.Broker
}

Expand Down Expand Up @@ -97,7 +102,7 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
// We can't easily know when the Send is complete, so we can't call the cancel function.
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
ctx, _ = context.WithTimeout(ctx, bus.timeout)
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived)
_, err := bus.broker.Send(ctx, eventTypeAll, eventReceived)
if err != nil {
// if no listeners for this event type are registered, that's okay, the event
// will just not be sent anywhere
Expand Down Expand Up @@ -164,43 +169,53 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
}, nil
}

func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eventType logical.EventType) (<-chan *logical.EventReceived, context.CancelFunc, error) {
func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *logical.EventReceived, context.CancelFunc, error) {
// subscriptions are still stored even if the bus has not been started
pipelineID, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
}

nodeID, err := uuid.GenerateUUID()
filterNodeID, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
}

filterNode := newFilterNode(ns, pattern)
err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode)
if err != nil {
return nil, nil, err
}

sinkNodeID, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
}

// TODO: should we have just one node per namespace, and handle all the routing ourselves?
ctx, cancel := context.WithCancel(ctx)
asyncNode := newAsyncNode(ctx, ns, bus.logger)
err = bus.broker.RegisterNode(eventlogger.NodeID(nodeID), asyncNode)
err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode)
if err != nil {
defer cancel()
return nil, nil, err
}

nodes := []eventlogger.NodeID{bus.formatterNodeID, eventlogger.NodeID(nodeID)}
nodes := []eventlogger.NodeID{eventlogger.NodeID(filterNodeID), bus.formatterNodeID, eventlogger.NodeID(sinkNodeID)}

pipeline := eventlogger.Pipeline{
PipelineID: eventlogger.PipelineID(pipelineID),
EventType: eventlogger.EventType(eventType),
EventType: eventTypeAll,
NodeIDs: nodes,
}
err = bus.broker.RegisterPipeline(pipeline)
if err != nil {
defer cancel()
return nil, nil, err
}

addSubscriptions(1)
// add info needed to cancel the subscription
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
asyncNode.eventType = eventlogger.EventType(eventType)
asyncNode.cancelFunc = cancel
return asyncNode.ch, asyncNode.Close, nil
}
Expand All @@ -211,12 +226,32 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
bus.timeout = timeout
}

func newFilterNode(ns *namespace.Namespace, pattern string) *eventlogger.Filter {
return &eventlogger.Filter{
Predicate: func(e *eventlogger.Event) (bool, error) {
eventRecv := e.Payload.(*logical.EventReceived)

// Drop if event is not in our namespace.
// TODO: add wildcard/child namespace processing here in some cases?
if eventRecv.Namespace != ns.Path {
return false, nil
}

// Filter for correct event type, including wildcards.
if !glob.Glob(pattern, eventRecv.EventType) {
return false, nil
}

return true, nil
},
}
}

func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
return &asyncChanNode{
ctx: ctx,
ch: make(chan *logical.EventReceived),
namespace: namespace,
logger: logger,
ctx: ctx,
ch: make(chan *logical.EventReceived),
logger: logger,
}
}

Expand All @@ -225,7 +260,7 @@ func (node *asyncChanNode) Close() {
node.closeOnce.Do(func() {
defer node.cancelFunc()
if node.broker != nil {
err := node.broker.RemovePipeline(node.eventType, node.pipelineID)
err := node.broker.RemovePipeline(eventTypeAll, node.pipelineID)
if err != nil {
node.logger.Warn("Error removing pipeline for closing node", "error", err)
}
Expand All @@ -238,11 +273,6 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
// sends to the channel async in another goroutine
go func() {
eventRecv := e.Payload.(*logical.EventReceived)
// drop if event is not in our namespace
// TODO: add wildcard processing here in some cases?
if eventRecv.Namespace != node.namespace.Path {
return
}
var timeout bool
select {
case node.ch <- eventRecv:
Expand Down
86 changes: 81 additions & 5 deletions vault/eventbus/bus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -38,7 +39,7 @@ func TestBusBasics(t *testing.T) {
t.Errorf("Expected no error sending: %v", err)
}

ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -81,7 +82,7 @@ func TestNamespaceFiltering(t *testing.T) {
t.Fatal(err)
}

ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -137,13 +138,13 @@ func TestBus2Subscriptions(t *testing.T) {
eventType2 := logical.EventType("someType2")
bus.Start()

ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType1)
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1))
if err != nil {
t.Fatal(err)
}
defer cancel1()

ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType2)
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -222,7 +223,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
received := atomic.Int32{}

for i := 0; i < create; i++ {
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -297,3 +298,78 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
}
t.Error("Timeout waiting for condition")
}

// TestBusWildcardSubscriptions tests that a single subscription can receive
// multiple event types using * for glob patterns.
func TestBusWildcardSubscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()

fooEventType := logical.EventType("kv/foo")
barEventType := logical.EventType("kv/bar")
bus.Start()

ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*")
if err != nil {
t.Fatal(err)
}
defer cancel1()

ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar")
if err != nil {
t.Fatal(err)
}
defer cancel2()

event1, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
event2, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}

err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2)
if err != nil {
t.Error(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1)
if err != nil {
t.Error(err)
}

timeout := time.After(1 * time.Second)
// Expect to receive both events on ch1, which subscribed to kv/*
var ch1Seen []string
for i := 0; i < 2; i++ {
select {
case message := <-ch1:
ch1Seen = append(ch1Seen, message.Event.ID())
case <-timeout:
t.Error("Timeout waiting for event1")
}
}
if len(ch1Seen) != 2 {
t.Errorf("Expected 2 events but got: %v", ch1Seen)
} else {
if !strutil.StrListContains(ch1Seen, event1.ID()) {
t.Errorf("Did not find %s event1 ID in ch1seen", event1.ID())
}
if !strutil.StrListContains(ch1Seen, event2.ID()) {
t.Errorf("Did not find %s event2 ID in ch1seen", event2.ID())
}
}
// Expect to receive just kv/bar on ch2, which subscribed to */bar
select {
case message := <-ch2:
if message.Event.ID() != event2.ID() {
t.Errorf("Got unexpected message: %v", message)
}
case <-timeout:
t.Error("Timeout waiting for event2")
}
}
2 changes: 1 addition & 1 deletion vault/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestCanSendEventsFromBuiltinPlugin(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, logical.EventType(eventType))
ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, eventType)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 1e3dcee

Please sign in to comment.