diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 8281f7b79ecb..cebe021e4f08 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -174,6 +174,7 @@ Setting environmental variable ELASTIC_NETINFO:false in Elastic Agent pod will d - Add support for PEM-based Okta auth in HTTPJSON. {pull}37772[37772] - Prevent complete loss of long request trace data. {issue}37826[37826] {pull}37836[37836] - Add support for PEM-based Okta auth in CEL. {pull}37813[37813] +- Add support for Active Directory an entity analytics provider. {pull}37919[37919] *Auditbeat* diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory.go new file mode 100644 index 000000000000..a8a498380e24 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory.go @@ -0,0 +1,378 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +// Package activedirectory provides a user identity asset provider for Microsoft +// Active Directory. +package activedirectory + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "time" + + "github.com/go-ldap/ldap/v3" + + v2 "github.com/elastic/beats/v7/filebeat/input/v2" + "github.com/elastic/beats/v7/libbeat/beat" + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/internal/kvstore" + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/provider" + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/provider/activedirectory/internal/activedirectory" + "github.com/elastic/elastic-agent-libs/config" + "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent-libs/mapstr" + "github.com/elastic/elastic-agent-libs/transport/httpcommon" + "github.com/elastic/elastic-agent-libs/transport/tlscommon" + "github.com/elastic/go-concert/ctxtool" +) + +func init() { + err := provider.Register(Name, New) + if err != nil { + panic(err) + } +} + +// Name of this provider. +const Name = "activedirectory" + +// FullName of this provider, including the input name. Prefer using this +// value for full context, especially if the input name isn't present in an +// adjacent log field. +const FullName = "entity-analytics-" + Name + +// adInput implements the provider.Provider interface. +type adInput struct { + *kvstore.Manager + + cfg conf + baseDN *ldap.DN + tlsConfig *tls.Config + + metrics *inputMetrics + logger *logp.Logger +} + +// New creates a new instance of an Active Directory identity provider. +func New(logger *logp.Logger) (provider.Provider, error) { + p := adInput{ + cfg: defaultConfig(), + } + p.Manager = &kvstore.Manager{ + Logger: logger, + Type: FullName, + Configure: p.configure, + } + + return &p, nil +} + +// configure configures this provider using the given configuration. +func (p *adInput) configure(cfg *config.C) (kvstore.Input, error) { + err := cfg.Unpack(&p.cfg) + if err != nil { + return nil, fmt.Errorf("unable to unpack %s input config: %w", Name, err) + } + p.baseDN, err = ldap.ParseDN(p.cfg.BaseDN) + if err != nil { + return nil, err + } + u, err := url.Parse(p.cfg.URL) + if err != nil { + return nil, err + } + if p.cfg.TLS.IsEnabled() && u.Scheme == "ldaps" { + tlsConfig, err := tlscommon.LoadTLSConfig(p.cfg.TLS) + if err != nil { + return nil, err + } + host, _, err := net.SplitHostPort(u.Host) + var addrErr *net.AddrError + switch { + case err == nil: + case errors.As(err, &addrErr): + if addrErr.Err != "missing port in address" { + return nil, err + } + host = u.Host + default: + return nil, err + } + p.tlsConfig = tlsConfig.BuildModuleClientConfig(host) + } + return p, nil +} + +// Name returns the name of this provider. +func (p *adInput) Name() string { + return FullName +} + +func (*adInput) Test(v2.TestContext) error { return nil } + +// Run will start data collection on this provider. +func (p *adInput) Run(inputCtx v2.Context, store *kvstore.Store, client beat.Client) error { + p.logger = inputCtx.Logger.With("provider", Name, "domain", p.cfg.URL) + p.metrics = newMetrics(inputCtx.ID, nil) + defer p.metrics.Close() + + lastSyncTime, _ := getLastSync(store) + syncWaitTime := time.Until(lastSyncTime.Add(p.cfg.SyncInterval)) + lastUpdateTime, _ := getLastUpdate(store) + updateWaitTime := time.Until(lastUpdateTime.Add(p.cfg.UpdateInterval)) + + syncTimer := time.NewTimer(syncWaitTime) + updateTimer := time.NewTimer(updateWaitTime) + + for { + select { + case <-inputCtx.Cancelation.Done(): + if !errors.Is(inputCtx.Cancelation.Err(), context.Canceled) { + return inputCtx.Cancelation.Err() + } + return nil + case <-syncTimer.C: + start := time.Now() + if err := p.runFullSync(inputCtx, store, client); err != nil { + p.logger.Errorw("Error running full sync", "error", err) + p.metrics.syncError.Inc() + } + p.metrics.syncTotal.Inc() + p.metrics.syncProcessingTime.Update(time.Since(start).Nanoseconds()) + + syncTimer.Reset(p.cfg.SyncInterval) + p.logger.Debugf("Next sync expected at: %v", time.Now().Add(p.cfg.SyncInterval)) + + // Reset the update timer and wait the configured interval. If the + // update timer has already fired, then drain the timer's channel + // before resetting. + if !updateTimer.Stop() { + <-updateTimer.C + } + updateTimer.Reset(p.cfg.UpdateInterval) + p.logger.Debugf("Next update expected at: %v", time.Now().Add(p.cfg.UpdateInterval)) + case <-updateTimer.C: + start := time.Now() + if err := p.runIncrementalUpdate(inputCtx, store, client); err != nil { + p.logger.Errorw("Error running incremental update", "error", err) + p.metrics.updateError.Inc() + } + p.metrics.updateTotal.Inc() + p.metrics.updateProcessingTime.Update(time.Since(start).Nanoseconds()) + updateTimer.Reset(p.cfg.UpdateInterval) + p.logger.Debugf("Next update expected at: %v", time.Now().Add(p.cfg.UpdateInterval)) + } + } +} + +// clientOption returns constructed client configuration options, including +// setting up http+unix and http+npipe transports if requested. +func clientOptions(keepalive httpcommon.WithKeepaliveSettings) []httpcommon.TransportOption { + return []httpcommon.TransportOption{ + httpcommon.WithAPMHTTPInstrumentation(), + keepalive, + } +} + +// runFullSync performs a full synchronization. It will fetch user and group +// identities from Azure Active Directory, enrich users with group memberships, +// and publishes all known users (regardless if they have been modified) to the +// given beat.Client. +func (p *adInput) runFullSync(inputCtx v2.Context, store *kvstore.Store, client beat.Client) error { + p.logger.Debugf("Running full sync...") + + p.logger.Debugf("Opening new transaction...") + state, err := newStateStore(store) + if err != nil { + return fmt.Errorf("unable to begin transaction: %w", err) + } + p.logger.Debugf("Transaction opened") + defer func() { // If commit is successful, call to this close will be no-op. + closeErr := state.close(false) + if closeErr != nil { + p.logger.Errorw("Error rolling back full sync transaction", "error", closeErr) + } + }() + + ctx := ctxtool.FromCanceller(inputCtx.Cancelation) + p.logger.Debugf("Starting fetch...") + _, err = p.doFetchUsers(ctx, state, true) + if err != nil { + return err + } + + if len(state.users) != 0 { + tracker := kvstore.NewTxTracker(ctx) + + start := time.Now() + p.publishMarker(start, start, inputCtx.ID, true, client, tracker) + for _, u := range state.users { + p.publishUser(u, state, inputCtx.ID, client, tracker) + } + + end := time.Now() + p.publishMarker(end, end, inputCtx.ID, false, client, tracker) + + tracker.Wait() + } + + if ctx.Err() != nil { + return ctx.Err() + } + + state.lastSync = time.Now() + err = state.close(true) + if err != nil { + return fmt.Errorf("unable to commit state: %w", err) + } + + return nil +} + +// runIncrementalUpdate will run an incremental update. The process is similar +// to full synchronization, except only users which have changed (newly +// discovered, modified, or deleted) will be published. +func (p *adInput) runIncrementalUpdate(inputCtx v2.Context, store *kvstore.Store, client beat.Client) error { + p.logger.Debugf("Running incremental update...") + + state, err := newStateStore(store) + if err != nil { + return fmt.Errorf("unable to begin transaction: %w", err) + } + defer func() { // If commit is successful, call to this close will be no-op. + closeErr := state.close(false) + if closeErr != nil { + p.logger.Errorw("Error rolling back incremental update transaction", "error", closeErr) + } + }() + + ctx := ctxtool.FromCanceller(inputCtx.Cancelation) + updatedUsers, err := p.doFetchUsers(ctx, state, false) + if err != nil { + return err + } + + var tracker *kvstore.TxTracker + if len(updatedUsers) != 0 { + tracker = kvstore.NewTxTracker(ctx) + for _, u := range updatedUsers { + p.publishUser(u, state, inputCtx.ID, client, tracker) + } + tracker.Wait() + } + + if ctx.Err() != nil { + return ctx.Err() + } + + state.lastUpdate = time.Now() + if err = state.close(true); err != nil { + return fmt.Errorf("unable to commit state: %w", err) + } + + return nil +} + +// doFetchUsers handles fetching user identities from Active Directory. If +// fullSync is true, then any existing whenChanged will be ignored, forcing a +// full synchronization from Active Directory. +// Returns a set of modified users by ID. +func (p *adInput) doFetchUsers(ctx context.Context, state *stateStore, fullSync bool) ([]*User, error) { + var since time.Time + if !fullSync { + since = state.whenChanged + } + + entries, err := activedirectory.GetDetails(p.cfg.URL, p.cfg.User, p.cfg.Password, p.baseDN, since, p.cfg.PagingSize, nil, p.tlsConfig) + p.logger.Debugf("received %d users from API", len(entries)) + if err != nil { + return nil, err + } + + var ( + users []*User + whenChanged time.Time + ) + if fullSync { + for _, u := range entries { + state.storeUser(u) + if u.WhenChanged.After(whenChanged) { + whenChanged = u.WhenChanged + } + } + } else { + users = make([]*User, 0, len(entries)) + for _, u := range entries { + users = append(users, state.storeUser(u)) + if u.WhenChanged.After(whenChanged) { + whenChanged = u.WhenChanged + } + } + p.logger.Debugf("processed %d users from API", len(users)) + } + if whenChanged.After(state.whenChanged) { + state.whenChanged = whenChanged + } + + return users, nil +} + +// publishMarker will publish a write marker document using the given beat.Client. +// If start is true, then it will be a start marker, otherwise an end marker. +func (p *adInput) publishMarker(ts, eventTime time.Time, inputID string, start bool, client beat.Client, tracker *kvstore.TxTracker) { + fields := mapstr.M{} + _, _ = fields.Put("labels.identity_source", inputID) + + if start { + _, _ = fields.Put("event.action", "started") + _, _ = fields.Put("event.start", eventTime) + } else { + _, _ = fields.Put("event.action", "completed") + _, _ = fields.Put("event.end", eventTime) + } + + event := beat.Event{ + Timestamp: ts, + Fields: fields, + Private: tracker, + } + tracker.Add() + if start { + p.logger.Debug("Publishing start write marker") + } else { + p.logger.Debug("Publishing end write marker") + } + + client.Publish(event) +} + +// publishUser will publish a user document using the given beat.Client. +func (p *adInput) publishUser(u *User, state *stateStore, inputID string, client beat.Client, tracker *kvstore.TxTracker) { + userDoc := mapstr.M{} + + _, _ = userDoc.Put("activedirectory", u.User) + _, _ = userDoc.Put("labels.identity_source", inputID) + _, _ = userDoc.Put("user.id", u.ID) + + switch u.State { + case Discovered: + _, _ = userDoc.Put("event.action", "user-discovered") + case Modified: + _, _ = userDoc.Put("event.action", "user-modified") + } + + event := beat.Event{ + Timestamp: time.Now(), + Fields: userDoc, + Private: tracker, + } + tracker.Add() + + p.logger.Debugf("Publishing user %q", u.ID) + + client.Publish(event) +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory_test.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory_test.go new file mode 100644 index 000000000000..fdb3883b0b8b --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/activedirectory_test.go @@ -0,0 +1,139 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "context" + "encoding/json" + "flag" + "os" + "sort" + "testing" + "time" + + "github.com/go-ldap/ldap/v3" + + "github.com/elastic/elastic-agent-libs/logp" +) + +var logResponses = flag.Bool("log_response", false, "use to log users/groups returned from the API") + +func TestActiveDirectoryDoFetch(t *testing.T) { + url, ok := os.LookupEnv("AD_URL") + if !ok { + t.Skip("activedirectory tests require ${AD_URL} to be set") + } + baseDN, ok := os.LookupEnv("AD_BASE") + if !ok { + t.Skip("activedirectory tests require ${AD_BASE} to be set") + } + user, ok := os.LookupEnv("AD_USER") + if !ok { + t.Skip("activedirectory tests require ${AD_USER} to be set") + } + pass, ok := os.LookupEnv("AD_PASS") + if !ok { + t.Skip("activedirectory tests require ${AD_PASS} to be set") + } + + base, err := ldap.ParseDN(baseDN) + if err != nil { + t.Fatalf("invalid base distinguished name: %v", err) + } + + const dbFilename = "TestActiveDirectoryDoFetch.db" + store := testSetupStore(t, dbFilename) + t.Cleanup(func() { + testCleanupStore(store, dbFilename) + }) + a := adInput{ + cfg: conf{ + BaseDN: baseDN, + URL: url, + User: user, + Password: pass, + }, + baseDN: base, + logger: logp.L(), + } + + ss, err := newStateStore(store) + if err != nil { + t.Fatalf("unexpected error making state store: %v", err) + } + defer ss.close(false) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + var times []time.Time + t.Run("full", func(t *testing.T) { + ss.whenChanged = time.Time{} // Reach back to the start of time. + + users, err := a.doFetchUsers(ctx, ss, false) // We are lying about fullSync since we are not getting users via the store. + if err != nil { + t.Fatalf("unexpected error from doFetch: %v", err) + } + + if len(users) == 0 { + t.Error("expected non-empty result from query") + } + found := false + var gotUsers []string + for _, e := range users { + gotUsers = append(gotUsers, e.ID) + if e.ID == user { + found = true + } + + times = append(times, e.WhenChanged) + } + if !found { + t.Errorf("expected login user to be found in directory: got:%q", gotUsers) + } + + if !*logResponses { + return + } + b, err := json.MarshalIndent(users, "", "\t") + if err != nil { + t.Errorf("failed to marshal users for logging: %v", err) + } + t.Logf("user: %s", b) + }) + + // Find the time of the first changed entry for later. + sort.Slice(times, func(i, j int) bool { return times[i].Before(times[j]) }) + since := times[0].Add(time.Second) // Step past first entry by a small amount within LDAP resolution. + var want int + // ... and count all entries since then. + for _, when := range times[1:] { + if !since.After(when) { + want++ + } + } + + t.Run("update", func(t *testing.T) { + ss.whenChanged = since // Reach back until after the first entry. + + users, err := a.doFetchUsers(ctx, ss, false) + if err != nil { + t.Fatalf("unexpected error from doFetchUsers: %v", err) + } + + if len(users) != want { + t.Errorf("unexpected number of results from query since %v: got:%d want:%d", since, len(users), want) + } + + if !*logResponses && !t.Failed() { + return + } + b, err := json.MarshalIndent(users, "", "\t") + if err != nil { + t.Errorf("failed to marshal users for logging: %v", err) + } + t.Logf("user: %s", b) + }) +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf.go new file mode 100644 index 000000000000..7dab7f5e4569 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf.go @@ -0,0 +1,89 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "errors" + "net" + "net/url" + "time" + + "github.com/go-ldap/ldap/v3" + + "github.com/elastic/elastic-agent-libs/transport/tlscommon" +) + +// defaultConfig returns a default configuration. +func defaultConfig() conf { + return conf{ + SyncInterval: 24 * time.Hour, + UpdateInterval: 15 * time.Minute, + } +} + +// conf contains parameters needed to configure the input. +type conf struct { + BaseDN string `config:"ad_base_dn" validate:"required"` + + URL string `config:"ad_url" validate:"required"` + User string `config:"ad_user" validate:"required"` + Password string `config:"ad_password" validate:"required"` + + PagingSize uint32 `config:"ad_paging_size"` + + // SyncInterval is the time between full + // synchronisation operations. + SyncInterval time.Duration `config:"sync_interval"` + // UpdateInterval is the time between + // incremental updated. + UpdateInterval time.Duration `config:"update_interval"` + + // TLS provides ssl/tls setup settings + TLS *tlscommon.Config `config:"ssl" yaml:"ssl,omitempty" json:"ssl,omitempty"` +} + +var ( + errInvalidSyncInterval = errors.New("zero or negative sync_interval") + errInvalidUpdateInterval = errors.New("zero or negative update_interval") + errSyncBeforeUpdate = errors.New("sync_interval not longer than update_interval") +) + +// Validate runs validation against the config. +func (c *conf) Validate() error { + switch { + case c.SyncInterval <= 0: + return errInvalidSyncInterval + case c.UpdateInterval <= 0: + return errInvalidUpdateInterval + case c.SyncInterval <= c.UpdateInterval: + return errSyncBeforeUpdate + } + _, err := ldap.ParseDN(c.BaseDN) + if err != nil { + return err + } + u, err := url.Parse(c.URL) + if err != nil { + return err + } + if c.TLS.IsEnabled() && u.Scheme == "ldaps" { + _, err := tlscommon.LoadTLSConfig(c.TLS) + if err != nil { + return err + } + _, _, err = net.SplitHostPort(u.Host) + var addrErr *net.AddrError + switch { + case err == nil: + case errors.As(err, &addrErr): + if addrErr.Err != "missing port in address" { + return err + } + default: + return err + } + } + return nil +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf_test.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf_test.go new file mode 100644 index 000000000000..c518c122635c --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/conf_test.go @@ -0,0 +1,57 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "testing" + "time" +) + +var validateTests = []struct { + name string + cfg conf + wantErr error +}{ + { + name: "default", + cfg: defaultConfig(), + wantErr: nil, + }, + { + name: "invalid_sync_interval", + cfg: conf{ + SyncInterval: 0, + UpdateInterval: time.Second * 2, + }, + wantErr: errInvalidSyncInterval, + }, + { + name: "invalid_update_interval", + cfg: conf{ + SyncInterval: time.Second, + UpdateInterval: 0, + }, + wantErr: errInvalidUpdateInterval, + }, + { + name: "invalid_relative_intervals", + cfg: conf{ + SyncInterval: time.Second, + UpdateInterval: time.Second * 2, + }, + wantErr: errSyncBeforeUpdate, + }, +} + +func TestConfValidate(t *testing.T) { + for _, test := range validateTests { + t.Run(test.name, func(t *testing.T) { + err := test.cfg.Validate() + if err != test.wantErr { + t.Errorf("unexpected error: got:%v want:%v", err, test.wantErr) + } + }) + } +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/metrics.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/metrics.go new file mode 100644 index 000000000000..070deab28868 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/metrics.go @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "github.com/rcrowley/go-metrics" + + "github.com/elastic/beats/v7/libbeat/monitoring/inputmon" + "github.com/elastic/elastic-agent-libs/monitoring" + "github.com/elastic/elastic-agent-libs/monitoring/adapter" +) + +// inputMetrics defines metrics for this provider. +type inputMetrics struct { + unregister func() + + syncTotal *monitoring.Uint // The total number of full synchronizations. + syncError *monitoring.Uint // The number of full synchronizations that failed due to an error. + syncProcessingTime metrics.Sample // Histogram of the elapsed full synchronization times in nanoseconds (time of API contact to items sent to output). + updateTotal *monitoring.Uint // The total number of incremental updates. + updateError *monitoring.Uint // The number of incremental updates that failed due to an error. + updateProcessingTime metrics.Sample // Histogram of the elapsed incremental update times in nanoseconds (time of API contact to items sent to output). +} + +// Close removes metrics from the registry. +func (m *inputMetrics) Close() { + m.unregister() +} + +// newMetrics creates a new instance for gathering metrics. +func newMetrics(id string, optionalParent *monitoring.Registry) *inputMetrics { + reg, unreg := inputmon.NewInputRegistry(FullName, id, optionalParent) + + out := inputMetrics{ + unregister: unreg, + syncTotal: monitoring.NewUint(reg, "sync_total"), + syncError: monitoring.NewUint(reg, "sync_error"), + syncProcessingTime: metrics.NewUniformSample(1024), + updateTotal: monitoring.NewUint(reg, "update_total"), + updateError: monitoring.NewUint(reg, "update_error"), + updateProcessingTime: metrics.NewUniformSample(1024), + } + + adapter.NewGoMetrics(reg, "sync_processing_time", adapter.Accept).Register("histogram", metrics.NewHistogram(out.syncProcessingTime)) //nolint:errcheck // A unique namespace is used so name collisions are impossible. + adapter.NewGoMetrics(reg, "update_processing_time", adapter.Accept).Register("histogram", metrics.NewHistogram(out.updateProcessingTime)) //nolint:errcheck // A unique namespace is used so name collisions are impossible. + + return &out +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/state_string.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/state_string.go new file mode 100644 index 000000000000..b584c0b611e7 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/state_string.go @@ -0,0 +1,29 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +// Code generated by "stringer -type State"; DO NOT EDIT. + +package activedirectory + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Discovered-1] + _ = x[Modified-2] +} + +const _State_name = "DiscoveredModified" + +var _State_index = [...]uint8{0, 10, 18} + +func (i State) String() string { + i -= 1 + if i < 0 || i >= State(len(_State_index)-1) { + return "State(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _State_name[_State_index[i]:_State_index[i+1]] +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore.go new file mode 100644 index 000000000000..081176c58bf9 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore.go @@ -0,0 +1,194 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/internal/kvstore" + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/provider/activedirectory/internal/activedirectory" +) + +var ( + usersBucket = []byte("users") + stateBucket = []byte("state") + + whenChangedKey = []byte("when_changed") + lastSyncKey = []byte("last_sync") + lastUpdateKey = []byte("last_update") +) + +//go:generate stringer -type State +//go:generate go-licenser -license Elastic +type State int + +const ( + Discovered State = iota + 1 + Modified +) + +type User struct { + activedirectory.Entry `json:"properties"` + State State `json:"state"` +} + +// stateStore wraps a kvstore.Transaction and provides convenience methods for +// accessing and store relevant data within the kvstore database. +type stateStore struct { + tx *kvstore.Transaction + + // whenChanged is the last whenChanged time in the set of + // users and their associated groups. + whenChanged time.Time + + // lastSync and lastUpdate are the times of the first update + // or sync operation of users/groups. + lastSync time.Time + lastUpdate time.Time + users map[string]*User +} + +// newStateStore creates a new instance of stateStore. It will open a new write +// transaction on the kvstore and load values from the database. Since this +// opens a write transaction, only one instance of stateStore may be created +// at a time. The close function must be called to release the transaction lock +// on the kvstore database. +func newStateStore(store *kvstore.Store) (*stateStore, error) { + tx, err := store.BeginTx(true) + if err != nil { + return nil, fmt.Errorf("unable to open state store transaction: %w", err) + } + + s := stateStore{ + users: make(map[string]*User), + tx: tx, + } + + err = s.tx.Get(stateBucket, lastSyncKey, &s.lastSync) + if err != nil && !errIsItemNotFound(err) { + return nil, fmt.Errorf("unable to get last sync time from state: %w", err) + } + err = s.tx.Get(stateBucket, lastUpdateKey, &s.lastUpdate) + if err != nil && !errIsItemNotFound(err) { + return nil, fmt.Errorf("unable to get last update time from state: %w", err) + } + err = s.tx.Get(stateBucket, whenChangedKey, &s.whenChanged) + if err != nil && !errIsItemNotFound(err) { + return nil, fmt.Errorf("unable to get last change time from state: %w", err) + } + + err = s.tx.ForEach(usersBucket, func(key, value []byte) error { + var u User + err = json.Unmarshal(value, &u) + if err != nil { + return fmt.Errorf("unable to unmarshal user from state: %w", err) + } + s.users[u.ID] = &u + + return nil + }) + if err != nil && !errIsItemNotFound(err) { + return nil, fmt.Errorf("unable to get users from state: %w", err) + } + + return &s, nil +} + +// storeUser stores a user. If the user does not exist in the store, then the +// user will be marked as discovered. Otherwise, the user will be marked +// as modified. +func (s *stateStore) storeUser(u activedirectory.Entry) *User { + su := User{Entry: u} + if existing, ok := s.users[u.ID]; ok { + su.State = Modified + *existing = su + } else { + su.State = Discovered + s.users[u.ID] = &su + } + return &su +} + +// close will close out the stateStore. If commit is true, the staged values on the +// stateStore will be set in the kvstore database, and the transaction will be +// committed. Otherwise, all changes will be discarded and the transaction will +// be rolled back. The stateStore must NOT be used after close is called, rather, +// a new stateStore should be created. +func (s *stateStore) close(commit bool) (err error) { + if !commit { + return s.tx.Rollback() + } + + // Fallback in case one of the statements below fails. If everything is + // successful and Commit is called, then this call to Rollback will be a no-op. + defer func() { + if err == nil { + return + } + rollbackErr := s.tx.Rollback() + if rollbackErr == nil { + err = fmt.Errorf("multiple errors during statestore close: %w", errors.Join(err, rollbackErr)) + } + }() + + if !s.lastSync.IsZero() { + err = s.tx.Set(stateBucket, lastSyncKey, &s.lastSync) + if err != nil { + return fmt.Errorf("unable to save last sync time to state: %w", err) + } + } + if !s.lastUpdate.IsZero() { + err = s.tx.Set(stateBucket, lastUpdateKey, &s.lastUpdate) + if err != nil { + return fmt.Errorf("unable to save last update time to state: %w", err) + } + } + if !s.whenChanged.IsZero() { + err = s.tx.Set(stateBucket, whenChangedKey, &s.whenChanged) + if err != nil { + return fmt.Errorf("unable to save last change time to state: %w", err) + } + } + + for key, value := range s.users { + err = s.tx.Set(usersBucket, []byte(key), value) + if err != nil { + return fmt.Errorf("unable to save user %q to state: %w", key, err) + } + } + + return s.tx.Commit() +} + +// getLastSync retrieves the last full synchronization time from the kvstore +// database. If the value doesn't exist, a zero time.Time is returned. +func getLastSync(store *kvstore.Store) (time.Time, error) { + var t time.Time + err := store.RunTransaction(false, func(tx *kvstore.Transaction) error { + return tx.Get(stateBucket, lastSyncKey, &t) + }) + + return t, err +} + +// getLastUpdate retrieves the last incremental update time from the kvstore +// database. If the value doesn't exist, a zero time.Time is returned. +func getLastUpdate(store *kvstore.Store) (time.Time, error) { + var t time.Time + err := store.RunTransaction(false, func(tx *kvstore.Transaction) error { + return tx.Get(stateBucket, lastUpdateKey, &t) + }) + + return t, err +} + +// errIsItemNotFound returns true if the error represents an item not found +// error (bucket not found or key not found). +func errIsItemNotFound(err error) bool { + return errors.Is(err, kvstore.ErrBucketNotFound) || errors.Is(err, kvstore.ErrKeyNotFound) +} diff --git a/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore_test.go b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore_test.go new file mode 100644 index 000000000000..eeae20431e81 --- /dev/null +++ b/x-pack/filebeat/input/entityanalytics/provider/activedirectory/statestore_test.go @@ -0,0 +1,247 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package activedirectory + +import ( + "bytes" + "encoding/json" + "errors" + "os" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/internal/kvstore" + "github.com/elastic/beats/v7/x-pack/filebeat/input/entityanalytics/provider/activedirectory/internal/activedirectory" + "github.com/elastic/elastic-agent-libs/logp" +) + +func TestStateStore(t *testing.T) { + lastSync, err := time.Parse(time.RFC3339Nano, "2023-01-12T08:47:23.296794-05:00") + if err != nil { + t.Fatalf("failed to parse lastSync") + } + lastUpdate, err := time.Parse(time.RFC3339Nano, "2023-01-12T08:50:04.546457-05:00") + if err != nil { + t.Fatalf("failed to parse lastUpdate") + } + + t.Run("new", func(t *testing.T) { + dbFilename := "TestStateStore_New.db" + store := testSetupStore(t, dbFilename) + t.Cleanup(func() { + testCleanupStore(store, dbFilename) + }) + + // Inject test values into store. + data := []struct { + key []byte + val any + }{ + {key: lastSyncKey, val: lastSync}, + {key: lastUpdateKey, val: lastUpdate}, + } + for _, kv := range data { + err := store.RunTransaction(true, func(tx *kvstore.Transaction) error { + return tx.Set(stateBucket, kv.key, kv.val) + }) + if err != nil { + t.Fatalf("failed to set %s: %v", kv.key, err) + } + } + + ss, err := newStateStore(store) + if err != nil { + t.Fatalf("failed to make new store: %v", err) + } + defer ss.close(false) + + checks := []struct { + name string + got, want any + }{ + {name: "lastSync", got: ss.lastSync, want: lastSync}, + {name: "lastUpdate", got: ss.lastUpdate, want: lastUpdate}, + } + for _, c := range checks { + if !cmp.Equal(c.got, c.want) { + t.Errorf("unexpected results for %s: got:%#v want:%#v", c.name, c.got, c.want) + } + } + }) + + t.Run("close", func(t *testing.T) { + dbFilename := "TestStateStore_Close.db" + store := testSetupStore(t, dbFilename) + t.Cleanup(func() { + testCleanupStore(store, dbFilename) + }) + + wantUsers := map[string]*User{ + "userid": { + State: Discovered, + Entry: activedirectory.Entry{ + ID: "userid", + Status: "STATUS", + }, + }, + } + + ss, err := newStateStore(store) + if err != nil { + t.Fatalf("failed to make new store: %v", err) + } + ss.lastSync = lastSync + ss.lastUpdate = lastUpdate + ss.users = wantUsers + + err = ss.close(true) + if err != nil { + t.Fatalf("unexpected error closing: %v", err) + } + + roundTripChecks := []struct { + name string + key []byte + val any + }{ + {name: "lastSyncKey", key: lastSyncKey, val: &ss.lastSync}, + {name: "lastUpdateKey", key: lastUpdateKey, val: &ss.lastUpdate}, + } + for _, check := range roundTripChecks { + want, err := json.Marshal(check.val) + if err != nil { + t.Errorf("unexpected error marshaling %s: %v", check.name, err) + } + var got []byte + err = store.RunTransaction(false, func(tx *kvstore.Transaction) error { + got, err = tx.GetBytes(stateBucket, check.key) + return err + }) + if err != nil { + t.Errorf("unexpected error from store run transaction %s: %v", check.name, err) + } + if !bytes.Equal(got, want) { + t.Errorf("unexpected result after store round-trip for %s: got:%s want:%s", check.name, got, want) + } + } + + users := map[string]*User{} + err = store.RunTransaction(false, func(tx *kvstore.Transaction) error { + return tx.ForEach(usersBucket, func(key, value []byte) error { + var u User + err = json.Unmarshal(value, &u) + if err != nil { + return err + } + users[u.ID] = &u + return nil + }) + }) + if err != nil { + t.Errorf("unexpected error from store run transaction: %v", err) + } + if !cmp.Equal(wantUsers, users) { + t.Errorf("unexpected result:\n- want\n+ got\n%s", cmp.Diff(wantUsers, users)) + } + }) + + t.Run("get_last_sync", func(t *testing.T) { + dbFilename := "TestGetLastSync.db" + store := testSetupStore(t, dbFilename) + t.Cleanup(func() { + testCleanupStore(store, dbFilename) + }) + + err := store.RunTransaction(true, func(tx *kvstore.Transaction) error { + return tx.Set(stateBucket, lastSyncKey, lastSync) + }) + if err != nil { + t.Fatalf("failed to set value: %v", err) + } + + got, err := getLastSync(store) + if err != nil { + t.Errorf("unexpected error from getLastSync: %v", err) + } + if !lastSync.Equal(got) { + t.Errorf("unexpected result from getLastSync: got:%v want:%v", got, lastSync) + } + }) + + t.Run("get_last_update", func(t *testing.T) { + dbFilename := "TestGetLastUpdate.db" + store := testSetupStore(t, dbFilename) + t.Cleanup(func() { + testCleanupStore(store, dbFilename) + }) + + err := store.RunTransaction(true, func(tx *kvstore.Transaction) error { + return tx.Set(stateBucket, lastUpdateKey, lastUpdate) + }) + if err != nil { + t.Fatalf("failed to set value: %v", err) + } + + got, err := getLastUpdate(store) + if err != nil { + t.Errorf("unexpected error from getLastUpdate: %v", err) + } + if !lastUpdate.Equal(got) { + t.Errorf("unexpected result from getLastUpdate: got:%v want:%v", got, lastUpdate) + } + }) +} + +func TestErrIsItemFound(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "bucket-not-found", + err: kvstore.ErrBucketNotFound, + want: true, + }, + { + name: "key-not-found", + err: kvstore.ErrKeyNotFound, + want: true, + }, + { + name: "invalid error", + err: errors.New("test error"), + want: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := errIsItemNotFound(test.err) + if got != test.want { + t.Errorf("unexpected result for %s: got:%t want:%t", test.name, got, test.want) + } + }) + } +} + +func ptr[T any](v T) *T { return &v } + +func testSetupStore(t *testing.T, path string) *kvstore.Store { + t.Helper() + + store, err := kvstore.NewStore(logp.L(), path, 0644) + if err != nil { + t.Fatalf("unexpected error making store: %v", err) + } + return store +} + +func testCleanupStore(store *kvstore.Store, path string) { + _ = store.Close() + _ = os.Remove(path) +}