diff --git a/x-pack/filebeat/include/list.go b/x-pack/filebeat/include/list.go index 48ac4c4ce6c..ae739d07510 100644 --- a/x-pack/filebeat/include/list.go +++ b/x-pack/filebeat/include/list.go @@ -15,7 +15,6 @@ import ( _ "github.com/elastic/beats/v7/x-pack/filebeat/input/http_endpoint" _ "github.com/elastic/beats/v7/x-pack/filebeat/input/httpjson" _ "github.com/elastic/beats/v7/x-pack/filebeat/input/netflow" - _ "github.com/elastic/beats/v7/x-pack/filebeat/input/o365audit" _ "github.com/elastic/beats/v7/x-pack/filebeat/input/s3" _ "github.com/elastic/beats/v7/x-pack/filebeat/module/activemq" _ "github.com/elastic/beats/v7/x-pack/filebeat/module/aws" diff --git a/x-pack/filebeat/input/default-inputs/inputs.go b/x-pack/filebeat/input/default-inputs/inputs.go index 525bbe2a578..0af51f2c18a 100644 --- a/x-pack/filebeat/input/default-inputs/inputs.go +++ b/x-pack/filebeat/input/default-inputs/inputs.go @@ -10,6 +10,7 @@ import ( v2 "github.com/elastic/beats/v7/filebeat/input/v2" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/logp" + "github.com/elastic/beats/v7/x-pack/filebeat/input/o365audit" ) func Init(info beat.Info, log *logp.Logger, store beater.StateStore) []v2.Plugin { @@ -20,5 +21,7 @@ func Init(info beat.Info, log *logp.Logger, store beater.StateStore) []v2.Plugin } func xpackInputs(info beat.Info, log *logp.Logger, store beater.StateStore) []v2.Plugin { - return []v2.Plugin{} + return []v2.Plugin{ + o365audit.Plugin(log, store), + } } diff --git a/x-pack/filebeat/input/o365audit/contentblob.go b/x-pack/filebeat/input/o365audit/contentblob.go index 44ddb911f46..0eca809b637 100644 --- a/x-pack/filebeat/input/o365audit/contentblob.go +++ b/x-pack/filebeat/input/o365audit/contentblob.go @@ -22,7 +22,7 @@ type contentBlob struct { env apiEnvironment id, url string // cursor is used to ACK the resulting events. - cursor cursor + cursor checkpoint // skipLines is used when resuming from a saved cursor so that already // acknowledged objects are not duplicated. skipLines int @@ -115,7 +115,7 @@ func (c contentBlob) handleError(response *http.Response) (actions []poll.Action } // ContentBlob creates a new contentBlob. -func ContentBlob(url string, cursor cursor, env apiEnvironment) contentBlob { +func ContentBlob(url string, cursor checkpoint, env apiEnvironment) contentBlob { return contentBlob{ url: url, env: env, diff --git a/x-pack/filebeat/input/o365audit/contentblob_test.go b/x-pack/filebeat/input/o365audit/contentblob_test.go index 1a08c69fb36..0e7ea7fb411 100644 --- a/x-pack/filebeat/input/o365audit/contentblob_test.go +++ b/x-pack/filebeat/input/o365audit/contentblob_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/elastic/beats/v7/libbeat/beat" @@ -21,9 +22,15 @@ type contentStore struct { stopped bool } -func (s *contentStore) onEvent(b beat.Event) bool { +var errStopped = errors.New("stopped") + +func (s *contentStore) onEvent(b beat.Event, checkpointUpdate interface{}) error { + b.Private = checkpointUpdate s.events = append(s.events, b) - return !s.stopped + if s.stopped { + return errStopped + } + return nil } func (f *fakePoll) BlobContent(t testing.TB, b poll.Transaction, data []common.MapStr, nextUrl string) poll.Transaction { @@ -41,7 +48,7 @@ func makeEvent(ts time.Time, id string) common.MapStr { } } -func validateBlobs(t testing.TB, store contentStore, expected []string, c cursor) cursor { +func validateBlobs(t testing.TB, store contentStore, expected []string, c checkpoint) checkpoint { assert.Len(t, store.events, len(expected)) for idx := range expected { id, err := getString(store.events[idx].Fields, fieldsPrefix+".Id") @@ -51,14 +58,14 @@ func validateBlobs(t testing.TB, store contentStore, expected []string, c cursor assert.Equal(t, expected[idx], id) } prev := c - baseLine := c.line + baseLine := c.Line for idx, id := range expected { ev := store.events[idx] - cursor, ok := ev.Private.(cursor) + cursor, ok := ev.Private.(checkpoint) if !assert.True(t, ok) { t.Fatal("no cursor for event id", id) } - assert.Equal(t, idx+1+baseLine, cursor.line) + assert.Equal(t, idx+1+baseLine, cursor.Line) assert.True(t, prev.Before(cursor)) prev = cursor } @@ -72,7 +79,7 @@ func TestContentBlob(t *testing.T) { Logger: logp.L(), Callback: store.onEvent, } - baseCursor := newCursor(stream{"myTenant", "contentype"}, time.Now()) + baseCursor := checkpoint{Timestamp: time.Now()} query := ContentBlob("http://test.localhost/", baseCursor, ctx) data := []common.MapStr{ makeEvent(now.Add(-time.Hour), "e1"), @@ -85,7 +92,7 @@ func TestContentBlob(t *testing.T) { next := f.BlobContent(t, query, data, "") assert.Nil(t, next) c := validateBlobs(t, store, expected, baseCursor) - assert.Equal(t, len(expected), c.line) + assert.Equal(t, len(expected), c.Line) } func TestContentBlobResumeToLine(t *testing.T) { @@ -93,9 +100,9 @@ func TestContentBlobResumeToLine(t *testing.T) { var store contentStore ctx := testConfig() ctx.Callback = store.onEvent - baseCursor := newCursor(stream{"myTenant", "contentype"}, time.Now()) + baseCursor := checkpoint{Timestamp: time.Now()} const skip = 3 - baseCursor.line = skip + baseCursor.Line = skip query := ContentBlob("http://test.localhost/", baseCursor, ctx).WithSkipLines(skip) data := []common.MapStr{ makeEvent(now.Add(-time.Hour), "e1"), @@ -108,7 +115,7 @@ func TestContentBlobResumeToLine(t *testing.T) { next := f.BlobContent(t, query, data, "") assert.Nil(t, next) c := validateBlobs(t, store, expected, baseCursor) - assert.Equal(t, len(expected), c.line-skip) + assert.Equal(t, len(expected), c.Line-skip) } func TestContentBlobPaged(t *testing.T) { @@ -118,7 +125,7 @@ func TestContentBlobPaged(t *testing.T) { Logger: logp.L(), Callback: store.onEvent, } - baseCursor := newCursor(stream{"myTenant", "contentype"}, time.Now()) + baseCursor := checkpoint{Timestamp: time.Now()} query := ContentBlob("http://test.localhost/", baseCursor, ctx) data := []common.MapStr{ makeEvent(now.Add(-time.Hour), "e1"), @@ -133,17 +140,17 @@ func TestContentBlobPaged(t *testing.T) { assert.NotNil(t, next) assert.IsType(t, paginator{}, next) c := validateBlobs(t, store, expected, baseCursor) - assert.Equal(t, 3, c.line) + assert.Equal(t, 3, c.Line) store.events = nil next = f.BlobContent(t, next, data[3:5], "http://test.localhost/page/3") assert.IsType(t, paginator{}, next) expected = []string{"e4", "e5"} c = validateBlobs(t, store, expected, c) - assert.Equal(t, 5, c.line) + assert.Equal(t, 5, c.Line) store.events = nil next = f.BlobContent(t, next, data[5:], "") assert.Nil(t, next) expected = []string{"e6"} c = validateBlobs(t, store, expected, c) - assert.Equal(t, 6, c.line) + assert.Equal(t, 6, c.Line) } diff --git a/x-pack/filebeat/input/o365audit/dates.go b/x-pack/filebeat/input/o365audit/dates.go index 5eb53d4d6de..848df7a0c22 100644 --- a/x-pack/filebeat/input/o365audit/dates.go +++ b/x-pack/filebeat/input/o365audit/dates.go @@ -6,10 +6,8 @@ package o365audit import ( "fmt" - "sort" "time" - "github.com/joeshaw/multierror" "github.com/pkg/errors" "github.com/elastic/beats/v7/libbeat/common" @@ -79,23 +77,6 @@ func getDateKey(m common.MapStr, key string, formats dateFormats) (t time.Time, return formats.Parse(str) } -// Sort a slice of maps by one of its keys parsed as a date in the given format(s). -func sortMapSliceByDate(s []common.MapStr, dateKey string, formats dateFormats) error { - var errs multierror.Errors - sort.Slice(s, func(i, j int) bool { - di, e1 := getDateKey(s[i], dateKey, formats) - dj, e2 := getDateKey(s[j], dateKey, formats) - if e1 != nil { - errs = append(errs, e1) - } - if e2 != nil { - errs = append(errs, e2) - } - return di.Before(dj) - }) - return errors.Wrapf(errs.Err(), "failed sorting by date key:%s", dateKey) -} - func inRange(d, maxLimit time.Duration) bool { if maxLimit < 0 { maxLimit = -maxLimit diff --git a/x-pack/filebeat/input/o365audit/input.go b/x-pack/filebeat/input/o365audit/input.go index 6dbaa3ab2f6..ceff10751de 100644 --- a/x-pack/filebeat/input/o365audit/input.go +++ b/x-pack/filebeat/input/o365audit/input.go @@ -6,274 +6,218 @@ package o365audit import ( "context" - "sync" "time" "github.com/Azure/go-autorest/autorest" "github.com/joeshaw/multierror" "github.com/pkg/errors" - "github.com/elastic/beats/v7/filebeat/channel" - "github.com/elastic/beats/v7/filebeat/input" + v2 "github.com/elastic/beats/v7/filebeat/input/v2" + cursor "github.com/elastic/beats/v7/filebeat/input/v2/input-cursor" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" - "github.com/elastic/beats/v7/libbeat/common/acker" - "github.com/elastic/beats/v7/libbeat/common/cfgwarn" "github.com/elastic/beats/v7/libbeat/common/useragent" + "github.com/elastic/beats/v7/libbeat/feature" "github.com/elastic/beats/v7/libbeat/logp" "github.com/elastic/beats/v7/x-pack/filebeat/input/o365audit/poll" + "github.com/elastic/go-concert/ctxtool" ) const ( - inputName = "o365audit" - fieldsPrefix = inputName + pluginName = "o365audit" + fieldsPrefix = pluginName ) -func init() { - if err := input.Register(inputName, NewInput); err != nil { - panic(errors.Wrapf(err, "unable to create %s input", inputName)) - } +type o365input struct { + config Config } -type o365input struct { - config Config - outlet channel.Outleter - storage *stateStorage - log *logp.Logger - pollers map[stream]*poll.Poller - cancel func() - ctx context.Context - wg sync.WaitGroup - runOnce sync.Once +// Stream represents an event stream. +type stream struct { + tenantID string + contentType string } type apiEnvironment struct { TenantID string ContentType string Config APIConfig - Callback func(beat.Event) bool + Callback func(event beat.Event, cursor interface{}) error Logger *logp.Logger Clock func() time.Time } -// NewInput creates a new o365audit input. -func NewInput( - cfg *common.Config, - connector channel.Connector, - inputContext input.Context, -) (inp input.Input, err error) { - cfgwarn.Beta("The %s input is beta", inputName) - inp, err = newInput(cfg, connector, inputContext) - return inp, errors.Wrap(err, inputName) +func Plugin(log *logp.Logger, store cursor.StateStore) v2.Plugin { + return v2.Plugin{ + Name: pluginName, + Stability: feature.Experimental, + Deprecated: false, + Info: "O365 logs", + Doc: "Collect logs from O365 service", + Manager: &cursor.InputManager{ + Logger: log, + StateStore: store, + Type: pluginName, + Configure: configure, + }, + } } -func newInput( - cfg *common.Config, - connector channel.Connector, - inputContext input.Context, -) (inp input.Input, err error) { +func configure(cfg *common.Config) ([]cursor.Source, cursor.Input, error) { config := defaultConfig() if err := cfg.Unpack(&config); err != nil { - return nil, errors.Wrap(err, "reading config") + return nil, nil, errors.Wrap(err, "reading config") } - log := logp.NewLogger(inputName) - - // TODO: Update with input v2 state. - storage := newStateStorage(noopPersister{}) - - var out channel.Outleter - out, err = connector.ConnectWith(cfg, beat.ClientConfig{ - ACKHandler: acker.ConnectionOnly( - acker.LastEventPrivateReporter(func(_ int, private interface{}) { - // Errors don't have a cursor. - if cursor, ok := private.(cursor); ok { - log.Debugf("ACKed cursor %+v", cursor) - if err := storage.Save(cursor); err != nil && err != errNoUpdate { - log.Errorf("Error saving state: %v", err) - } - } - }), - ), - }) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - if err != nil { - cancel() - } - }() - - pollers := make(map[stream]*poll.Poller) + var sources []cursor.Source for _, tenantID := range config.TenantID { - // MaxRequestsPerMinute limitation is per tenant. - delay := time.Duration(len(config.ContentType)) * time.Minute / time.Duration(config.API.MaxRequestsPerMinute) - auth, err := config.NewTokenProvider(tenantID) - if err != nil { - return nil, err - } - if _, err = auth.Token(); err != nil { - return nil, errors.Wrapf(err, "unable to acquire authentication token for tenant:%s", tenantID) - } for _, contentType := range config.ContentType { - key := stream{ + sources = append(sources, &stream{ tenantID: tenantID, contentType: contentType, - } - poller, err := poll.New( - poll.WithTokenProvider(auth), - poll.WithMinRequestInterval(delay), - poll.WithLogger(log.With("tenantID", tenantID, "contentType", contentType)), - poll.WithContext(ctx), - poll.WithRequestDecorator( - autorest.WithUserAgent(useragent.UserAgent("Filebeat-"+inputName)), - autorest.WithQueryParameters(common.MapStr{ - "publisherIdentifier": tenantID, - }), - ), - ) - if err != nil { - return nil, errors.Wrap(err, "failed to create API poller") - } - pollers[key] = poller + }) } } - return &o365input{ - config: config, - outlet: out, - storage: storage, - log: log, - pollers: pollers, - ctx: ctx, - cancel: cancel, - }, nil + return sources, &o365input{config: config}, nil } -// Run starts the o365input. Only has effect the first time it's called. -func (inp *o365input) Run() { - inp.runOnce.Do(inp.run) +func (s *stream) Name() string { + return s.tenantID + "::" + s.contentType } -func (inp *o365input) run() { - for stream, poller := range inp.pollers { - start := inp.loadLastLocation(stream) - inp.log.Infow("Start fetching events", - "cursor", start, - "tenantID", stream.tenantID, - "contentType", stream.contentType) - inp.runPoller(poller, start) +func (inp *o365input) Name() string { return pluginName } + +func (inp *o365input) Test(src cursor.Source, ctx v2.TestContext) error { + tenantID := src.(*stream).tenantID + auth, err := inp.config.NewTokenProvider(tenantID) + if err != nil { + return err + } + + if _, err := auth.Token(); err != nil { + return errors.Wrapf(err, "unable to acquire authentication token for tenant:%s", tenantID) } + + return nil } -func (inp *o365input) runPoller(poller *poll.Poller, start cursor) { - ctx := apiEnvironment{ - TenantID: start.tenantID, - ContentType: start.contentType, +func (inp *o365input) Run( + ctx v2.Context, + src cursor.Source, + cursor cursor.Cursor, + publisher cursor.Publisher, +) error { + stream := src.(*stream) + tenantID, contentType := stream.tenantID, stream.contentType + log := ctx.Logger.With("tenantID", tenantID, "contentType", contentType) + + tokenProvider, err := inp.config.NewTokenProvider(stream.tenantID) + if err != nil { + return err + } + + if _, err := tokenProvider.Token(); err != nil { + return errors.Wrapf(err, "unable to acquire authentication token for tenant:%s", stream.tenantID) + } + + config := &inp.config + + // MaxRequestsPerMinute limitation is per tenant. + delay := time.Duration(len(config.ContentType)) * time.Minute / time.Duration(config.API.MaxRequestsPerMinute) + + poller, err := poll.New( + poll.WithTokenProvider(tokenProvider), + poll.WithMinRequestInterval(delay), + poll.WithLogger(log), + poll.WithContext(ctxtool.FromCanceller(ctx.Cancelation)), + poll.WithRequestDecorator( + autorest.WithUserAgent(useragent.UserAgent("Filebeat-"+pluginName)), + autorest.WithQueryParameters(common.MapStr{ + "publisherIdentifier": tenantID, + }), + ), + ) + if err != nil { + return errors.Wrap(err, "failed to create API poller") + } + + start := initCheckpoint(log, cursor, config.API.MaxRetention) + action := makeListBlob(start, apiEnvironment{ + Logger: log, + TenantID: tenantID, + ContentType: contentType, Config: inp.config.API, - Callback: inp.reportEvent, - Logger: poller.Logger(), + Callback: publisher.Publish, Clock: time.Now, + }) + if start.Line > 0 { + action = action.WithStartTime(start.StartTime) } - inp.wg.Add(1) - go func() { - defer logp.Recover("panic in " + inputName + " runner.") - defer inp.wg.Done() - action := ListBlob(start, ctx) - // When resuming from a saved state, it's necessary to query for the - // same startTime that provided the last ACKed event. Otherwise there's - // the risk of observing partial blobs with different line counts, due to - // how the backend works. - if start.line > 0 { - action = action.WithStartTime(start.startTime) - } - if err := poller.Run(action); err != nil { - ctx.Logger.Errorf("API polling terminated with error: %v", err.Error()) - msg := common.MapStr{} - msg.Put("error.message", err.Error()) - msg.Put("event.kind", "pipeline_error") - event := beat.Event{ - Timestamp: time.Now(), - Fields: msg, - } - inp.reportEvent(event) - } - }() -} -func (inp *o365input) reportEvent(event beat.Event) bool { - return inp.outlet.OnEvent(event) -} - -// Stop terminates the o365 input. -func (inp *o365input) Stop() { - inp.log.Info("Stopping input " + inputName) - defer inp.log.Info(inputName + " stopped.") - defer inp.outlet.Close() - inp.cancel() + log.Infow("Start fetching events", "cursor", start) + err = poller.Run(action) + if err != nil && ctx.Cancelation.Err() != err && err != context.Canceled { + msg := common.MapStr{} + msg.Put("error.message", err.Error()) + msg.Put("event.kind", "pipeline_error") + event := beat.Event{ + Timestamp: time.Now(), + Fields: msg, + } + publisher.Publish(event, nil) + } + return err } -// Wait terminates the o365input and waits for all the pollers to finalize. -func (inp *o365input) Wait() { - inp.Stop() - inp.wg.Wait() -} +func initCheckpoint(log *logp.Logger, c cursor.Cursor, maxRetention time.Duration) checkpoint { + var cp checkpoint + retentionLimit := time.Now().UTC().Add(-maxRetention) -func (inp *o365input) loadLastLocation(key stream) cursor { - period := inp.config.API.MaxRetention - retentionLimit := time.Now().UTC().Add(-period) - cursor, err := inp.storage.Load(key) - if err != nil { - if err == errStateNotFound { - inp.log.Infof("No saved state found. Will fetch events for the last %v.", period.String()) - } else { - inp.log.Errorw("Error loading saved state. Will fetch all retained events. "+ + if c.IsNew() { + log.Infof("No saved state found. Will fetch events for the last %v.", maxRetention.String()) + cp.Timestamp = retentionLimit + } else { + err := c.Unpack(&cp) + if err != nil { + log.Errorw("Error loading saved state. Will fetch all retained events. "+ "Depending on max_retention, this can cause event loss or duplication.", "error", err, - "max_retention", period.String()) + "max_retention", maxRetention.String()) + cp.Timestamp = retentionLimit } - cursor.timestamp = retentionLimit } - if cursor.timestamp.Before(retentionLimit) { - inp.log.Warnw("Last update exceeds the retention limit. "+ + + if cp.Timestamp.Before(retentionLimit) { + log.Warnw("Last update exceeds the retention limit. "+ "Probably some events have been lost.", - "resume_since", cursor, + "resume_since", cp, "retention_limit", retentionLimit, - "max_retention", period.String()) + "max_retention", maxRetention.String()) // Due to API limitations, it's necessary to perform a query for each // day. These avoids performing a lot of queries that will return empty // when the input hasn't run in a long time. - cursor.timestamp = retentionLimit + cp.Timestamp = retentionLimit } - return cursor -} -var errTerminated = errors.New("terminated due to output closed") + return cp +} // Report returns an action that produces a beat.Event from the given object. func (env apiEnvironment) Report(doc common.MapStr, private interface{}) poll.Action { return func(poll.Enqueuer) error { - if !env.Callback(env.toBeatEvent(doc, private)) { - return errTerminated - } - return nil + return env.Callback(env.toBeatEvent(doc), private) } } // ReportAPIError returns an action that produces a beat.Event from an API error. func (env apiEnvironment) ReportAPIError(err apiError) poll.Action { return func(poll.Enqueuer) error { - if !env.Callback(err.ToBeatEvent()) { - return errTerminated - } - return nil + return env.Callback(err.ToBeatEvent(), nil) } } -func (env apiEnvironment) toBeatEvent(doc common.MapStr, private interface{}) beat.Event { +func (env apiEnvironment) toBeatEvent(doc common.MapStr) beat.Event { var errs multierror.Errors ts, err := getDateKey(doc, "CreationTime", apiDateFormats) if err != nil { @@ -285,7 +229,6 @@ func (env apiEnvironment) toBeatEvent(doc common.MapStr, private interface{}) be Fields: common.MapStr{ fieldsPrefix: doc, }, - Private: private, } if env.Config.SetIDFromAuditRecord { if id, err := getString(doc, "Id"); err == nil && len(id) > 0 { diff --git a/x-pack/filebeat/input/o365audit/listblobs.go b/x-pack/filebeat/input/o365audit/listblobs.go index 5be65a8d67d..ebdbddf79f6 100644 --- a/x-pack/filebeat/input/o365audit/listblobs.go +++ b/x-pack/filebeat/input/o365audit/listblobs.go @@ -20,20 +20,20 @@ import ( // listBlob is a poll.Transaction that handles the content/"blobs" list. type listBlob struct { - cursor cursor + cursor checkpoint startTime, endTime time.Time delay time.Duration env apiEnvironment } -// ListBlob creates a new poll.Transaction that lists content starting from +// makeListBlob creates a new poll.Transaction that lists content starting from // the given cursor position. -func ListBlob(cursor cursor, env apiEnvironment) listBlob { +func makeListBlob(cursor checkpoint, env apiEnvironment) listBlob { l := listBlob{ cursor: cursor, env: env, } - return l.adjustTimes(cursor.timestamp) + return l.adjustTimes(cursor.Timestamp) } // WithStartTime allows to alter the startTime of a listBlob. This is necessary @@ -57,8 +57,8 @@ func (l listBlob) adjustTimes(since time.Time) listBlob { var delay time.Duration if to.After(now) { since = now.Add(-l.env.Config.MaxQuerySize) - if since.Before(l.cursor.timestamp) { - since = l.cursor.timestamp + if since.Before(l.cursor.Timestamp) { + since = l.cursor.Timestamp } to = now delay = l.env.Config.PollInterval @@ -84,11 +84,11 @@ func (l listBlob) RequestDecorators() []autorest.PrepareDecorator { return []autorest.PrepareDecorator{ autorest.WithBaseURL(l.env.Config.Resource), autorest.WithPath("api/v1.0"), - autorest.WithPath(l.cursor.tenantID), + autorest.WithPath(l.env.TenantID), autorest.WithPath("activity/feed/subscriptions/content"), autorest.WithQueryParameters( map[string]interface{}{ - "contentType": l.cursor.contentType, + "contentType": l.env.ContentType, "startTime": l.startTime.Format(apiDateFormat), "endTime": l.endTime.Format(apiDateFormat), }), @@ -130,7 +130,7 @@ func (l listBlob) OnResponse(response *http.Response) (actions []poll.Action) { actions = append(actions, poll.Fetch( ContentBlob(entry.URI, l.cursor, l.env). WithID(entry.ID). - WithSkipLines(l.cursor.line))) + WithSkipLines(l.cursor.Line))) } else { l.env.Logger.Debugf("- skip blob date:%v id:%s", entry.Created.UTC(), entry.ID) } @@ -293,5 +293,5 @@ func getServerTimeDelta(response *http.Response) time.Duration { if err != nil { return 0 } - return serverDate.Sub(time.Now()) + return time.Until(serverDate) } diff --git a/x-pack/filebeat/input/o365audit/listblobs_test.go b/x-pack/filebeat/input/o365audit/listblobs_test.go index 148ee2273e8..0479ade9d6e 100644 --- a/x-pack/filebeat/input/o365audit/listblobs_test.go +++ b/x-pack/filebeat/input/o365audit/listblobs_test.go @@ -187,7 +187,7 @@ func testConfig() apiEnvironment { config := defaultConfig() return apiEnvironment{ Config: config.API, - Logger: logp.NewLogger(inputName + " test"), + Logger: logp.NewLogger(pluginName + " test"), Clock: func() time.Time { return now }, @@ -215,7 +215,9 @@ func TestListBlob(t *testing.T) { makeBlob(now.Add(-time.Hour*12), "today_1"), makeBlob(now.Add(-time.Hour*7), "today_2"), } - lb := ListBlob(newCursor(stream{"1234", contentType}, time.Time{}), ctx) + ctx.TenantID = "1234" + ctx.ContentType = contentType + lb := makeListBlob(checkpoint{}, ctx) var f fakePoll // 6 days ago blobs, next := f.SearchQuery(t, lb, db) @@ -282,7 +284,7 @@ func TestListBlob(t *testing.T) { blobs, next = f.SearchQuery(t, next.(listBlob), db) assert.Equal(t, []string{"live_4a", "live_4b", "live_4c"}, blobs) - blobs, next = f.SearchQuery(t, next.(listBlob), db) + blobs, _ = f.SearchQuery(t, next.(listBlob), db) assert.Empty(t, blobs) } @@ -297,7 +299,9 @@ func TestSubscriptionStart(t *testing.T) { return now }, } - lb := ListBlob(newCursor(stream{"1234", contentType}, time.Time{}), ctx) + ctx.TenantID = "1234" + ctx.ContentType = contentType + lb := makeListBlob(checkpoint{}, ctx) var f fakePoll s, l := f.subscriptionError(t, lb) assert.Equal(t, lb.cursor, l.cursor) @@ -309,7 +313,7 @@ func TestSubscriptionStart(t *testing.T) { assert.Equal(t, lb.env.ContentType, l.env.ContentType) assert.Equal(t, lb.env.Logger, l.env.Logger) assert.Equal(t, contentType, s.ContentType) - assert.Equal(t, lb.cursor.tenantID, s.TenantID) + assert.Equal(t, "1234", s.TenantID) } func TestPagination(t *testing.T) { @@ -324,7 +328,9 @@ func TestPagination(t *testing.T) { makeBlob(now.Add(-time.Hour*47+7*time.Nanosecond), "e7"), makeBlob(now.Add(-time.Hour*47+8*time.Nanosecond), "e8"), } - lb := ListBlob(newCursor(stream{"1234", contentType}, now.Add(-time.Hour*48)), ctx) + ctx.TenantID = "1234" + ctx.ContentType = contentType + lb := makeListBlob(checkpoint{Timestamp: now.Add(-time.Hour * 48)}, ctx) var f fakePoll // 6 days ago blobs, next := f.PagedSearchQuery(t, lb, db) @@ -369,7 +375,9 @@ func TestAdvance(t *testing.T) { ctx.Clock = func() time.Time { return *now } - lb := ListBlob(newCursor(stream{"tenant", contentType}, start), ctx) + ctx.TenantID = "tenant" + ctx.ContentType = contentType + lb := makeListBlob(checkpoint{Timestamp: start}, ctx) assert.Equal(t, start, lb.startTime) assert.Equal(t, start.Add(time.Hour*24), lb.endTime) assert.True(t, lb.endTime.Before(now1)) diff --git a/x-pack/filebeat/input/o365audit/state.go b/x-pack/filebeat/input/o365audit/state.go index 6992437ccab..5036ae63807 100644 --- a/x-pack/filebeat/input/o365audit/state.go +++ b/x-pack/filebeat/input/o365audit/state.go @@ -5,151 +5,59 @@ package o365audit import ( - "errors" "fmt" - "sync" "time" ) -var errNoUpdate = errors.New("new cursor doesn't preceed the existing cursor") - -// Stream represents an event stream. -type stream struct { - tenantID, contentType string -} - -// A cursor represents a point in time within an event stream +// A checkpoint represents a point in time within an event stream // that can be persisted and used to resume processing from that point. -type cursor struct { - // Identifier for the event stream. - stream - +type checkpoint struct { // createdTime for the last seen blob. - timestamp time.Time + Timestamp time.Time `struct:"timestamp"` + // index of object count (1...n) within a blob. - line int + Line int `struct:"line"` + // startTime used in the last list content query. // This is necessary to ensure that the same blobs are observed. - startTime time.Time + StartTime time.Time `struct:"start_time"` } -// Create a new cursor. -func newCursor(s stream, time time.Time) cursor { - return cursor{ - stream: s, - timestamp: time, - } +func (c *checkpoint) Before(other checkpoint) bool { + return c.Timestamp.Before(other.Timestamp) || (c.Timestamp.Equal(other.Timestamp) && c.Line < other.Line) } // TryAdvance advances the cursor to the given content blob // if it's not in the past. // Returns whether the given content needs to be processed. -func (c *cursor) TryAdvance(ct content) bool { - if ct.Created.Before(c.timestamp) { +func (c *checkpoint) TryAdvance(ct content) bool { + if ct.Created.Before(c.Timestamp) { return false } - if ct.Created.Equal(c.timestamp) { + if ct.Created.Equal(c.Timestamp) { // Only need to re-process the current content blob if we're // seeking to a line inside it. - return c.line > 0 + return c.Line > 0 } - c.timestamp = ct.Created - c.line = 0 + c.Timestamp = ct.Created + c.Line = 0 return true } -// Before allows to compare cursors to see if the new cursor needs to be persisted. -func (c cursor) Before(b cursor) bool { - if c.contentType != b.contentType || c.tenantID != b.tenantID { - panic(fmt.Sprintf("assertion failed: %+v vs %+v", c, b)) - } - - if c.timestamp.Before(b.timestamp) { - return true - } - if c.timestamp.Equal(b.timestamp) { - return c.line < b.line - } - return false -} - // WithStartTime allows to create a cursor with an updated startTime. -func (c cursor) WithStartTime(s time.Time) cursor { - c.startTime = s +func (c checkpoint) WithStartTime(s time.Time) checkpoint { + c.StartTime = s return c } // ForNextLine returns a new cursor for the next line within a blob. -func (c cursor) ForNextLine() cursor { - c.line++ +func (c checkpoint) ForNextLine() checkpoint { + c.Line++ return c } // String returns the printable representation of a cursor. -func (c cursor) String() string { - return fmt.Sprintf("cursor{tenantID:%s contentType:%s timestamp:%s line:%d start:%s}", - c.tenantID, c.contentType, c.timestamp, c.line, c.startTime) -} - -// ErrStateNotFound is the error returned by a statePersister when a cursor -// is not found for a stream. -var errStateNotFound = errors.New("no saved state found") - -type statePersister interface { - Load(key stream) (cursor, error) - Save(cursor cursor) error -} - -type stateStorage struct { - sync.Mutex - saved map[stream]cursor - persister statePersister -} - -func (s *stateStorage) Load(key stream) (cursor, error) { - s.Lock() - defer s.Unlock() - if st, found := s.saved[key]; found { - return st, nil - } - cur, err := s.persister.Load(key) - if err != nil { - return newCursor(key, time.Time{}), err - } - return cur, s.saveUnsafe(cur) -} - -func (s *stateStorage) Save(c cursor) error { - s.Lock() - defer s.Unlock() - return s.saveUnsafe(c) -} - -func (s *stateStorage) saveUnsafe(c cursor) error { - if prev, found := s.saved[c.stream]; found { - if !prev.Before(c) { - return errNoUpdate - } - } - if s.saved == nil { - s.saved = make(map[stream]cursor) - } - s.saved[c.stream] = c - return s.persister.Save(c) -} - -func newStateStorage(underlying statePersister) *stateStorage { - return &stateStorage{ - persister: underlying, - } -} - -type noopPersister struct{} - -func (p noopPersister) Load(key stream) (cursor, error) { - return cursor{}, errStateNotFound -} - -func (p noopPersister) Save(cursor cursor) error { - return nil +func (c checkpoint) String() string { + return fmt.Sprintf("cursor{timestamp:%s line:%d start:%s}", + c.Timestamp, c.Line, c.StartTime) } diff --git a/x-pack/filebeat/input/o365audit/state_test.go b/x-pack/filebeat/input/o365audit/state_test.go deleted file mode 100644 index 4e274578e83..00000000000 --- a/x-pack/filebeat/input/o365audit/state_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -package o365audit - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestNoopState(t *testing.T) { - const ( - ct = "content-type" - tn = "my_tenant" - ) - myStream := stream{tn, ct} - t.Run("new state", func(t *testing.T) { - st := newStateStorage(noopPersister{}) - cur, err := st.Load(myStream) - assert.Equal(t, errStateNotFound, err) - empty := newCursor(myStream, time.Time{}) - assert.Equal(t, empty, cur) - }) - t.Run("update state", func(t *testing.T) { - st := newStateStorage(noopPersister{}) - cur, err := st.Load(myStream) - assert.Equal(t, errStateNotFound, err) - advanced := cur.TryAdvance(content{ - Type: tn, - ID: "1234", - URI: "http://localhost.test/my_uri", - Created: time.Now(), - Expiration: time.Now().Add(time.Hour), - }) - assert.True(t, advanced) - err = st.Save(cur) - if !assert.NoError(t, err) { - t.Fatal(err) - } - saved, err := st.Load(myStream) - if !assert.NoError(t, err) { - t.Fatal(err) - } - assert.Equal(t, cur, saved) - }) - t.Run("forbid reversal", func(t *testing.T) { - st := newStateStorage(noopPersister{}) - cur := newCursor(myStream, time.Now()) - next := cur.ForNextLine() - err := st.Save(next) - if !assert.NoError(t, err) { - t.Fatal(err) - } - err = st.Save(cur) - assert.Equal(t, errNoUpdate, err) - }) - t.Run("multiple contexts", func(t *testing.T) { - st := newStateStorage(noopPersister{}) - cursors := []cursor{ - newCursor(myStream, time.Time{}), - newCursor(stream{"tenant2", ct}, time.Time{}), - newCursor(stream{ct, "bananas"}, time.Time{}), - } - for idx, cur := range cursors { - msg := fmt.Sprintf("idx:%d cur:%+v", idx, cur) - err := st.Save(cur) - if !assert.NoError(t, err, msg) { - t.Fatal(err) - } - } - for idx, cur := range cursors { - msg := fmt.Sprintf("idx:%d cur:%+v", idx, cur) - saved, err := st.Load(cur.stream) - if !assert.NoError(t, err, msg) { - t.Fatal(err) - } - assert.Equal(t, cur, saved) - } - for idx, cur := range cursors { - cur = cur.ForNextLine() - cursors[idx] = cur - msg := fmt.Sprintf("idx:%d cur:%+v", idx, cur) - err := st.Save(cur) - if !assert.NoError(t, err, msg) { - t.Fatal(err) - } - } - for idx, cur := range cursors { - msg := fmt.Sprintf("idx:%d cur:%+v", idx, cur) - saved, err := st.Load(cur.stream) - if !assert.NoError(t, err, msg) { - t.Fatal(err) - } - assert.Equal(t, cur, saved) - } - }) -}