diff --git a/http/events.go b/http/events.go index 22dfe35630db..e957cd7a3798 100644 --- a/http/events.go +++ b/http/events.go @@ -20,13 +20,13 @@ import ( ) type eventSubscribeArgs struct { - ctx context.Context - logger hclog.Logger - events *eventbus.EventBus - ns *namespace.Namespace - eventType logical.EventType - conn *websocket.Conn - json bool + ctx context.Context + logger hclog.Logger + events *eventbus.EventBus + ns *namespace.Namespace + pattern string + conn *websocket.Conn + json bool } // handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason @@ -34,7 +34,7 @@ type eventSubscribeArgs struct { func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { ctx := args.ctx logger := args.logger - ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.eventType) + ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.pattern) if err != nil { logger.Info("Error subscribing", "error", err) return websocket.StatusUnsupportedData, "Error subscribing", nil @@ -97,12 +97,11 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler if ns.ID != namespace.RootNamespaceID { prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path) } - eventTypeStr := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix)) - if eventTypeStr == "" { + pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix)) + if pattern == "" { respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to")) return } - eventType := logical.EventType(eventTypeStr) json := false jsonRaw := r.URL.Query().Get("json") @@ -135,7 +134,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler } }() - closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, eventType, conn, json}) + closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, pattern, conn, json}) if err != nil { closeStatus = websocket.CloseStatus(err) if closeStatus == -1 { diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index cf789ef26d9d..cc954435fc93 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -16,10 +16,17 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" + "github.com/ryanuber/go-glob" "google.golang.org/protobuf/types/known/timestamppb" ) -const defaultTimeout = 60 * time.Second +const ( + // eventTypeAll is purely internal to the event bus. We use it to send all + // events down one big firehose, and pipelines define their own filtering + // based on what each subscriber is interested in. + eventTypeAll = "*" + defaultTimeout = 60 * time.Second +) var ( ErrNotStarted = errors.New("event broker has not been started") @@ -45,16 +52,14 @@ type pluginEventBus struct { type asyncChanNode struct { // TODO: add bounded deque buffer of *EventReceived - ctx context.Context - ch chan *logical.EventReceived - namespace *namespace.Namespace - logger hclog.Logger + ctx context.Context + ch chan *logical.EventReceived + logger hclog.Logger // used to close the connection closeOnce sync.Once cancelFunc context.CancelFunc pipelineID eventlogger.PipelineID - eventType eventlogger.EventType broker *eventlogger.Broker } @@ -97,7 +102,7 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace, // We can't easily know when the Send is complete, so we can't call the cancel function. // But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long. ctx, _ = context.WithTimeout(ctx, bus.timeout) - _, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived) + _, err := bus.broker.Send(ctx, eventTypeAll, eventReceived) if err != nil { // if no listeners for this event type are registered, that's okay, the event // will just not be sent anywhere @@ -164,32 +169,42 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { }, nil } -func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eventType logical.EventType) (<-chan *logical.EventReceived, context.CancelFunc, error) { +func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *logical.EventReceived, context.CancelFunc, error) { // subscriptions are still stored even if the bus has not been started pipelineID, err := uuid.GenerateUUID() if err != nil { return nil, nil, err } - nodeID, err := uuid.GenerateUUID() + filterNodeID, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, err + } + + filterNode := newFilterNode(ns, pattern) + err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode) + if err != nil { + return nil, nil, err + } + + sinkNodeID, err := uuid.GenerateUUID() if err != nil { return nil, nil, err } - // TODO: should we have just one node per namespace, and handle all the routing ourselves? ctx, cancel := context.WithCancel(ctx) asyncNode := newAsyncNode(ctx, ns, bus.logger) - err = bus.broker.RegisterNode(eventlogger.NodeID(nodeID), asyncNode) + err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) if err != nil { defer cancel() return nil, nil, err } - nodes := []eventlogger.NodeID{bus.formatterNodeID, eventlogger.NodeID(nodeID)} + nodes := []eventlogger.NodeID{eventlogger.NodeID(filterNodeID), bus.formatterNodeID, eventlogger.NodeID(sinkNodeID)} pipeline := eventlogger.Pipeline{ PipelineID: eventlogger.PipelineID(pipelineID), - EventType: eventlogger.EventType(eventType), + EventType: eventTypeAll, NodeIDs: nodes, } err = bus.broker.RegisterPipeline(pipeline) @@ -197,10 +212,10 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve defer cancel() return nil, nil, err } + addSubscriptions(1) // add info needed to cancel the subscription asyncNode.pipelineID = eventlogger.PipelineID(pipelineID) - asyncNode.eventType = eventlogger.EventType(eventType) asyncNode.cancelFunc = cancel return asyncNode.ch, asyncNode.Close, nil } @@ -211,12 +226,32 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) { bus.timeout = timeout } +func newFilterNode(ns *namespace.Namespace, pattern string) *eventlogger.Filter { + return &eventlogger.Filter{ + Predicate: func(e *eventlogger.Event) (bool, error) { + eventRecv := e.Payload.(*logical.EventReceived) + + // Drop if event is not in our namespace. + // TODO: add wildcard/child namespace processing here in some cases? + if eventRecv.Namespace != ns.Path { + return false, nil + } + + // Filter for correct event type, including wildcards. + if !glob.Glob(pattern, eventRecv.EventType) { + return false, nil + } + + return true, nil + }, + } +} + func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode { return &asyncChanNode{ - ctx: ctx, - ch: make(chan *logical.EventReceived), - namespace: namespace, - logger: logger, + ctx: ctx, + ch: make(chan *logical.EventReceived), + logger: logger, } } @@ -225,7 +260,7 @@ func (node *asyncChanNode) Close() { node.closeOnce.Do(func() { defer node.cancelFunc() if node.broker != nil { - err := node.broker.RemovePipeline(node.eventType, node.pipelineID) + err := node.broker.RemovePipeline(eventTypeAll, node.pipelineID) if err != nil { node.logger.Warn("Error removing pipeline for closing node", "error", err) } @@ -238,11 +273,6 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (* // sends to the channel async in another goroutine go func() { eventRecv := e.Payload.(*logical.EventReceived) - // drop if event is not in our namespace - // TODO: add wildcard processing here in some cases? - if eventRecv.Namespace != node.namespace.Path { - return - } var timeout bool select { case node.ch <- eventRecv: diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 94b3dd2c5ecb..cf7d26e318eb 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) @@ -38,7 +39,7 @@ func TestBusBasics(t *testing.T) { t.Errorf("Expected no error sending: %v", err) } - ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) + ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) if err != nil { t.Fatal(err) } @@ -81,7 +82,7 @@ func TestNamespaceFiltering(t *testing.T) { t.Fatal(err) } - ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) + ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) if err != nil { t.Fatal(err) } @@ -137,13 +138,13 @@ func TestBus2Subscriptions(t *testing.T) { eventType2 := logical.EventType("someType2") bus.Start() - ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType1) + ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1)) if err != nil { t.Fatal(err) } defer cancel1() - ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType2) + ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2)) if err != nil { t.Fatal(err) } @@ -222,7 +223,7 @@ func TestBusSubscriptionsCancel(t *testing.T) { received := atomic.Int32{} for i := 0; i < create; i++ { - ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) + ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) if err != nil { t.Fatal(err) } @@ -297,3 +298,78 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { } t.Error("Timeout waiting for condition") } + +// TestBusWildcardSubscriptions tests that a single subscription can receive +// multiple event types using * for glob patterns. +func TestBusWildcardSubscriptions(t *testing.T) { + bus, err := NewEventBus(nil) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + + fooEventType := logical.EventType("kv/foo") + barEventType := logical.EventType("kv/bar") + bus.Start() + + ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*") + if err != nil { + t.Fatal(err) + } + defer cancel1() + + ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar") + if err != nil { + t.Fatal(err) + } + defer cancel2() + + event1, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + event2, err := logical.NewEvent() + if err != nil { + t.Fatal(err) + } + + err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2) + if err != nil { + t.Error(err) + } + err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1) + if err != nil { + t.Error(err) + } + + timeout := time.After(1 * time.Second) + // Expect to receive both events on ch1, which subscribed to kv/* + var ch1Seen []string + for i := 0; i < 2; i++ { + select { + case message := <-ch1: + ch1Seen = append(ch1Seen, message.Event.ID()) + case <-timeout: + t.Error("Timeout waiting for event1") + } + } + if len(ch1Seen) != 2 { + t.Errorf("Expected 2 events but got: %v", ch1Seen) + } else { + if !strutil.StrListContains(ch1Seen, event1.ID()) { + t.Errorf("Did not find %s event1 ID in ch1seen", event1.ID()) + } + if !strutil.StrListContains(ch1Seen, event2.ID()) { + t.Errorf("Did not find %s event2 ID in ch1seen", event2.ID()) + } + } + // Expect to receive just kv/bar on ch2, which subscribed to */bar + select { + case message := <-ch2: + if message.Event.ID() != event2.ID() { + t.Errorf("Got unexpected message: %v", message) + } + case <-timeout: + t.Error("Timeout waiting for event2") + } +} diff --git a/vault/events_test.go b/vault/events_test.go index d57be82325be..107be1913efb 100644 --- a/vault/events_test.go +++ b/vault/events_test.go @@ -19,7 +19,7 @@ func TestCanSendEventsFromBuiltinPlugin(t *testing.T) { if err != nil { t.Fatal(err) } - ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, logical.EventType(eventType)) + ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, eventType) if err != nil { t.Fatal(err) }