diff --git a/audit/entry_filter.go b/audit/entry_filter.go new file mode 100644 index 000000000000..7a7a253b10a7 --- /dev/null +++ b/audit/entry_filter.go @@ -0,0 +1,91 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package audit + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/go-bexpr" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internal/observability/event" +) + +var _ eventlogger.Node = (*EntryFilter)(nil) + +// NewEntryFilter should be used to create an EntryFilter node. +// The filter supplied should be in bexpr format and reference fields from logical.LogInputBexpr. +func NewEntryFilter(filter string) (*EntryFilter, error) { + const op = "audit.NewEntryFilter" + + filter = strings.TrimSpace(filter) + if filter == "" { + return nil, fmt.Errorf("%s: cannot create new audit filter with empty filter expression: %w", op, event.ErrInvalidParameter) + } + + eval, err := bexpr.CreateEvaluator(filter) + if err != nil { + return nil, fmt.Errorf("%s: cannot create new audit filter: %w", op, err) + } + + return &EntryFilter{evaluator: eval}, nil +} + +// Reopen is a no-op for the filter node. +func (*EntryFilter) Reopen() error { + return nil +} + +// Type describes the type of this node (filter). +func (*EntryFilter) Type() eventlogger.NodeType { + return eventlogger.NodeTypeFilter +} + +// Process will attempt to parse the incoming event data and decide whether it +// should be filtered or remain in the pipeline and passed to the next node. +func (f *EntryFilter) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { + const op = "audit.(EntryFilter).Process" + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if e == nil { + return nil, fmt.Errorf("%s: event is nil: %w", op, event.ErrInvalidParameter) + } + + a, ok := e.Payload.(*AuditEvent) + if !ok { + return nil, fmt.Errorf("%s: cannot parse event payload: %w", op, event.ErrInvalidParameter) + } + + // If we don't have data to process, then we're done. + if a.Data == nil { + return nil, nil + } + + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, fmt.Errorf("%s: cannot obtain namespace: %w", op, err) + } + + datum := a.Data.BexprDatum(ns.Path) + + result, err := f.evaluator.Evaluate(datum) + if err != nil { + return nil, fmt.Errorf("%s: unable to evaluate filter: %w", op, err) + } + + if result { + // Allow this event to carry on through the pipeline. + return e, nil + } + + // End process of this pipeline. + return nil, nil +} diff --git a/audit/entry_filter_test.go b/audit/entry_filter_test.go new file mode 100644 index 000000000000..a5efea1dc69a --- /dev/null +++ b/audit/entry_filter_test.go @@ -0,0 +1,249 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package audit + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internal/observability/event" + "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" +) + +// TestEntryFilter_NewEntryFilter tests that we can create EntryFilter types correctly. +func TestEntryFilter_NewEntryFilter(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + Filter string + IsErrorExpected bool + ExpectedErrorMessage string + }{ + "empty-filter": { + Filter: "", + IsErrorExpected: true, + ExpectedErrorMessage: "audit.NewEntryFilter: cannot create new audit filter with empty filter expression: invalid parameter", + }, + "spacey-filter": { + Filter: " ", + IsErrorExpected: true, + ExpectedErrorMessage: "audit.NewEntryFilter: cannot create new audit filter with empty filter expression: invalid parameter", + }, + "bad-filter": { + Filter: "____", + IsErrorExpected: true, + ExpectedErrorMessage: "audit.NewEntryFilter: cannot create new audit filter", + }, + "good-filter": { + Filter: "foo == bar", + IsErrorExpected: false, + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + f, err := NewEntryFilter(tc.Filter) + switch { + case tc.IsErrorExpected: + require.ErrorContains(t, err, tc.ExpectedErrorMessage) + require.Nil(t, f) + default: + require.NoError(t, err) + require.NotNil(t, f) + } + }) + } +} + +// TestEntryFilter_Reopen ensures we can reopen the filter node. +func TestEntryFilter_Reopen(t *testing.T) { + t.Parallel() + + f := &EntryFilter{} + res := f.Reopen() + require.Nil(t, res) +} + +// TestEntryFilter_Type ensures we always return the right type for this node. +func TestEntryFilter_Type(t *testing.T) { + t.Parallel() + + f := &EntryFilter{} + require.Equal(t, eventlogger.NodeTypeFilter, f.Type()) +} + +// TestEntryFilter_Process_ContextDone ensures that we stop processing the event +// if the context was cancelled. +func TestEntryFilter_Process_ContextDone(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + // Explicitly cancel the context + cancel() + + l, err := NewEntryFilter("foo == bar") + require.NoError(t, err) + + // Fake audit event + a, err := NewEvent(RequestType) + require.NoError(t, err) + + // Fake event logger event + e := &eventlogger.Event{ + Type: eventlogger.EventType(event.AuditType.String()), + CreatedAt: time.Now(), + Formatted: make(map[string][]byte), + Payload: a, + } + + e2, err := l.Process(ctx, e) + + require.Error(t, err) + require.ErrorContains(t, err, "context canceled") + + // Ensure that the pipeline won't continue. + require.Nil(t, e2) +} + +// TestEntryFilter_Process_NilEvent ensures we receive the right error when the +// event we are trying to process is nil. +func TestEntryFilter_Process_NilEvent(t *testing.T) { + t.Parallel() + + l, err := NewEntryFilter("foo == bar") + require.NoError(t, err) + e, err := l.Process(context.Background(), nil) + require.Error(t, err) + require.EqualError(t, err, "audit.(EntryFilter).Process: event is nil: invalid parameter") + + // Ensure that the pipeline won't continue. + require.Nil(t, e) +} + +// TestEntryFilter_Process_BadPayload ensures we receive the correct error when +// attempting to process an event with a payload that cannot be parsed back to +// an audit event. +func TestEntryFilter_Process_BadPayload(t *testing.T) { + t.Parallel() + + l, err := NewEntryFilter("foo == bar") + require.NoError(t, err) + + e := &eventlogger.Event{ + Type: eventlogger.EventType(event.AuditType.String()), + CreatedAt: time.Now(), + Formatted: make(map[string][]byte), + Payload: nil, + } + + e2, err := l.Process(context.Background(), e) + require.Error(t, err) + require.EqualError(t, err, "audit.(EntryFilter).Process: cannot parse event payload: invalid parameter") + + // Ensure that the pipeline won't continue. + require.Nil(t, e2) +} + +// TestEntryFilter_Process_NoAuditDataInPayload ensure we stop processing a pipeline +// when the data in the audit event is nil. +func TestEntryFilter_Process_NoAuditDataInPayload(t *testing.T) { + t.Parallel() + + l, err := NewEntryFilter("foo == bar") + require.NoError(t, err) + + a, err := NewEvent(RequestType) + require.NoError(t, err) + + // Ensure audit data is nil + a.Data = nil + + e := &eventlogger.Event{ + Type: eventlogger.EventType(event.AuditType.String()), + CreatedAt: time.Now(), + Formatted: make(map[string][]byte), + Payload: a, + } + + e2, err := l.Process(context.Background(), e) + + // Make sure we get the 'nil, nil' response to stop processing this pipeline. + require.NoError(t, err) + require.Nil(t, e2) +} + +// TestEntryFilter_Process_FilterSuccess tests that when a filter matches we +// receive no error and the event is not nil so it continues in the pipeline. +func TestEntryFilter_Process_FilterSuccess(t *testing.T) { + t.Parallel() + + l, err := NewEntryFilter("mount_type == juan") + require.NoError(t, err) + + a, err := NewEvent(RequestType) + require.NoError(t, err) + + a.Data = &logical.LogInput{ + Request: &logical.Request{ + Operation: logical.CreateOperation, + MountType: "juan", + }, + } + + e := &eventlogger.Event{ + Type: eventlogger.EventType(event.AuditType.String()), + CreatedAt: time.Now(), + Formatted: make(map[string][]byte), + Payload: a, + } + + ctx := namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace) + + e2, err := l.Process(ctx, e) + + require.NoError(t, err) + require.NotNil(t, e2) +} + +// TestEntryFilter_Process_FilterFail tests that when a filter fails to match we +// receive no error, but also the event is nil so that the pipeline completes. +func TestEntryFilter_Process_FilterFail(t *testing.T) { + t.Parallel() + + l, err := NewEntryFilter("mount_type == john and operation == create and namespace == root") + require.NoError(t, err) + + a, err := NewEvent(RequestType) + require.NoError(t, err) + + a.Data = &logical.LogInput{ + Request: &logical.Request{ + Operation: logical.CreateOperation, + MountType: "juan", + }, + } + + e := &eventlogger.Event{ + Type: eventlogger.EventType(event.AuditType.String()), + CreatedAt: time.Now(), + Formatted: make(map[string][]byte), + Payload: a, + } + + ctx := namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace) + + e2, err := l.Process(ctx, e) + + require.NoError(t, err) + require.Nil(t, e2) +} diff --git a/audit/entry_formatter.go b/audit/entry_formatter.go index a6b836d0a446..6937949db368 100644 --- a/audit/entry_formatter.go +++ b/audit/entry_formatter.go @@ -11,16 +11,13 @@ import ( "strings" "time" - "github.com/jefferai/jsonx" - - "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/sdk/logical" - "github.com/go-jose/go-jose/v3/jwt" + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internal/observability/event" "github.com/hashicorp/vault/sdk/helper/jsonutil" - - "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/sdk/logical" + "github.com/jefferai/jsonx" ) var ( @@ -29,7 +26,7 @@ var ( ) // NewEntryFormatter should be used to create an EntryFormatter. -// Accepted options: WithPrefix, WithHeaderFormatter. +// Accepted options: WithHeaderFormatter, WithPrefix. func NewEntryFormatter(config FormatterConfig, salter Salter, opt ...Option) (*EntryFormatter, error) { const op = "audit.NewEntryFormatter" diff --git a/audit/nodes.go b/audit/nodes.go index 01602d4b1389..624777e72515 100644 --- a/audit/nodes.go +++ b/audit/nodes.go @@ -15,10 +15,12 @@ import ( ) // ProcessManual will attempt to create an (audit) event with the specified data -// and manually iterate over the supplied nodes calling Process on each. +// and manually iterate over the supplied nodes calling Process on each until the +// event is nil (which indicates the pipeline has completed). // Order of IDs in the NodeID slice determines the order they are processed. // (Audit) Event will be of RequestType (as opposed to ResponseType). -// The last node must be a sink node (eventlogger.NodeTypeSink). +// The last node must be a filter node (eventlogger.NodeTypeFilter) or +// sink node (eventlogger.NodeTypeSink). func ProcessManual(ctx context.Context, data *logical.LogInput, ids []eventlogger.NodeID, nodes map[eventlogger.NodeID]eventlogger.Node) error { switch { case data == nil: @@ -52,9 +54,15 @@ func ProcessManual(ctx context.Context, data *logical.LogInput, ids []eventlogge // Process nodes in order, updating the event with the result. // This means we *should* do: - // 1. formatter (temporary) - // 2. sink + // 1. filter (optional if configured) + // 2. formatter (temporary) + // 3. sink for _, id := range ids { + // If the event is nil, we've completed processing the pipeline (hopefully + // by either a filter node or a sink node). + if e == nil { + break + } node, ok := nodes[id] if !ok { return fmt.Errorf("node not found: %v", id) @@ -74,12 +82,14 @@ func ProcessManual(ctx context.Context, data *logical.LogInput, ids []eventlogge return err } - // Track the last node we have processed, as we should end with a sink. + // Track the last node we have processed, as we should end with a filter or sink. lastSeen = node.Type() } - if lastSeen != eventlogger.NodeTypeSink { - return errors.New("last node must be a sink") + switch lastSeen { + case eventlogger.NodeTypeSink, eventlogger.NodeTypeFilter: + default: + return errors.New("last node must be a filter or sink") } return nil diff --git a/audit/nodes_test.go b/audit/nodes_test.go index a50034c1d418..3aa4ef533210 100644 --- a/audit/nodes_test.go +++ b/audit/nodes_test.go @@ -185,12 +185,13 @@ func TestProcessManual_LastNodeNotSink(t *testing.T) { err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) require.Error(t, err) - require.EqualError(t, err, "last node must be a sink") + require.EqualError(t, err, "last node must be a filter or sink") } -// TestProcessManual ensures that the manual processing of a test message works -// as expected with proper inputs. -func TestProcessManual(t *testing.T) { +// TestProcessManualEndWithSink ensures that the manual processing of a test +// message works as expected with proper inputs, which mean processing ends with +// sink node. +func TestProcessManualEndWithSink(t *testing.T) { t.Parallel() var ids []eventlogger.NodeID @@ -215,6 +216,39 @@ func TestProcessManual(t *testing.T) { require.NoError(t, err) } +// TestProcessManual_EndWithFilter ensures that the manual processing of a test +// message works as expected with proper inputs, which mean processing ends with +// sink node. +func TestProcessManual_EndWithFilter(t *testing.T) { + t.Parallel() + + var ids []eventlogger.NodeID + nodes := make(map[eventlogger.NodeID]eventlogger.Node) + + // Filter node + filterId, filterNode := newFilterNode(t) + ids = append(ids, filterId) + nodes[filterId] = filterNode + + // Formatter node + formatterId, formatterNode := newFormatterNode(t) + ids = append(ids, formatterId) + nodes[formatterId] = formatterNode + + // Sink node + sinkId, sinkNode := newSinkNode(t) + ids = append(ids, sinkId) + nodes[sinkId] = sinkNode + + // Data + requestId, err := uuid.GenerateUUID() + require.NoError(t, err) + data := newData(requestId) + + err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) + require.NoError(t, err) +} + // newSinkNode creates a new UUID and NoopSink (sink node). func newSinkNode(t *testing.T) (eventlogger.NodeID, *event.NoopSink) { t.Helper() @@ -226,6 +260,25 @@ func newSinkNode(t *testing.T) (eventlogger.NodeID, *event.NoopSink) { return sinkId, sinkNode } +// TestFilter is a trivial implementation of eventlogger.Node used as a placeholder +// for Filter nodes in tests. +type TestFilter struct{} + +// Process trivially filters the event preventing it from being processed by subsequent nodes. +func (f *TestFilter) Process(_ context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { + return nil, nil +} + +// Reopen does nothing. +func (f *TestFilter) Reopen() error { + return nil +} + +// Type returns the eventlogger.NodeTypeFormatter type. +func (f *TestFilter) Type() eventlogger.NodeType { + return eventlogger.NodeTypeFilter +} + // TestFormatter is a trivial implementation of the eventlogger.Node interface // used as a place-holder for Formatter nodes in tests. type TestFormatter struct{} @@ -248,6 +301,15 @@ func (f *TestFormatter) Type() eventlogger.NodeType { return eventlogger.NodeTypeFormatter } +// newFilterNode creates a new TestFormatter (filter node). +func newFilterNode(t *testing.T) (eventlogger.NodeID, *TestFilter) { + nodeId, err := event.GenerateNodeID() + require.NoError(t, err) + node := &TestFilter{} + + return nodeId, node +} + // newFormatterNode creates a new TestFormatter (formatter node). func newFormatterNode(t *testing.T) (eventlogger.NodeID, *TestFormatter) { nodeId, err := event.GenerateNodeID() diff --git a/audit/sink_wrapper.go b/audit/sink_wrapper.go index 3cf79c709535..f61c908a687c 100644 --- a/audit/sink_wrapper.go +++ b/audit/sink_wrapper.go @@ -11,6 +11,8 @@ import ( "github.com/hashicorp/eventlogger" ) +var _ eventlogger.Node = (*SinkWrapper)(nil) + // SinkWrapper is a wrapper for any kind of Sink Node that processes events // containing an AuditEvent payload. type SinkWrapper struct { diff --git a/audit/types.go b/audit/types.go index af5a5830dfae..3434ff84d840 100644 --- a/audit/types.go +++ b/audit/types.go @@ -8,9 +8,9 @@ import ( "io" "time" - "github.com/hashicorp/eventlogger" + "github.com/hashicorp/go-bexpr" + "github.com/hashicorp/vault/internal/observability/event" "github.com/hashicorp/vault/sdk/helper/salt" - "github.com/hashicorp/vault/sdk/logical" ) @@ -144,6 +144,13 @@ type FormatterConfig struct { RequiredFormat format } +// EntryFilter should be used to filter audit requests and responses which should +// make it to a sink. +type EntryFilter struct { + // the evaluator for the bexpr expression that should be applied by the node. + evaluator *bexpr.Evaluator +} + // RequestEntry is the structure of a request audit log entry. type RequestEntry struct { Time string `json:"time,omitempty"` @@ -268,6 +275,10 @@ type Backend interface { // Salter interface must be implemented by anything implementing Backend. Salter + // The PipelineReader interface allows backends to surface information about their + // nodes for node and pipeline registration. + event.PipelineReader + // LogRequest is used to synchronously log a request. This is done after the // request is authorized but before the request is executed. The arguments // MUST not be modified in any way. They should be deep copied if this is @@ -291,12 +302,6 @@ type Backend interface { // Invalidate is called for path invalidation Invalidate(context.Context) - - // RegisterNodesAndPipeline provides an eventlogger.Broker pointer so that - // the Backend can call its RegisterNode and RegisterPipeline methods with - // the nodes and the pipeline that were created in the corresponding - // Factory function. - RegisterNodesAndPipeline(*eventlogger.Broker, string) error } // BackendConfig contains configuration parameters used in the factory func to diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index fc6a44a58719..2681ee244e99 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -27,75 +27,71 @@ const ( discard = "discard" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { - if conf.SaltConfig == nil { - return nil, fmt.Errorf("nil salt config") - } - if conf.SaltView == nil { - return nil, fmt.Errorf("nil salt view") - } - - path, ok := conf.Config["file_path"] - if !ok { - path, ok = conf.Config["path"] - if !ok { - return nil, fmt.Errorf("file_path is required") - } - } +var _ audit.Backend = (*Backend)(nil) - // normalize path if configured for stdout - if strings.EqualFold(path, stdout) { - path = stdout - } - if strings.EqualFold(path, discard) { - path = discard - } +// Backend is the audit backend for the file-based audit store. +// +// NOTE: This audit backend is currently very simple: it appends to a file. +// It doesn't do anything more at the moment to assist with rotation +// or reset the write cursor, this should be done in the future. +type Backend struct { + f *os.File + fileLock sync.RWMutex + formatter *audit.EntryFormatterWriter + formatConfig audit.FormatterConfig + mode os.FileMode + name string + nodeIDList []eventlogger.NodeID + nodeMap map[eventlogger.NodeID]eventlogger.Node + filePath string + salt *atomic.Value + saltConfig *salt.Config + saltMutex sync.RWMutex + saltView logical.Storage +} - var cfgOpts []audit.Option +func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { + const op = "file.Factory" - if format, ok := conf.Config["format"]; ok { - cfgOpts = append(cfgOpts, audit.WithFormat(format)) + if conf.SaltConfig == nil { + return nil, fmt.Errorf("%s: nil salt config", op) } - - // Check if hashing of accessor is disabled - if hmacAccessorRaw, ok := conf.Config["hmac_accessor"]; ok { - v, err := strconv.ParseBool(hmacAccessorRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithHMACAccessor(v)) + if conf.SaltView == nil { + return nil, fmt.Errorf("%s: nil salt view", op) } - // Check if raw logging is enabled - if raw, ok := conf.Config["log_raw"]; ok { - v, err := strconv.ParseBool(raw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithRaw(v)) + // Get file path from config or fall back to the old option name ('path') for compatibility + // (see commit bac4fe0799a372ba1245db642f3f6cd1f1d02669). + var filePath string + if p, ok := conf.Config["file_path"]; ok { + filePath = p + } else if p, ok = conf.Config["path"]; ok { + filePath = p + } else { + return nil, fmt.Errorf("%s: file_path is required", op) } - if elideListResponsesRaw, ok := conf.Config["elide_list_responses"]; ok { - v, err := strconv.ParseBool(elideListResponsesRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithElision(v)) + // normalize file path if configured for stdout + if strings.EqualFold(filePath, stdout) { + filePath = stdout + } + if strings.EqualFold(filePath, discard) { + filePath = discard } mode := os.FileMode(0o600) if modeRaw, ok := conf.Config["mode"]; ok { m, err := strconv.ParseUint(modeRaw, 8, 32) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: unable to parse 'mode': %w", op, err) } switch m { case 0: // if mode is 0000, then do not modify file mode - if path != stdout && path != discard { - fileInfo, err := os.Stat(path) + if filePath != stdout && filePath != discard { + fileInfo, err := os.Stat(filePath) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: unable to stat %q: %w", op, filePath, err) } mode = fileInfo.Mode() } @@ -104,18 +100,19 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool } } - cfg, err := audit.NewFormatterConfig(cfgOpts...) + cfg, err := formatterConfig(conf.Config) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } b := &Backend{ - path: path, + filePath: filePath, + formatConfig: cfg, mode: mode, + name: conf.MountPath, saltConfig: conf.SaltConfig, saltView: conf.SaltView, salt: new(atomic.Value), - formatConfig: cfg, } // Ensure we are working with the right type by explicitly storing a nil of @@ -125,8 +122,9 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool // Configure the formatter for either case. f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithHeaderFormatter(headersConfig), audit.WithPrefix(conf.Config["prefix"])) if err != nil { - return nil, fmt.Errorf("error creating formatter: %w", err) + return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) } + var w audit.Writer switch b.formatConfig.RequiredFormat { case audit.JSONFormat: @@ -134,63 +132,40 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool case audit.JSONxFormat: w = &audit.JSONxWriter{Prefix: conf.Config["prefix"]} default: - return nil, fmt.Errorf("unknown format type %q", b.formatConfig.RequiredFormat) + return nil, fmt.Errorf("%s: unknown format type %q", op, b.formatConfig.RequiredFormat) } fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) if err != nil { - return nil, fmt.Errorf("error creating formatter writer: %w", err) + return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) } b.formatter = fw if useEventLogger { - b.nodeIDList = make([]eventlogger.NodeID, 2) + b.nodeIDList = []eventlogger.NodeID{} b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - formatterNodeID, err := event.GenerateNodeID() + err := b.configureFilterNode(conf.Config["filter"]) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err) + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } - b.nodeIDList[0] = formatterNodeID - b.nodeMap[formatterNodeID] = f - - var sinkNode eventlogger.Node - - switch path { - case stdout: - sinkNode = &audit.SinkWrapper{Name: path, Sink: event.NewStdoutSinkNode(b.formatConfig.RequiredFormat.String())} - case discard: - sinkNode = &audit.SinkWrapper{Name: path, Sink: event.NewNoopSink()} - default: - var err error - - var opts []event.Option - // Check if mode is provided - if modeRaw, ok := conf.Config["mode"]; ok { - opts = append(opts, event.WithFileMode(modeRaw)) - } - - // The NewFileSink function attempts to open the file and will - // return an error if it can't. - n, err := event.NewFileSink( - b.path, - b.formatConfig.RequiredFormat.String(), opts...) - if err != nil { - return nil, fmt.Errorf("file sink creation failed for path %q: %w", path, err) - } - sinkNode = &audit.SinkWrapper{Name: conf.MountPath, Sink: n} + formatterOpts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), + audit.WithPrefix(conf.Config["prefix"]), } - sinkNodeID, err := event.GenerateNodeID() + err = b.configureFormatterNode(cfg, formatterOpts...) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for sink node: %w", err) + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) } - b.nodeIDList[1] = sinkNodeID - b.nodeMap[sinkNodeID] = sinkNode + err = b.configureSinkNode(conf.MountPath, filePath, conf.Config["mode"], cfg.RequiredFormat.String()) + if err != nil { + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) + } } else { - switch path { + switch filePath { case stdout: case discard: default: @@ -198,7 +173,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool // otherwise it will be too late to catch later without problems // (ref: https://github.com/hashicorp/vault/issues/550) if err := b.open(); err != nil { - return nil, fmt.Errorf("sanity check failed; unable to open %q for writing: %w", path, err) + return nil, fmt.Errorf("%s: sanity check failed; unable to open %q for writing: %w", op, filePath, err) } } } @@ -206,32 +181,6 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool return b, nil } -// Backend is the audit backend for the file-based audit store. -// -// NOTE: This audit backend is currently very simple: it appends to a file. -// It doesn't do anything more at the moment to assist with rotation -// or reset the write cursor, this should be done in the future. -type Backend struct { - path string - - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig - - fileLock sync.RWMutex - f *os.File - mode os.FileMode - - saltMutex sync.RWMutex - salt *atomic.Value - saltConfig *salt.Config - saltView logical.Storage - - nodeIDList []eventlogger.NodeID - nodeMap map[eventlogger.NodeID]eventlogger.Node -} - -var _ audit.Backend = (*Backend)(nil) - func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { s := b.salt.Load().(*salt.Salt) if s != nil { @@ -256,9 +205,10 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { return newSalt, nil } +// Deprecated: Use eventlogger. func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var writer io.Writer - switch b.path { + switch b.filePath { case stdout: writer = os.Stdout case discard: @@ -274,6 +224,7 @@ func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { return b.log(ctx, buf, writer) } +// Deprecated: Use eventlogger. func (b *Backend) log(_ context.Context, buf *bytes.Buffer, writer io.Writer) error { reader := bytes.NewReader(buf.Bytes()) @@ -290,7 +241,7 @@ func (b *Backend) log(_ context.Context, buf *bytes.Buffer, writer io.Writer) er if _, err := reader.WriteTo(writer); err == nil { b.fileLock.Unlock() return nil - } else if b.path == stdout { + } else if b.filePath == stdout { b.fileLock.Unlock() return err } @@ -312,9 +263,10 @@ func (b *Backend) log(_ context.Context, buf *bytes.Buffer, writer io.Writer) er return err } +// Deprecated: Use eventlogger. func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { var writer io.Writer - switch b.path { + switch b.filePath { case stdout: writer = os.Stdout case discard: @@ -338,7 +290,7 @@ func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, conf // Old behavior var writer io.Writer - switch b.path { + switch b.filePath { case stdout: writer = os.Stdout case discard: @@ -360,27 +312,28 @@ func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, conf } // The file lock must be held before calling this +// Deprecated: Use eventlogger. func (b *Backend) open() error { if b.f != nil { return nil } - if err := os.MkdirAll(filepath.Dir(b.path), b.mode); err != nil { + if err := os.MkdirAll(filepath.Dir(b.filePath), b.mode); err != nil { return err } var err error - b.f, err = os.OpenFile(b.path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, b.mode) + b.f, err = os.OpenFile(b.filePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, b.mode) if err != nil { return err } // Change the file mode in case the log file already existed. We special // case /dev/null since we can't chmod it and bypass if the mode is zero - switch b.path { + switch b.filePath { case "/dev/null": default: if b.mode != 0 { - err = os.Chmod(b.path, b.mode) + err = os.Chmod(b.filePath, b.mode) if err != nil { return err } @@ -402,7 +355,7 @@ func (b *Backend) Reload(_ context.Context) error { return nil } else { // old non-eventlogger behavior - switch b.path { + switch b.filePath { case stdout, discard: return nil } @@ -432,20 +385,168 @@ func (b *Backend) Invalidate(_ context.Context) { b.salt.Store((*salt.Salt)(nil)) } -// RegisterNodesAndPipeline registers the nodes and a pipeline as required by -// the audit.Backend interface. -func (b *Backend) RegisterNodesAndPipeline(broker *eventlogger.Broker, name string) error { - for id, node := range b.nodeMap { - if err := broker.RegisterNode(id, node, eventlogger.WithNodeRegistrationPolicy(eventlogger.DenyOverwrite)); err != nil { - return err +// formatterConfig creates the configuration required by a formatter node using +// the config map supplied to the factory. +func formatterConfig(config map[string]string) (audit.FormatterConfig, error) { + const op = "file.formatterConfig" + + var opts []audit.Option + + if format, ok := config["format"]; ok { + opts = append(opts, audit.WithFormat(format)) + } + + // Check if hashing of accessor is disabled + if hmacAccessorRaw, ok := config["hmac_accessor"]; ok { + v, err := strconv.ParseBool(hmacAccessorRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'hmac_accessor': %w", op, err) + } + opts = append(opts, audit.WithHMACAccessor(v)) + } + + // Check if raw logging is enabled + if raw, ok := config["log_raw"]; ok { + v, err := strconv.ParseBool(raw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'log_raw': %w", op, err) } + opts = append(opts, audit.WithRaw(v)) + } + + if elideListResponsesRaw, ok := config["elide_list_responses"]; ok { + v, err := strconv.ParseBool(elideListResponsesRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'elide_list_responses': %w", op, err) + } + opts = append(opts, audit.WithElision(v)) + } + + return audit.NewFormatterConfig(opts...) +} + +// configureFilterNode is used to configure a filter node and associated ID on the Backend. +func (b *Backend) configureFilterNode(filter string) error { + const op = "file.(Backend).configureFilterNode" + + filter = strings.TrimSpace(filter) + if filter == "" { + return nil + } + + filterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for filter node: %w", op, err) + } + + filterNode, err := audit.NewEntryFilter(filter) + if err != nil { + return fmt.Errorf("%s: error creating filter node: %w", op, err) } - pipeline := eventlogger.Pipeline{ - PipelineID: eventlogger.PipelineID(name), - EventType: eventlogger.EventType(event.AuditType.String()), - NodeIDs: b.nodeIDList, + b.nodeIDList = append(b.nodeIDList, filterNodeID) + b.nodeMap[filterNodeID] = filterNode + return nil +} + +// configureFormatterNode is used to configure a formatter node and associated ID on the Backend. +func (b *Backend) configureFormatterNode(formatConfig audit.FormatterConfig, opts ...audit.Option) error { + const op = "file.(Backend).configureFormatterNode" + + formatterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for formatter node: %w", op, err) } - return broker.RegisterPipeline(pipeline, eventlogger.WithPipelineRegistrationPolicy(eventlogger.DenyOverwrite)) + formatterNode, err := audit.NewEntryFormatter(formatConfig, b, opts...) + if err != nil { + return fmt.Errorf("%s: error creating formatter: %w", op, err) + } + + b.nodeIDList = append(b.nodeIDList, formatterNodeID) + b.nodeMap[formatterNodeID] = formatterNode + return nil +} + +// configureSinkNode is used to configure a sink node and associated ID on the Backend. +func (b *Backend) configureSinkNode(name string, filePath string, mode string, format string) error { + const op = "file.(Backend).configureSinkNode" + + name = strings.TrimSpace(name) + if name == "" { + return fmt.Errorf("%s: name is required: %w", op, event.ErrInvalidParameter) + } + + filePath = strings.TrimSpace(filePath) + if filePath == "" { + return fmt.Errorf("%s: file path is required: %w", op, event.ErrInvalidParameter) + } + + format = strings.TrimSpace(format) + if format == "" { + return fmt.Errorf("%s: format is required: %w", op, event.ErrInvalidParameter) + } + + sinkNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for sink node: %w", op, err) + } + + // normalize file path if configured for stdout or discard + if strings.EqualFold(filePath, stdout) { + filePath = stdout + } else if strings.EqualFold(filePath, discard) { + filePath = discard + } + + var sinkNode eventlogger.Node + var sinkName string + + switch filePath { + case stdout: + sinkName = stdout + sinkNode, err = event.NewStdoutSinkNode(format) + case discard: + sinkName = discard + sinkNode = event.NewNoopSink() + default: + // The NewFileSink function attempts to open the file and will return an error if it can't. + sinkName = name + sinkNode, err = event.NewFileSink(filePath, format, []event.Option{event.WithFileMode(mode)}...) + } + + if err != nil { + return fmt.Errorf("%s: file sink creation failed for path %q: %w", op, filePath, err) + } + + sinkNode = &audit.SinkWrapper{Name: sinkName, Sink: sinkNode} + + b.nodeIDList = append(b.nodeIDList, sinkNodeID) + b.nodeMap[sinkNodeID] = sinkNode + return nil +} + +// Name for this backend, this would ideally correspond to the mount path for the audit device. +func (b *Backend) Name() string { + return b.name +} + +// Nodes returns the nodes which should be used by the event framework to process audit entries. +func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node { + return b.nodeMap +} + +// NodeIDs returns the IDs of the nodes, in the order they are required. +func (b *Backend) NodeIDs() []eventlogger.NodeID { + return b.nodeIDList +} + +// EventType returns the event type for the backend. +func (b *Backend) EventType() eventlogger.EventType { + return eventlogger.EventType(event.AuditType.String()) +} + +// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter. +func (b *Backend) HasFiltering() bool { + return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter } diff --git a/builtin/audit/file/backend_test.go b/builtin/audit/file/backend_test.go index e0ba06319ca5..17ea7fd20365 100644 --- a/builtin/audit/file/backend_test.go +++ b/builtin/audit/file/backend_test.go @@ -12,10 +12,12 @@ import ( "testing" "time" + "github.com/hashicorp/eventlogger" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" ) func TestAuditFile_fileModeNew(t *testing.T) { @@ -145,6 +147,7 @@ func TestAuditFile_EventLogger_fileModeNew(t *testing.T) { } _, err = Factory(context.Background(), &audit.BackendConfig{ + MountPath: "foo/bar", SaltConfig: &salt.Config{}, SaltView: &logical.InmemStorage{}, Config: config, @@ -210,3 +213,366 @@ func BenchmarkAuditFile_request(b *testing.B) { } }) } + +// TestBackend_formatterConfig ensures that all the configuration values are parsed correctly. +func TestBackend_formatterConfig(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + config map[string]string + want audit.FormatterConfig + wantErr bool + expectedMessage string + }{ + "happy-path-json": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "json", + }, wantErr: false, + }, + "happy-path-jsonx": { + config: map[string]string{ + "format": audit.JSONxFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "jsonx", + }, + wantErr: false, + }, + "invalid-format": { + config: map[string]string{ + "format": " squiggly ", + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedMessage: "audit.NewFormatterConfig: error applying options: audit.(format).validate: 'squiggly' is not a valid format: invalid parameter", + }, + "invalid-hmac-accessor": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedMessage: "file.formatterConfig: unable to parse 'hmac_accessor': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-log-raw": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedMessage: "file.formatterConfig: unable to parse 'log_raw': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-elide-bool": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedMessage: "file.formatterConfig: unable to parse 'elide_list_responses': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := formatterConfig(tc.config) + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedMessage) + } else { + require.NoError(t, err) + } + require.Equal(t, tc.want, got) + }) + } +} + +// TestBackend_configureFilterNode ensures that configureFilterNode handles various +// filter values as expected. Empty (including whitespace) strings should return +// no error but skip configuration of the node. +func TestBackend_configureFilterNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + filter string + shouldSkipNode bool + wantErr bool + expectedErrorMsg string + }{ + "happy": { + filter: "foo == bar", + }, + "empty": { + filter: "", + shouldSkipNode: true, + }, + "spacey": { + filter: " ", + shouldSkipNode: true, + }, + "bad": { + filter: "___qwerty", + wantErr: true, + expectedErrorMsg: "file.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureFilterNode(tc.filter) + + switch { + case tc.wantErr: + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedErrorMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + case tc.shouldSkipNode: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + default: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + } + }) + } +} + +// TestBackend_configureFormatterNode ensures that configureFormatterNode +// populates the nodeIDList and nodeMap on Backend when given valid formatConfig. +func TestBackend_configureFormatterNode(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) +} + +// TestBackend_configureSinkNode ensures that we can correctly configure the sink +// node on the Backend, and any incorrect parameters result in the relevant errors. +func TestBackend_configureSinkNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + name string + filePath string + mode string + format string + wantErr bool + expectedErrMsg string + expectedName string + }{ + "name-empty": { + name: "", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "name-whitespace": { + name: " ", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "filePath-empty": { + name: "foo", + filePath: "", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: file path is required: invalid parameter", + }, + "filePath-whitespace": { + name: "foo", + filePath: " ", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: file path is required: invalid parameter", + }, + "filePath-stdout-lower": { + name: "foo", + expectedName: "stdout", + filePath: "stdout", + format: "json", + }, + "filePath-stdout-upper": { + name: "foo", + expectedName: "stdout", + filePath: "STDOUT", + format: "json", + }, + "filePath-stdout-mixed": { + name: "foo", + expectedName: "stdout", + filePath: "StdOut", + format: "json", + }, + "filePath-discard-lower": { + name: "foo", + expectedName: "discard", + filePath: "discard", + format: "json", + }, + "filePath-discard-upper": { + name: "foo", + expectedName: "discard", + filePath: "DISCARD", + format: "json", + }, + "filePath-discard-mixed": { + name: "foo", + expectedName: "discard", + filePath: "DisCArd", + format: "json", + }, + "format-empty": { + name: "foo", + filePath: "/tmp/", + format: "", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "format-whitespace": { + name: "foo", + filePath: "/tmp/", + format: " ", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "filePath-weird-with-mode-zero": { + name: "foo", + filePath: "/tmp/qwerty", + format: "json", + mode: "0", + wantErr: true, + expectedErrMsg: "file.(Backend).configureSinkNode: file sink creation failed for path \"/tmp/qwerty\": event.NewFileSink: unable to determine existing file mode: stat /tmp/qwerty: no such file or directory", + }, + "happy": { + name: "foo", + filePath: "/tmp/audit.log", + mode: "", + format: "json", + wantErr: false, + expectedName: "foo", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureSinkNode(tc.name, tc.filePath, tc.mode, tc.format) + + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + } else { + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) + sw, ok := node.(*audit.SinkWrapper) + require.True(t, ok) + require.Equal(t, tc.expectedName, sw.Name) + } + }) + } +} + +// TestBackend_configureFilterFormatterSink ensures that configuring all three +// types of nodes on a Backend works as expected, i.e. we have all three nodes +// at the end and nothing gets overwritten. The order of calls influences the +// slice of IDs on the Backend. +func TestBackend_configureFilterFormatterSink(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFilterNode("foo == bar") + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + require.NoError(t, err) + + err = b.configureSinkNode("foo", "/tmp/foo", "0777", "json") + require.NoError(t, err) + + require.Len(t, b.nodeIDList, 3) + require.Len(t, b.nodeMap, 3) + + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + + id = b.nodeIDList[1] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) + + id = b.nodeIDList[2] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) +} diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index 1e906468c7f8..09662c2ab683 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "strconv" + "strings" "sync" "time" @@ -21,83 +22,76 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { +var _ audit.Backend = (*Backend)(nil) + +// Backend is the audit backend for the socket audit transport. +type Backend struct { + sync.Mutex + address string + connection net.Conn + formatter *audit.EntryFormatterWriter + formatConfig audit.FormatterConfig + name string + nodeIDList []eventlogger.NodeID + nodeMap map[eventlogger.NodeID]eventlogger.Node + salt *salt.Salt + saltConfig *salt.Config + saltMutex sync.RWMutex + saltView logical.Storage + socketType string + writeDuration time.Duration +} + +func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { + const op = "socket.Factory" + if conf.SaltConfig == nil { - return nil, fmt.Errorf("nil salt config") + return nil, fmt.Errorf("%s: nil salt config", op) } + if conf.SaltView == nil { - return nil, fmt.Errorf("nil salt view") + return nil, fmt.Errorf("%s: nil salt view", op) } address, ok := conf.Config["address"] if !ok { - return nil, fmt.Errorf("address is required") + return nil, fmt.Errorf("%s: address is required", op) } socketType, ok := conf.Config["socket_type"] if !ok { socketType = "tcp" } + writeDeadline, ok := conf.Config["write_timeout"] if !ok { writeDeadline = "2s" } + writeDuration, err := parseutil.ParseDurationSecond(writeDeadline) if err != nil { - return nil, err - } - - var cfgOpts []audit.Option - - if format, ok := conf.Config["format"]; ok { - cfgOpts = append(cfgOpts, audit.WithFormat(format)) - } - - // Check if hashing of accessor is disabled - if hmacAccessorRaw, ok := conf.Config["hmac_accessor"]; ok { - v, err := strconv.ParseBool(hmacAccessorRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithHMACAccessor(v)) - } - - // Check if raw logging is enabled - if raw, ok := conf.Config["log_raw"]; ok { - v, err := strconv.ParseBool(raw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithRaw(v)) + return nil, fmt.Errorf("%s: failed to parse 'write_timeout': %w", op, err) } - if elideListResponsesRaw, ok := conf.Config["elide_list_responses"]; ok { - v, err := strconv.ParseBool(elideListResponsesRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithElision(v)) - } - - cfg, err := audit.NewFormatterConfig(cfgOpts...) + cfg, err := formatterConfig(conf.Config) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } b := &Backend{ - saltConfig: conf.SaltConfig, - saltView: conf.SaltView, - formatConfig: cfg, - - writeDuration: writeDuration, address: address, + formatConfig: cfg, + name: conf.MountPath, + saltConfig: conf.SaltConfig, + saltView: conf.SaltView, socketType: socketType, + writeDuration: writeDuration, } // Configure the formatter for either case. - f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithHeaderFormatter(headersConfig)) + f, err := audit.NewEntryFormatter(cfg, b, audit.WithHeaderFormatter(headersConfig)) if err != nil { - return nil, fmt.Errorf("error creating formatter: %w", err) + return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) } var w audit.Writer switch b.formatConfig.RequiredFormat { @@ -109,72 +103,44 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) if err != nil { - return nil, fmt.Errorf("error creating formatter writer: %w", err) + return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) } b.formatter = fw if useEventLogger { - var opts []event.Option + b.nodeIDList = []eventlogger.NodeID{} + b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - if socketType, ok := conf.Config["socket_type"]; ok { - opts = append(opts, event.WithSocketType(socketType)) + err := b.configureFilterNode(conf.Config["filter"]) + if err != nil { + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } - if writeDeadline, ok := conf.Config["write_timeout"]; ok { - opts = append(opts, event.WithMaxDuration(writeDeadline)) + opts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), } - b.nodeIDList = make([]eventlogger.NodeID, 2) - b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - - formatterNodeID, err := event.GenerateNodeID() + err = b.configureFormatterNode(cfg, opts...) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err) + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) } - b.nodeIDList[0] = formatterNodeID - b.nodeMap[formatterNodeID] = f - n, err := event.NewSocketSink(b.formatConfig.RequiredFormat.String(), address, opts...) - if err != nil { - return nil, fmt.Errorf("error creating socket sink node: %w", err) + sinkOpts := []event.Option{ + event.WithSocketType(socketType), + event.WithMaxDuration(writeDeadline), } - sinkNode := &audit.SinkWrapper{Name: conf.MountPath, Sink: n} - sinkNodeID, err := event.GenerateNodeID() + + err = b.configureSinkNode(conf.MountPath, address, cfg.RequiredFormat.String(), sinkOpts...) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for sink node: %w", err) + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) } - b.nodeIDList[1] = sinkNodeID - b.nodeMap[sinkNodeID] = sinkNode } return b, nil } -// Backend is the audit backend for the socket audit transport. -type Backend struct { - connection net.Conn - - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig - - writeDuration time.Duration - address string - socketType string - - sync.Mutex - - saltMutex sync.RWMutex - salt *salt.Salt - saltConfig *salt.Config - saltView logical.Storage - - nodeIDList []eventlogger.NodeID - nodeMap map[eventlogger.NodeID]eventlogger.Node -} - -var _ audit.Backend = (*Backend)(nil) - +// Deprecated: Use eventlogger. func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { @@ -198,6 +164,7 @@ func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { return err } +// Deprecated: Use eventlogger. func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteResponse(ctx, &buf, in); err != nil { @@ -256,6 +223,7 @@ func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, conf return err } +// Deprecated: Use eventlogger. func (b *Backend) write(ctx context.Context, buf []byte) error { if b.connection == nil { if err := b.reconnect(ctx); err != nil { @@ -276,6 +244,7 @@ func (b *Backend) write(ctx context.Context, buf []byte) error { return nil } +// Deprecated: Use eventlogger. func (b *Backend) reconnect(ctx context.Context) error { if b.connection != nil { b.connection.Close() @@ -317,12 +286,12 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) + s, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) if err != nil { return nil, err } - b.salt = salt - return salt, nil + b.salt = s + return s, nil } func (b *Backend) Invalidate(_ context.Context) { @@ -331,20 +300,146 @@ func (b *Backend) Invalidate(_ context.Context) { b.salt = nil } -// RegisterNodesAndPipeline registers the nodes and a pipeline as required by -// the audit.Backend interface. -func (b *Backend) RegisterNodesAndPipeline(broker *eventlogger.Broker, name string) error { - for id, node := range b.nodeMap { - if err := broker.RegisterNode(id, node, eventlogger.WithNodeRegistrationPolicy(eventlogger.DenyOverwrite)); err != nil { - return err +// formatterConfig creates the configuration required by a formatter node using +// the config map supplied to the factory. +func formatterConfig(config map[string]string) (audit.FormatterConfig, error) { + const op = "socket.formatterConfig" + + var cfgOpts []audit.Option + + if format, ok := config["format"]; ok { + cfgOpts = append(cfgOpts, audit.WithFormat(format)) + } + + // Check if hashing of accessor is disabled + if hmacAccessorRaw, ok := config["hmac_accessor"]; ok { + v, err := strconv.ParseBool(hmacAccessorRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'hmac_accessor': %w", op, err) + } + cfgOpts = append(cfgOpts, audit.WithHMACAccessor(v)) + } + + // Check if raw logging is enabled + if raw, ok := config["log_raw"]; ok { + v, err := strconv.ParseBool(raw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'log_raw': %w", op, err) + } + cfgOpts = append(cfgOpts, audit.WithRaw(v)) + } + + if elideListResponsesRaw, ok := config["elide_list_responses"]; ok { + v, err := strconv.ParseBool(elideListResponsesRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'elide_list_responses': %w", op, err) } + cfgOpts = append(cfgOpts, audit.WithElision(v)) + } + + return audit.NewFormatterConfig(cfgOpts...) +} + +// configureFilterNode is used to configure a filter node and associated ID on the Backend. +func (b *Backend) configureFilterNode(filter string) error { + const op = "socket.(Backend).configureFilterNode" + + filter = strings.TrimSpace(filter) + if filter == "" { + return nil + } + + filterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for filter node: %w", op, err) + } + + filterNode, err := audit.NewEntryFilter(filter) + if err != nil { + return fmt.Errorf("%s: error creating filter node: %w", op, err) + } + + b.nodeIDList = append(b.nodeIDList, filterNodeID) + b.nodeMap[filterNodeID] = filterNode + return nil +} + +// configureFormatterNode is used to configure a formatter node and associated ID on the Backend. +func (b *Backend) configureFormatterNode(formatConfig audit.FormatterConfig, opts ...audit.Option) error { + const op = "socket.(Backend).configureFormatterNode" + + formatterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for formatter node: %w", op, err) + } + + formatterNode, err := audit.NewEntryFormatter(formatConfig, b, opts...) + if err != nil { + return fmt.Errorf("%s: error creating formatter: %w", op, err) + } + + b.nodeIDList = append(b.nodeIDList, formatterNodeID) + b.nodeMap[formatterNodeID] = formatterNode + return nil +} + +// configureSinkNode is used to configure a sink node and associated ID on the Backend. +func (b *Backend) configureSinkNode(name string, address string, format string, opts ...event.Option) error { + const op = "socket.(Backend).configureSinkNode" + + name = strings.TrimSpace(name) + if name == "" { + return fmt.Errorf("%s: name is required: %w", op, event.ErrInvalidParameter) + } + + address = strings.TrimSpace(address) + if address == "" { + return fmt.Errorf("%s: address is required: %w", op, event.ErrInvalidParameter) + } + + format = strings.TrimSpace(format) + if format == "" { + return fmt.Errorf("%s: format is required: %w", op, event.ErrInvalidParameter) } - pipeline := eventlogger.Pipeline{ - PipelineID: eventlogger.PipelineID(name), - EventType: eventlogger.EventType(event.AuditType.String()), - NodeIDs: b.nodeIDList, + sinkNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for sink node: %w", op, err) + } + + n, err := event.NewSocketSink(address, format, opts...) + if err != nil { + return fmt.Errorf("%s: error creating socket sink node: %w", op, err) } - return broker.RegisterPipeline(pipeline, eventlogger.WithPipelineRegistrationPolicy(eventlogger.DenyOverwrite)) + sinkNode := &audit.SinkWrapper{Name: name, Sink: n} + + b.nodeIDList = append(b.nodeIDList, sinkNodeID) + b.nodeMap[sinkNodeID] = sinkNode + return nil +} + +// Name for this backend, this would ideally correspond to the mount path for the audit device. +func (b *Backend) Name() string { + return b.name +} + +// Nodes returns the nodes which should be used by the event framework to process audit entries. +func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node { + return b.nodeMap +} + +// NodeIDs returns the IDs of the nodes, in the order they are required. +func (b *Backend) NodeIDs() []eventlogger.NodeID { + return b.nodeIDList +} + +// EventType returns the event type for the backend. +func (b *Backend) EventType() eventlogger.EventType { + return eventlogger.EventType(event.AuditType.String()) +} + +// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter. +func (b *Backend) HasFiltering() bool { + return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter } diff --git a/builtin/audit/socket/backend_test.go b/builtin/audit/socket/backend_test.go new file mode 100644 index 000000000000..d1dfc384720c --- /dev/null +++ b/builtin/audit/socket/backend_test.go @@ -0,0 +1,331 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package socket + +import ( + "testing" + + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/audit" + "github.com/stretchr/testify/require" +) + +// TestBackend_formatterConfig ensures that all the configuration values are parsed correctly. +func TestBackend_formatterConfig(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + config map[string]string + want audit.FormatterConfig + wantErr bool + expectedErrMsg string + }{ + "happy-path-json": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "json", + }, wantErr: false, + }, + "happy-path-jsonx": { + config: map[string]string{ + "format": audit.JSONxFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "jsonx", + }, + wantErr: false, + }, + "invalid-format": { + config: map[string]string{ + "format": " squiggly ", + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "audit.NewFormatterConfig: error applying options: audit.(format).validate: 'squiggly' is not a valid format: invalid parameter", + }, + "invalid-hmac-accessor": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "socket.formatterConfig: unable to parse 'hmac_accessor': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-log-raw": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "socket.formatterConfig: unable to parse 'log_raw': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-elide-bool": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "socket.formatterConfig: unable to parse 'elide_list_responses': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := formatterConfig(tc.config) + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + } else { + require.NoError(t, err) + } + require.Equal(t, tc.want, got) + }) + } +} + +// TestBackend_configureFilterNode ensures that configureFilterNode handles various +// filter values as expected. Empty (including whitespace) strings should return +// no error but skip configuration of the node. +func TestBackend_configureFilterNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + filter string + shouldSkipNode bool + wantErr bool + expectedErrorMsg string + }{ + "happy": { + filter: "foo == bar", + }, + "empty": { + filter: "", + shouldSkipNode: true, + }, + "spacey": { + filter: " ", + shouldSkipNode: true, + }, + "bad": { + filter: "___qwerty", + wantErr: true, + expectedErrorMsg: "socket.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureFilterNode(tc.filter) + + switch { + case tc.wantErr: + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedErrorMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + case tc.shouldSkipNode: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + default: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + } + }) + } +} + +// TestBackend_configureFormatterNode ensures that configureFormatterNode +// populates the nodeIDList and nodeMap on Backend when given valid formatConfig. +func TestBackend_configureFormatterNode(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) +} + +// TestBackend_configureSinkNode ensures that we can correctly configure the sink +// node on the Backend, and any incorrect parameters result in the relevant errors. +func TestBackend_configureSinkNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + name string + address string + format string + wantErr bool + expectedErrMsg string + expectedName string + }{ + "name-empty": { + name: "", + address: "wss://foo", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "name-whitespace": { + name: " ", + address: "wss://foo", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "address-empty": { + name: "foo", + address: "", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: address is required: invalid parameter", + }, + "address-whitespace": { + name: "foo", + address: " ", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: address is required: invalid parameter", + }, + "format-empty": { + name: "foo", + address: "wss://foo", + format: "", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "format-whitespace": { + name: "foo", + address: "wss://foo", + format: " ", + wantErr: true, + expectedErrMsg: "socket.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "happy": { + name: "foo", + address: "wss://foo", + format: "json", + wantErr: false, + expectedName: "foo", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureSinkNode(tc.name, tc.address, tc.format) + + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + } else { + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) + sw, ok := node.(*audit.SinkWrapper) + require.True(t, ok) + require.Equal(t, tc.expectedName, sw.Name) + } + }) + } +} + +// TestBackend_configureFilterFormatterSink ensures that configuring all three +// types of nodes on a Backend works as expected, i.e. we have all three nodes +// at the end and nothing gets overwritten. The order of calls influences the +// slice of IDs on the Backend. +func TestBackend_configureFilterFormatterSink(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFilterNode("foo == bar") + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + require.NoError(t, err) + + err = b.configureSinkNode("foo", "https://hashicorp.com", "json") + require.NoError(t, err) + + require.Len(t, b.nodeIDList, 3) + require.Len(t, b.nodeMap, 3) + + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + + id = b.nodeIDList[1] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) + + id = b.nodeIDList[2] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) +} diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index 9dc0298f64f6..45d6e0762daa 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "strconv" + "strings" "sync" "github.com/hashicorp/eventlogger" @@ -18,13 +19,31 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { +var _ audit.Backend = (*Backend)(nil) + +// Backend is the audit backend for the syslog-based audit store. +type Backend struct { + formatter *audit.EntryFormatterWriter + formatConfig audit.FormatterConfig + logger gsyslog.Syslogger + name string + nodeIDList []eventlogger.NodeID + nodeMap map[eventlogger.NodeID]eventlogger.Node + salt *salt.Salt + saltConfig *salt.Config + saltMutex sync.RWMutex + saltView logical.Storage +} + +func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { + const op = "syslog.Factory" + if conf.SaltConfig == nil { - return nil, fmt.Errorf("nil salt config") + return nil, fmt.Errorf("%s: nil salt config", op) } if conf.SaltView == nil { - return nil, fmt.Errorf("nil salt view") + return nil, fmt.Errorf("%s: nil salt view", op) } // Get facility or default to AUTH @@ -39,60 +58,29 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool tag = "vault" } - var cfgOpts []audit.Option - - if format, ok := conf.Config["format"]; ok { - cfgOpts = append(cfgOpts, audit.WithFormat(format)) - } - - // Check if hashing of accessor is disabled - if hmacAccessorRaw, ok := conf.Config["hmac_accessor"]; ok { - v, err := strconv.ParseBool(hmacAccessorRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithHMACAccessor(v)) - } - - // Check if raw logging is enabled - if raw, ok := conf.Config["log_raw"]; ok { - v, err := strconv.ParseBool(raw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithRaw(v)) - } - - if elideListResponsesRaw, ok := conf.Config["elide_list_responses"]; ok { - v, err := strconv.ParseBool(elideListResponsesRaw) - if err != nil { - return nil, err - } - cfgOpts = append(cfgOpts, audit.WithElision(v)) - } - - cfg, err := audit.NewFormatterConfig(cfgOpts...) + cfg, err := formatterConfig(conf.Config) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err) } // Get the logger logger, err := gsyslog.NewLogger(gsyslog.LOG_INFO, facility, tag) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: cannot create logger: %w", op, err) } b := &Backend{ + formatConfig: cfg, logger: logger, + name: conf.MountPath, saltConfig: conf.SaltConfig, saltView: conf.SaltView, - formatConfig: cfg, } // Configure the formatter for either case. f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithHeaderFormatter(headersConfig), audit.WithPrefix(conf.Config["prefix"])) if err != nil { - return nil, fmt.Errorf("error creating formatter: %w", err) + return nil, fmt.Errorf("%s: error creating formatter: %w", op, err) } var w audit.Writer @@ -105,67 +93,45 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w) if err != nil { - return nil, fmt.Errorf("error creating formatter writer: %w", err) + return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err) } b.formatter = fw if useEventLogger { - var opts []event.Option + b.nodeIDList = []eventlogger.NodeID{} + b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - // Get facility or default to AUTH - if facility, ok := conf.Config["facility"]; ok { - opts = append(opts, event.WithFacility(facility)) + err := b.configureFilterNode(conf.Config["filter"]) + if err != nil { + return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err) } - if tag, ok := conf.Config["tag"]; ok { - opts = append(opts, event.WithTag(tag)) + formatterOpts := []audit.Option{ + audit.WithHeaderFormatter(headersConfig), + audit.WithPrefix(conf.Config["prefix"]), } - b.nodeIDList = make([]eventlogger.NodeID, 2) - b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node) - - formatterNodeID, err := event.GenerateNodeID() + err = b.configureFormatterNode(cfg, formatterOpts...) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err) + return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err) } - b.nodeIDList[0] = formatterNodeID - b.nodeMap[formatterNodeID] = f - n, err := event.NewSyslogSink(b.formatConfig.RequiredFormat.String(), opts...) - if err != nil { - return nil, fmt.Errorf("error creating syslog sink node: %w", err) + sinkOpts := []event.Option{ + event.WithFacility(facility), + event.WithTag(tag), } - sinkNode := &audit.SinkWrapper{Name: conf.MountPath, Sink: n} - sinkNodeID, err := event.GenerateNodeID() + err = b.configureSinkNode(conf.MountPath, cfg.RequiredFormat.String(), sinkOpts...) if err != nil { - return nil, fmt.Errorf("error generating random NodeID for sink node: %w", err) + return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err) } - b.nodeIDList[1] = sinkNodeID - b.nodeMap[sinkNodeID] = sinkNode } - return b, nil -} - -// Backend is the audit backend for the syslog-based audit store. -type Backend struct { - logger gsyslog.Syslogger - formatter *audit.EntryFormatterWriter - formatConfig audit.FormatterConfig - - saltMutex sync.RWMutex - salt *salt.Salt - saltConfig *salt.Config - saltView logical.Storage - - nodeIDList []eventlogger.NodeID - nodeMap map[eventlogger.NodeID]eventlogger.Node + return b, nil } -var _ audit.Backend = (*Backend)(nil) - +// Deprecated: Use eventlogger. func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { @@ -177,6 +143,7 @@ func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { return err } +// Deprecated: Use eventlogger. func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { var buf bytes.Buffer if err := b.formatter.FormatAndWriteResponse(ctx, &buf, in); err != nil { @@ -227,12 +194,12 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { if b.salt != nil { return b.salt, nil } - salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) + s, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) if err != nil { return nil, err } - b.salt = salt - return salt, nil + b.salt = s + return s, nil } func (b *Backend) Invalidate(_ context.Context) { @@ -241,20 +208,142 @@ func (b *Backend) Invalidate(_ context.Context) { b.salt = nil } -// RegisterNodesAndPipeline registers the nodes and a pipeline as required by -// the audit.Backend interface. -func (b *Backend) RegisterNodesAndPipeline(broker *eventlogger.Broker, name string) error { - for id, node := range b.nodeMap { - if err := broker.RegisterNode(id, node, eventlogger.WithNodeRegistrationPolicy(eventlogger.DenyOverwrite)); err != nil { - return err +// formatterConfig creates the configuration required by a formatter node using +// the config map supplied to the factory. +func formatterConfig(config map[string]string) (audit.FormatterConfig, error) { + const op = "syslog.formatterConfig" + + var opts []audit.Option + + if format, ok := config["format"]; ok { + opts = append(opts, audit.WithFormat(format)) + } + + // Check if hashing of accessor is disabled + if hmacAccessorRaw, ok := config["hmac_accessor"]; ok { + v, err := strconv.ParseBool(hmacAccessorRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'hmac_accessor': %w", op, err) } + opts = append(opts, audit.WithHMACAccessor(v)) } - pipeline := eventlogger.Pipeline{ - PipelineID: eventlogger.PipelineID(name), - EventType: eventlogger.EventType(event.AuditType.String()), - NodeIDs: b.nodeIDList, + // Check if raw logging is enabled + if raw, ok := config["log_raw"]; ok { + v, err := strconv.ParseBool(raw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'log_raw': %w", op, err) + } + opts = append(opts, audit.WithRaw(v)) + } + + if elideListResponsesRaw, ok := config["elide_list_responses"]; ok { + v, err := strconv.ParseBool(elideListResponsesRaw) + if err != nil { + return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'elide_list_responses': %w", op, err) + } + opts = append(opts, audit.WithElision(v)) + } + + return audit.NewFormatterConfig(opts...) +} + +// configureFilterNode is used to configure a filter node and associated ID on the Backend. +func (b *Backend) configureFilterNode(filter string) error { + const op = "syslog.(Backend).configureFilterNode" + + filter = strings.TrimSpace(filter) + if filter == "" { + return nil + } + + filterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for filter node: %w", op, err) + } + + filterNode, err := audit.NewEntryFilter(filter) + if err != nil { + return fmt.Errorf("%s: error creating filter node: %w", op, err) + } + + b.nodeIDList = append(b.nodeIDList, filterNodeID) + b.nodeMap[filterNodeID] = filterNode + return nil +} + +// configureFormatterNode is used to configure a formatter node and associated ID on the Backend. +func (b *Backend) configureFormatterNode(formatConfig audit.FormatterConfig, opts ...audit.Option) error { + const op = "syslog.(Backend).configureFormatterNode" + + formatterNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for formatter node: %w", op, err) + } + + formatterNode, err := audit.NewEntryFormatter(formatConfig, b, opts...) + if err != nil { + return fmt.Errorf("%s: error creating formatter: %w", op, err) + } + + b.nodeIDList = append(b.nodeIDList, formatterNodeID) + b.nodeMap[formatterNodeID] = formatterNode + return nil +} + +// configureSinkNode is used to configure a sink node and associated ID on the Backend. +func (b *Backend) configureSinkNode(name string, format string, opts ...event.Option) error { + const op = "syslog.(Backend).configureSinkNode" + + name = strings.TrimSpace(name) + if name == "" { + return fmt.Errorf("%s: name is required: %w", op, event.ErrInvalidParameter) + } + + format = strings.TrimSpace(format) + if format == "" { + return fmt.Errorf("%s: format is required: %w", op, event.ErrInvalidParameter) + } + + sinkNodeID, err := event.GenerateNodeID() + if err != nil { + return fmt.Errorf("%s: error generating random NodeID for sink node: %w", op, err) + } + + n, err := event.NewSyslogSink(format, opts...) + if err != nil { + return fmt.Errorf("%s: error creating syslog sink node: %w", op, err) } - return broker.RegisterPipeline(pipeline, eventlogger.WithPipelineRegistrationPolicy(eventlogger.DenyOverwrite)) + // wrap the sink node with metrics middleware + sinkNode := &audit.SinkWrapper{Name: name, Sink: n} + + b.nodeIDList = append(b.nodeIDList, sinkNodeID) + b.nodeMap[sinkNodeID] = sinkNode + return nil +} + +// Name for this backend, this would ideally correspond to the mount path for the audit device. +func (b *Backend) Name() string { + return b.name +} + +// Nodes returns the nodes which should be used by the event framework to process audit entries. +func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node { + return b.nodeMap +} + +// NodeIDs returns the IDs of the nodes, in the order they are required. +func (b *Backend) NodeIDs() []eventlogger.NodeID { + return b.nodeIDList +} + +// EventType returns the event type for the backend. +func (b *Backend) EventType() eventlogger.EventType { + return eventlogger.EventType(event.AuditType.String()) +} + +// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter. +func (b *Backend) HasFiltering() bool { + return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter } diff --git a/builtin/audit/syslog/backend_test.go b/builtin/audit/syslog/backend_test.go new file mode 100644 index 000000000000..4aeaa5d0da5c --- /dev/null +++ b/builtin/audit/syslog/backend_test.go @@ -0,0 +1,313 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package syslog + +import ( + "testing" + + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/audit" + "github.com/stretchr/testify/require" +) + +// TestBackend_formatterConfig ensures that all the configuration values are parsed correctly. +func TestBackend_formatterConfig(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + config map[string]string + want audit.FormatterConfig + wantErr bool + expectedErrMsg string + }{ + "happy-path-json": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "json", + }, wantErr: false, + }, + "happy-path-jsonx": { + config: map[string]string{ + "format": audit.JSONxFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{ + Raw: true, + HMACAccessor: true, + ElideListResponses: true, + RequiredFormat: "jsonx", + }, + wantErr: false, + }, + "invalid-format": { + config: map[string]string{ + "format": " squiggly ", + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "true", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "audit.NewFormatterConfig: error applying options: audit.(format).validate: 'squiggly' is not a valid format: invalid parameter", + }, + "invalid-hmac-accessor": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "syslog.formatterConfig: unable to parse 'hmac_accessor': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-log-raw": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "syslog.formatterConfig: unable to parse 'log_raw': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + "invalid-elide-bool": { + config: map[string]string{ + "format": audit.JSONFormat.String(), + "hmac_accessor": "true", + "log_raw": "true", + "elide_list_responses": "maybe", + }, + want: audit.FormatterConfig{}, + wantErr: true, + expectedErrMsg: "syslog.formatterConfig: unable to parse 'elide_list_responses': strconv.ParseBool: parsing \"maybe\": invalid syntax", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := formatterConfig(tc.config) + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + } else { + require.NoError(t, err) + } + require.Equal(t, tc.want, got) + }) + } +} + +// TestBackend_configureFilterNode ensures that configureFilterNode handles various +// filter values as expected. Empty (including whitespace) strings should return +// no error but skip configuration of the node. +func TestBackend_configureFilterNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + filter string + shouldSkipNode bool + wantErr bool + expectedErrorMsg string + }{ + "happy": { + filter: "foo == bar", + }, + "empty": { + filter: "", + shouldSkipNode: true, + }, + "spacey": { + filter: " ", + shouldSkipNode: true, + }, + "bad": { + filter: "___qwerty", + wantErr: true, + expectedErrorMsg: "syslog.(Backend).configureFilterNode: error creating filter node: audit.NewEntryFilter: cannot create new audit filter", + }, + } + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureFilterNode(tc.filter) + + switch { + case tc.wantErr: + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedErrorMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + case tc.shouldSkipNode: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + default: + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + } + }) + } +} + +// TestBackend_configureFormatterNode ensures that configureFormatterNode +// populates the nodeIDList and nodeMap on Backend when given valid formatConfig. +func TestBackend_configureFormatterNode(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) +} + +// TestBackend_configureSinkNode ensures that we can correctly configure the sink +// node on the Backend, and any incorrect parameters result in the relevant errors. +func TestBackend_configureSinkNode(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + name string + format string + wantErr bool + expectedErrMsg string + expectedName string + }{ + "name-empty": { + name: "", + wantErr: true, + expectedErrMsg: "syslog.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "name-whitespace": { + name: " ", + wantErr: true, + expectedErrMsg: "syslog.(Backend).configureSinkNode: name is required: invalid parameter", + }, + "format-empty": { + name: "foo", + format: "", + wantErr: true, + expectedErrMsg: "syslog.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "format-whitespace": { + name: "foo", + format: " ", + wantErr: true, + expectedErrMsg: "syslog.(Backend).configureSinkNode: format is required: invalid parameter", + }, + "happy": { + name: "foo", + format: "json", + wantErr: false, + expectedName: "foo", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + err := b.configureSinkNode(tc.name, tc.format) + + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + require.Len(t, b.nodeIDList, 0) + require.Len(t, b.nodeMap, 0) + } else { + require.NoError(t, err) + require.Len(t, b.nodeIDList, 1) + require.Len(t, b.nodeMap, 1) + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) + sw, ok := node.(*audit.SinkWrapper) + require.True(t, ok) + require.Equal(t, tc.expectedName, sw.Name) + } + }) + } +} + +// TestBackend_configureFilterFormatterSink ensures that configuring all three +// types of nodes on a Backend works as expected, i.e. we have all three nodes +// at the end and nothing gets overwritten. The order of calls influences the +// slice of IDs on the Backend. +func TestBackend_configureFilterFormatterSink(t *testing.T) { + t.Parallel() + + b := &Backend{ + nodeIDList: []eventlogger.NodeID{}, + nodeMap: map[eventlogger.NodeID]eventlogger.Node{}, + } + + formatConfig, err := audit.NewFormatterConfig() + require.NoError(t, err) + + err = b.configureFilterNode("foo == bar") + require.NoError(t, err) + + err = b.configureFormatterNode(formatConfig) + require.NoError(t, err) + + err = b.configureSinkNode("foo", "json") + require.NoError(t, err) + + require.Len(t, b.nodeIDList, 3) + require.Len(t, b.nodeMap, 3) + + id := b.nodeIDList[0] + node := b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFilter, node.Type()) + + id = b.nodeIDList[1] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeFormatter, node.Type()) + + id = b.nodeIDList[2] + node = b.nodeMap[id] + require.Equal(t, eventlogger.NodeTypeSink, node.Type()) +} diff --git a/changelog/24558.txt b/changelog/24558.txt new file mode 100644 index 000000000000..cd573e6d5849 --- /dev/null +++ b/changelog/24558.txt @@ -0,0 +1,3 @@ +```release-note:feature +core/audit: add filter parameter when enabling an audit device, allowing filtering (using go-bexpr expressions) of audit entries written to the device's audit log +``` diff --git a/helper/testhelpers/corehelpers/corehelpers.go b/helper/testhelpers/corehelpers/corehelpers.go index b8d1def9ca73..c2d6bc8a3ce7 100644 --- a/helper/testhelpers/corehelpers/corehelpers.go +++ b/helper/testhelpers/corehelpers/corehelpers.go @@ -612,3 +612,23 @@ func NewTestLogger(t testing.T) *TestLogger { func (tl *TestLogger) StopLogging() { tl.InterceptLogger.DeregisterSink(tl.sink) } + +func (n *NoopAudit) EventType() eventlogger.EventType { + return eventlogger.EventType(event.AuditType.String()) +} + +func (n *NoopAudit) HasFiltering() bool { + return false +} + +func (n *NoopAudit) Name() string { + return n.Config.MountPath +} + +func (n *NoopAudit) Nodes() map[eventlogger.NodeID]eventlogger.Node { + return n.nodeMap +} + +func (n *NoopAudit) NodeIDs() []eventlogger.NodeID { + return n.nodeIDList +} diff --git a/http/logical_test.go b/http/logical_test.go index 01e6762ee0a9..88964ac874c2 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -570,7 +570,7 @@ func TestLogical_RespondWithStatusCode(t *testing.T) { func TestLogical_Audit_invalidWrappingToken(t *testing.T) { // Create a noop audit backend - noop := corehelpers.TestNoopAudit(t, "noop", nil) + noop := corehelpers.TestNoopAudit(t, "noop/", nil) c, _, root := vault.TestCoreUnsealedWithConfig(t, &vault.CoreConfig{ AuditBackends: map[string]audit.Factory{ "noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { diff --git a/internal/observability/event/options.go b/internal/observability/event/options.go index 80ad1bf9996e..30d667740b0d 100644 --- a/internal/observability/event/options.go +++ b/internal/observability/event/options.go @@ -113,7 +113,11 @@ func WithNow(now time.Time) Option { // WithFacility provides an Option to represent a 'facility' for a syslog sink. func WithFacility(facility string) Option { return func(o *options) error { - o.withFacility = facility + facility = strings.TrimSpace(facility) + + if facility != "" { + o.withFacility = facility + } return nil } @@ -122,7 +126,11 @@ func WithFacility(facility string) Option { // WithTag provides an Option to represent a 'tag' for a syslog sink. func WithTag(tag string) Option { return func(o *options) error { - o.withTag = tag + tag = strings.TrimSpace(tag) + + if tag != "" { + o.withTag = tag + } return nil } diff --git a/internal/observability/event/options_test.go b/internal/observability/event/options_test.go index 676c79833078..0f36014740cf 100644 --- a/internal/observability/event/options_test.go +++ b/internal/observability/event/options_test.go @@ -205,7 +205,7 @@ func TestOptions_WithFacility(t *testing.T) { }, "whitespace": { Value: " ", - ExpectedValue: " ", + ExpectedValue: "", }, "value": { Value: "juan", @@ -213,7 +213,7 @@ func TestOptions_WithFacility(t *testing.T) { }, "spacey-value": { Value: " juan ", - ExpectedValue: " juan ", + ExpectedValue: "juan", }, } @@ -243,7 +243,7 @@ func TestOptions_WithTag(t *testing.T) { }, "whitespace": { Value: " ", - ExpectedValue: " ", + ExpectedValue: "", }, "value": { Value: "juan", @@ -251,7 +251,7 @@ func TestOptions_WithTag(t *testing.T) { }, "spacey-value": { Value: " juan ", - ExpectedValue: " juan ", + ExpectedValue: "juan", }, } diff --git a/internal/observability/event/pipeline_reader.go b/internal/observability/event/pipeline_reader.go new file mode 100644 index 000000000000..f35672f8efa6 --- /dev/null +++ b/internal/observability/event/pipeline_reader.go @@ -0,0 +1,24 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package event + +import "github.com/hashicorp/eventlogger" + +// PipelineReader surfaces information required for pipeline registration. +type PipelineReader interface { + // EventType should return the event type to be used for pipeline registration. + EventType() eventlogger.EventType + + // HasFiltering should determine if filter nodes are used by this pipeline. + HasFiltering() bool + + // Name for the pipeline which should be used for the eventlogger.PipelineID. + Name() string + + // Nodes should return the nodes which should be used by the framework to process events. + Nodes() map[eventlogger.NodeID]eventlogger.Node + + // NodeIDs should return the IDs of the nodes, in the order they are required. + NodeIDs() []eventlogger.NodeID +} diff --git a/internal/observability/event/sink_socket.go b/internal/observability/event/sink_socket.go index 69f482560b67..e9cb00c19662 100644 --- a/internal/observability/event/sink_socket.go +++ b/internal/observability/event/sink_socket.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "strings" "sync" "time" @@ -29,9 +30,19 @@ type SocketSink struct { // NewSocketSink should be used to create a new SocketSink. // Accepted options: WithMaxDuration and WithSocketType. -func NewSocketSink(format string, address string, opt ...Option) (*SocketSink, error) { +func NewSocketSink(address string, format string, opt ...Option) (*SocketSink, error) { const op = "event.NewSocketSink" + address = strings.TrimSpace(address) + if address == "" { + return nil, fmt.Errorf("%s: address is required: %w", op, ErrInvalidParameter) + } + + format = strings.TrimSpace(format) + if format == "" { + return nil, fmt.Errorf("%s: format is required: %w", op, ErrInvalidParameter) + } + opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: error applying options: %w", op, err) diff --git a/internal/observability/event/sink_socket_test.go b/internal/observability/event/sink_socket_test.go new file mode 100644 index 000000000000..3c647f7b3ea1 --- /dev/null +++ b/internal/observability/event/sink_socket_test.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package event + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestNewSocketSink ensures that we validate the input arguments and can create +// the SocketSink if everything goes to plan. +func TestNewSocketSink(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + address string + format string + opts []Option + want *SocketSink + wantErr bool + expectedErrMsg string + }{ + "address-empty": { + address: "", + wantErr: true, + expectedErrMsg: "event.NewSocketSink: address is required: invalid parameter", + }, + "address-whitespace": { + address: " ", + wantErr: true, + expectedErrMsg: "event.NewSocketSink: address is required: invalid parameter", + }, + "format-empty": { + address: "addr", + format: "", + wantErr: true, + expectedErrMsg: "event.NewSocketSink: format is required: invalid parameter", + }, + "format-whitespace": { + address: "addr", + format: " ", + wantErr: true, + expectedErrMsg: "event.NewSocketSink: format is required: invalid parameter", + }, + "bad-max-duration": { + address: "addr", + format: "json", + opts: []Option{WithMaxDuration("bar")}, + wantErr: true, + expectedErrMsg: "event.NewSocketSink: error applying options: time: invalid duration \"bar\"", + }, + "happy": { + address: "wss://foo", + format: "json", + want: &SocketSink{ + requiredFormat: "json", + address: "wss://foo", + socketType: "tcp", // defaults to tcp + maxDuration: 2 * time.Second, // defaults to 2 secs + }, + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := NewSocketSink(tc.address, tc.format, tc.opts...) + + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + require.Nil(t, got) + } else { + require.NoError(t, err) + require.Equal(t, tc.want, got) + } + }) + } +} diff --git a/internal/observability/event/sink_stdout.go b/internal/observability/event/sink_stdout.go index 34307251d415..6b1f43dace8f 100644 --- a/internal/observability/event/sink_stdout.go +++ b/internal/observability/event/sink_stdout.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "strings" "github.com/hashicorp/eventlogger" ) @@ -21,10 +22,17 @@ type StdoutSink struct { // NewStdoutSinkNode creates a new StdoutSink that will persist the events // it processes using the specified expected format. -func NewStdoutSinkNode(format string) *StdoutSink { +func NewStdoutSinkNode(format string) (*StdoutSink, error) { + const op = "event.NewStdoutSinkNode" + + format = strings.TrimSpace(format) + if format == "" { + return nil, fmt.Errorf("%s: format is required: %w", op, ErrInvalidParameter) + } + return &StdoutSink{ requiredFormat: format, - } + }, nil } // Process persists the provided eventlogger.Event to the standard output stream. diff --git a/internal/observability/event/sink_syslog.go b/internal/observability/event/sink_syslog.go index 72ac6cdd1e1c..d099ed5c7349 100644 --- a/internal/observability/event/sink_syslog.go +++ b/internal/observability/event/sink_syslog.go @@ -6,6 +6,7 @@ package event import ( "context" "fmt" + "strings" gsyslog "github.com/hashicorp/go-syslog" @@ -25,6 +26,11 @@ type SyslogSink struct { func NewSyslogSink(format string, opt ...Option) (*SyslogSink, error) { const op = "event.NewSyslogSink" + format = strings.TrimSpace(format) + if format == "" { + return nil, fmt.Errorf("%s: format is required: %w", op, ErrInvalidParameter) + } + opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: error applying options: %w", op, err) diff --git a/internal/observability/event/sink_syslog_test.go b/internal/observability/event/sink_syslog_test.go new file mode 100644 index 000000000000..f977a4a50538 --- /dev/null +++ b/internal/observability/event/sink_syslog_test.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package event + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestNewSyslogSink ensures that we validate the input arguments and can create +// the SyslogSink if everything goes to plan. +func TestNewSyslogSink(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + format string + opts []Option + want *SyslogSink + wantErr bool + expectedErrMsg string + }{ + "format-empty": { + format: "", + wantErr: true, + expectedErrMsg: "event.NewSyslogSink: format is required: invalid parameter", + }, + "format-whitespace": { + format: " ", + wantErr: true, + expectedErrMsg: "event.NewSyslogSink: format is required: invalid parameter", + }, + "happy": { + format: "json", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := NewSyslogSink(tc.format, tc.opts...) + + if tc.wantErr { + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrMsg) + require.Nil(t, got) + } else { + require.NoError(t, err) + require.NotNil(t, got) + } + }) + } +} diff --git a/sdk/logical/audit.go b/sdk/logical/audit.go index 30c03e6113ac..12b8bed1cbdb 100644 --- a/sdk/logical/audit.go +++ b/sdk/logical/audit.go @@ -20,3 +20,36 @@ type MarshalOptions struct { type OptMarshaler interface { MarshalJSONWithOptions(*MarshalOptions) ([]byte, error) } + +// LogInputBexpr is used for evaluating boolean expressions with go-bexpr. +type LogInputBexpr struct { + MountPoint string `bexpr:"mount_point"` + MountType string `bexpr:"mount_type"` + Namespace string `bexpr:"namespace"` + Operation string `bexpr:"operation"` + Path string `bexpr:"path"` +} + +// BexprDatum returns values from a LogInput formatted for use in evaluating go-bexpr boolean expressions. +// The namespace should be supplied from the current request's context. +func (l *LogInput) BexprDatum(namespace string) *LogInputBexpr { + var mountPoint string + var mountType string + var operation string + var path string + + if l.Request != nil { + mountPoint = l.Request.MountPoint + mountType = l.Request.MountType + operation = string(l.Request.Operation) + path = l.Request.Path + } + + return &LogInputBexpr{ + MountPoint: mountPoint, + MountType: mountType, + Namespace: namespace, + Operation: operation, + Path: path, + } +} diff --git a/sdk/logical/audit_test.go b/sdk/logical/audit_test.go new file mode 100644 index 000000000000..710450c2f303 --- /dev/null +++ b/sdk/logical/audit_test.go @@ -0,0 +1,77 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package logical + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestLogInput_BexprDatum ensures that we can transform a LogInput +// into a LogInputBexpr to be used in audit filtering. +func TestLogInput_BexprDatum(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + Request *Request + Namespace string + ExpectedPath string + ExpectedMountPoint string + ExpectedMountType string + ExpectedNamespace string + ExpectedOperation string + }{ + "nil-no-namespace": { + Request: nil, + Namespace: "", + ExpectedPath: "", + ExpectedMountPoint: "", + ExpectedMountType: "", + ExpectedNamespace: "", + ExpectedOperation: "", + }, + "nil-namespace": { + Request: nil, + Namespace: "juan", + ExpectedPath: "", + ExpectedMountPoint: "", + ExpectedMountType: "", + ExpectedNamespace: "juan", + ExpectedOperation: "", + }, + "happy-path": { + Request: &Request{ + MountPoint: "IAmAMountPoint", + MountType: "IAmAMountType", + Operation: CreateOperation, + Path: "IAmAPath", + }, + Namespace: "juan", + ExpectedPath: "IAmAPath", + ExpectedMountPoint: "IAmAMountPoint", + ExpectedMountType: "IAmAMountType", + ExpectedNamespace: "juan", + ExpectedOperation: "create", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + l := &LogInput{Request: tc.Request} + + d := l.BexprDatum(tc.Namespace) + + require.Equal(t, tc.ExpectedPath, d.Path) + require.Equal(t, tc.ExpectedMountPoint, d.MountPoint) + require.Equal(t, tc.ExpectedMountType, d.MountType) + require.Equal(t, tc.ExpectedNamespace, d.Namespace) + require.Equal(t, tc.ExpectedOperation, d.Operation) + }) + } +} diff --git a/vault/audit.go b/vault/audit.go index bdbae6471464..a3d6fcfab9b8 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -397,8 +397,7 @@ func (c *Core) persistAudit(ctx context.Context, table *MountTable, localOnly bo return nil } -// setupAudit is invoked after we've loaded the audit able to -// initialize the audit backends +// setupAudits is invoked after we've loaded the audit table to initialize the audit backends func (c *Core) setupAudits(ctx context.Context) error { c.auditLock.Lock() defer c.auditLock.Unlock() @@ -539,7 +538,6 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi } auditLogger := c.baseLogger.Named("audit") - c.AddLogger(auditLogger) switch entry.Type { case "file": @@ -570,6 +568,7 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi } } + c.AddLogger(auditLogger) return be, err } diff --git a/vault/audit_broker.go b/vault/audit_broker.go index 15cef4409d89..468d4b6171d1 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -5,6 +5,7 @@ package vault import ( "context" + "errors" "fmt" "runtime/debug" "sync" @@ -57,6 +58,8 @@ func NewAuditBroker(log log.Logger, useEventLogger bool) (*AuditBroker, error) { // Register is used to add new audit backend to the broker func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { + const op = "vault.(AuditBroker).Register" + a.Lock() defer a.Unlock() @@ -66,16 +69,36 @@ func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { } if a.broker != nil { - // Attempt to register the pipeline before enabling 'broker level' enforcement - // of how many successful sinks we expect. - err := b.RegisterNodesAndPipeline(a.broker, name) + if name != b.Name() { + return fmt.Errorf("%s: audit registration failed due to device name mismatch: %q, %q", op, name, b.Name()) + } + + for id, node := range b.Nodes() { + err := a.broker.RegisterNode(id, node, eventlogger.WithNodeRegistrationPolicy(eventlogger.DenyOverwrite)) + if err != nil { + return fmt.Errorf("%s: unable to register nodes for %q: %w", op, name, err) + } + } + + pipeline := eventlogger.Pipeline{ + PipelineID: eventlogger.PipelineID(b.Name()), + EventType: b.EventType(), + NodeIDs: b.NodeIDs(), + } + + err := a.broker.RegisterPipeline(pipeline, eventlogger.WithPipelineRegistrationPolicy(eventlogger.DenyOverwrite)) if err != nil { - return err + return fmt.Errorf("%s: unable to register pipeline for %q: %w", op, name, err) } + + // Establish if we ONLY have pipelines that include filter nodes. + // Otherwise, we can rely on the eventlogger broker guarantee. + threshold := a.requiredSuccessThresholdSinks() + // Update the success threshold now that the pipeline is registered. - err = a.broker.SetSuccessThresholdSinks(eventlogger.EventType(event.AuditType.String()), 1) + err = a.broker.SetSuccessThresholdSinks(eventlogger.EventType(event.AuditType.String()), threshold) if err != nil { - return err + return fmt.Errorf("%s: unable to configure sink success threshold (%d) for %q: %w", op, threshold, name, err) } } @@ -84,6 +107,8 @@ func (a *AuditBroker) Register(name string, b audit.Backend, local bool) error { // Deregister is used to remove an audit backend from the broker func (a *AuditBroker) Deregister(ctx context.Context, name string) error { + const op = "vault.(AuditBroker).Deregister" + a.Lock() defer a.Unlock() @@ -93,20 +118,22 @@ func (a *AuditBroker) Deregister(ctx context.Context, name string) error { delete(a.backends, name) if a.broker != nil { - if len(a.backends) == 0 { - err := a.broker.SetSuccessThresholdSinks(eventlogger.EventType(event.AuditType.String()), 0) - if err != nil { - return err - } + // Establish if we ONLY have pipelines that include filter nodes. + // Otherwise, we can rely on the eventlogger broker guarantee. + threshold := a.requiredSuccessThresholdSinks() + + err := a.broker.SetSuccessThresholdSinks(eventlogger.EventType(event.AuditType.String()), threshold) + if err != nil { + return fmt.Errorf("%s: unable to configure sink success threshold (%d) for %q: %w", op, threshold, name, err) } // The first return value, a bool, indicates whether // RemovePipelineAndNodes encountered the error while evaluating // pre-conditions (false) or once it started removing the pipeline and // the nodes (true). This code doesn't care either way. - _, err := a.broker.RemovePipelineAndNodes(ctx, eventlogger.EventType(event.AuditType.String()), eventlogger.PipelineID(name)) + _, err = a.broker.RemovePipelineAndNodes(ctx, eventlogger.EventType(event.AuditType.String()), eventlogger.PipelineID(name)) if err != nil { - return err + return fmt.Errorf("%s: unable to remove pipeline and nodes for %q: %w", op, name, err) } } @@ -221,6 +248,18 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head status, err := a.broker.Send(ctx, eventlogger.EventType(event.AuditType.String()), e) if err != nil { retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) + return retErr.ErrorOrNil() + } + + // Audit event ended up in at least 1 sink. + if len(status.CompleteSinks()) > 0 { + return retErr.ErrorOrNil() + } + + // There were errors from inside the pipeline and we didn't write to a sink. + if len(status.Warnings) > 0 { + retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) + return retErr.ErrorOrNil() } } } @@ -317,6 +356,18 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea status, err := a.broker.Send(auditContext, eventlogger.EventType(event.AuditType.String()), e) if err != nil { retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...)) + return retErr.ErrorOrNil() + } + + // Audit event ended up in at least 1 sink. + if len(status.CompleteSinks()) > 0 { + return retErr.ErrorOrNil() + } + + // There were errors from inside the pipeline and we didn't write to a sink. + if len(status.Warnings) > 0 { + retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...)) + return retErr.ErrorOrNil() } } } @@ -333,3 +384,19 @@ func (a *AuditBroker) Invalidate(ctx context.Context, key string) { be.backend.Invalidate(ctx) } } + +// requiredSuccessThresholdSinks returns the value that should be used for configuring +// success threshold sinks on the eventlogger broker. +// If all backends have nodes which provide filtering, then we cannot rely on the +// guarantee provided by setting the threshold to 1, and must set it to 0. +func (a *AuditBroker) requiredSuccessThresholdSinks() int { + threshold := 0 + for _, be := range a.backends { + if !be.backend.HasFiltering() { + threshold = 1 + break + } + } + + return threshold +} diff --git a/vault/audit_broker_test.go b/vault/audit_broker_test.go new file mode 100644 index 000000000000..a7fa891fcfdb --- /dev/null +++ b/vault/audit_broker_test.go @@ -0,0 +1,143 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package vault + +import ( + "context" + "crypto/sha256" + "testing" + + "github.com/hashicorp/eventlogger" + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/builtin/audit/syslog" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" + "github.com/hashicorp/vault/internal/observability/event" + "github.com/hashicorp/vault/sdk/helper/salt" + "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" +) + +// testAuditBackend will create an audit.Backend (which expects to use the eventlogger). +func testAuditBackend(t *testing.T, path string, config map[string]string) audit.Backend { + t.Helper() + + headersCfg := &AuditedHeadersConfig{ + Headers: make(map[string]*auditedHeaderSettings), + view: nil, + } + + view := &logical.InmemStorage{} + se := &logical.StorageEntry{Key: "salt", Value: []byte("juan")} + err := view.Put(context.Background(), se) + require.NoError(t, err) + + cfg := &audit.BackendConfig{ + SaltView: view, + SaltConfig: &salt.Config{ + HMAC: sha256.New, + HMACType: "hmac-sha256", + }, + Config: config, + MountPath: path, + } + + be, err := syslog.Factory(context.Background(), cfg, true, headersCfg) + require.NoError(t, err) + require.NotNil(t, be) + + return be +} + +// TestAuditBroker_Register_SuccessThresholdSinks tests that we are able to +// correctly identify what the required success threshold sinks value on the +// eventlogger broker should be set to. +// We expect: +// * 0 for only filtered backends +// * 1 for any other combination +func TestAuditBroker_Register_SuccessThresholdSinks(t *testing.T) { + t.Parallel() + l := corehelpers.NewTestLogger(t) + a, err := NewAuditBroker(l, true) + require.NoError(t, err) + require.NotNil(t, a) + + filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "foo == bar"}) + noFilterBackend := testAuditBackend(t, "b2-no-filter", map[string]string{}) + + // Should be set to 0 for required sinks (and not found, as we've never registered before). + res, ok := a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.False(t, ok) + require.Equal(t, 0, res) + + // Register the filtered backend first, this shouldn't change the + // success threshold sinks to 1 as we can't guarantee any device yet. + err = a.Register("b1-filter", filterBackend, false) + require.NoError(t, err) + + // Check the SuccessThresholdSinks (we expect 0 still, but found). + res, ok = a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 0, res) + + // Register the non-filtered backend second, this should mean we + // can rely on guarantees from the broker again. + err = a.Register("b2-no-filter", noFilterBackend, false) + require.NoError(t, err) + + // Check the SuccessThresholdSinks (we expect 1 now). + res, ok = a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 1, res) +} + +// TestAuditBroker_Deregister_SuccessThresholdSinks tests that we are able to +// correctly identify what the required success threshold sinks value on the +// eventlogger broker should be set to when deregistering audit backends. +// We expect: +// * 0 for only filtered backends +// * 1 for any other combination +func TestAuditBroker_Deregister_SuccessThresholdSinks(t *testing.T) { + t.Parallel() + l := corehelpers.NewTestLogger(t) + a, err := NewAuditBroker(l, true) + require.NoError(t, err) + require.NotNil(t, a) + + filterBackend := testAuditBackend(t, "b1-filter", map[string]string{"filter": "foo == bar"}) + noFilterBackend := testAuditBackend(t, "b2-no-filter", map[string]string{}) + + err = a.Register("b1-filter", filterBackend, false) + require.NoError(t, err) + err = a.Register("b2-no-filter", noFilterBackend, false) + require.NoError(t, err) + + // We have a mix of filtered and non-filtered backends, so the + // successThresholdSinks should be 1. + res, ok := a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 1, res) + + // Deregister the non-filtered backend, there is one filtered backend left, + // so the successThresholdSinks should be 0. + err = a.Deregister(context.Background(), "b2-no-filter") + require.NoError(t, err) + res, ok = a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 0, res) + + // Deregister the last backend, disabling audit. The value of + // successThresholdSinks should still be 0. + err = a.Deregister(context.Background(), "b1-filter") + require.NoError(t, err) + res, ok = a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 0, res) + + // Re-register a backend that doesn't use filtering. + err = a.Register("b2-no-filter", noFilterBackend, false) + require.NoError(t, err) + res, ok = a.broker.SuccessThresholdSinks(eventlogger.EventType(event.AuditType.String())) + require.True(t, ok) + require.Equal(t, 1, res) +} diff --git a/vault/audit_test.go b/vault/audit_test.go index 60f9f45e1cef..afecafaea245 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -530,16 +530,13 @@ func TestAuditBroker_LogResponse(t *testing.T) { OuterErr: respErr, } err = b.LogResponse(ctx, logInput, headersConf) - if err != nil { - t.Fatalf("err: %v", err) - } + require.NoError(t, err) // Should FAIL work with both failing backends a2.RespErr = fmt.Errorf("failed") err = b.LogResponse(ctx, logInput, headersConf) - if !strings.Contains(err.Error(), "event not processed by enough 'sink' nodes") { - t.Fatalf("err: %v", err) - } + require.Error(t, err) + require.ErrorContains(t, err, "event not processed by enough 'sink' nodes") } func TestAuditBroker_AuditHeaders(t *testing.T) { diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index d975f13f51d0..8e6c476d726f 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -51,7 +51,7 @@ func doTwoPhaseLogin(t *testing.T, client *api.Client, totpCodePath, methodID, u } func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { - noop := corehelpers.TestNoopAudit(t, "noop", nil) + noop := corehelpers.TestNoopAudit(t, "noop/", nil) cluster := vault.NewTestCluster(t, &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{