diff --git a/event-handler/app.go b/event-handler/app.go index bdd195435..8c644e4cd 100644 --- a/event-handler/app.go +++ b/event-handler/app.go @@ -78,6 +78,9 @@ func (a *App) Run(ctx context.Context) error { a.SpawnCriticalJob(a.sessionEventsJob) <-a.Process.Done() + lastWindow := a.EventWatcher.getWindowStartTime() + a.State.SetLastWindowTime(&lastWindow) + return a.Err() } @@ -179,7 +182,18 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - t, err := NewTeleportEventsWatcher(ctx, a.Config, *startTime, latestCursor, latestID) + lastWindowTime, err := s.GetLastWindowTime() + if err != nil { + return trace.Wrap(err) + } + // if lastWindowTime is nil, set it to startTime + // lastWindowTime is used to track the last window of events ingested + // and is updated on exit + if lastWindowTime == nil { + lastWindowTime = startTime + } + + t, err := NewTeleportEventsWatcher(ctx, a.Config, *lastWindowTime, latestCursor, latestID) if err != nil { return trace.Wrap(err) } diff --git a/event-handler/state.go b/event-handler/state.go index 6ce1bae35..0f7df8cad 100644 --- a/event-handler/state.go +++ b/event-handler/state.go @@ -38,6 +38,9 @@ const ( // startTimeName is the start time variable name startTimeName = "start_time" + // windowTimeName is the start time of the last window. + windowTimeName = "window_time" + // cursorName is the cursor variable name cursorName = "cursor" @@ -120,11 +123,25 @@ func createStorageDir(c *StartCmdConfig) (string, error) { // GetStartTime gets current start time func (s *State) GetStartTime() (*time.Time, error) { - if !s.dv.Has(startTimeName) { + return s.getTimeKey(startTimeName) +} + +// SetStartTime sets current start time +func (s *State) SetStartTime(t *time.Time) error { + return s.setTimeKey(startTimeName, t) +} + +// GetLastWindowTime gets current start time +func (s *State) GetLastWindowTime() (*time.Time, error) { + return s.getTimeKey(windowTimeName) +} + +func (s *State) getTimeKey(keyName string) (*time.Time, error) { + if !s.dv.Has(keyName) { return nil, nil } - b, err := s.dv.Read(startTimeName) + b, err := s.dv.Read(keyName) if err != nil { return nil, trace.Wrap(err) } @@ -144,14 +161,18 @@ func (s *State) GetStartTime() (*time.Time, error) { return &t, nil } -// SetStartTime sets current start time -func (s *State) SetStartTime(t *time.Time) error { +func (s *State) setTimeKey(keyName string, t *time.Time) error { if t == nil { - return s.dv.Write(startTimeName, []byte("")) + return s.dv.Write(keyName, []byte("")) } v := t.Truncate(time.Second).Format(time.RFC3339) - return s.dv.Write(startTimeName, []byte(v)) + return s.dv.Write(keyName, []byte(v)) +} + +// SetLastWindowTime sets current start time of the last window used. +func (s *State) SetLastWindowTime(t *time.Time) error { + return s.setTimeKey(windowTimeName, t) } // GetCursor gets current cursor value diff --git a/event-handler/teleport_events_watcher.go b/event-handler/teleport_events_watcher.go index f715f3191..ca799a2a2 100644 --- a/event-handler/teleport_events_watcher.go +++ b/event-handler/teleport_events_watcher.go @@ -19,6 +19,7 @@ package main import ( "context" "fmt" + "sync" "time" "github.com/gravitational/teleport/api/client" @@ -69,15 +70,17 @@ type TeleportEventsWatcher struct { batch []*TeleportEvent // config is teleport config config *StartCmdConfig - // startTime is event time frame start - startTime time.Time + + // windowStartTime is event time frame start + windowStartTime time.Time + windowStartTimeMu sync.Mutex } // NewTeleportEventsWatcher builds Teleport client instance func NewTeleportEventsWatcher( ctx context.Context, c *StartCmdConfig, - startTime time.Time, + windowStartTime time.Time, cursor string, id string, ) (*TeleportEventsWatcher, error) { @@ -118,12 +121,12 @@ func NewTeleportEventsWatcher( } tc := TeleportEventsWatcher{ - client: teleportClient, - pos: -1, - cursor: cursor, - config: c, - id: id, - startTime: startTime, + client: teleportClient, + pos: -1, + cursor: cursor, + config: c, + id: id, + windowStartTime: windowStartTime, } return &tc, nil @@ -207,16 +210,45 @@ func (t *TeleportEventsWatcher) fetch(ctx context.Context) error { // getEvents calls Teleport client and loads events func (t *TeleportEventsWatcher) getEvents(ctx context.Context) ([]*auditlogpb.EventUnstructured, string, error) { - return t.client.SearchUnstructuredEvents( - ctx, - t.startTime, - time.Now().UTC(), - "default", - t.config.Types, - t.config.BatchSize, - types.EventOrderAscending, - t.cursor, - ) + rangeSplitByDay := splitRangeByDay(t.getWindowStartTime(), time.Now().UTC()) + for i := 1; i < len(rangeSplitByDay); i++ { + startTime := rangeSplitByDay[i-1] + endTime := rangeSplitByDay[i] + log.Debugf("Fetching events from %v to %v", startTime, endTime) + evts, cursor, err := t.client.SearchUnstructuredEvents( + ctx, + startTime, + endTime, + "default", + t.config.Types, + t.config.BatchSize, + types.EventOrderAscending, + t.cursor, + ) + if err != nil { + return nil, "", trace.Wrap(err) + } + + // if no events are found, the cursor is out of the range [startTime, endTime] + // and it's the last complete day, update start time to the next day. + if len(evts) == 0 && i < len(rangeSplitByDay)-1 { + log.Infof("No events found for the range %v to %v", startTime, endTime) + t.setWindowStartTime(endTime) + continue + } + // if any events are found, return them + return evts, cursor, nil + } + return nil, t.cursor, nil +} + +func splitRangeByDay(from, to time.Time) []time.Time { + // splitRangeByDay splits the range into days + var days []time.Time + for d := from; d.Before(to); d = d.AddDate(0, 0, 1) { + days = append(days, d) + } + return append(days, to) // add the last date } // pause sleeps for timeout seconds @@ -345,3 +377,15 @@ func (t *TeleportEventsWatcher) UpsertLock(ctx context.Context, user string, log return t.client.UpsertLock(ctx, lock) } + +func (t *TeleportEventsWatcher) getWindowStartTime() time.Time { + t.windowStartTimeMu.Lock() + defer t.windowStartTimeMu.Unlock() + return t.windowStartTime +} + +func (t *TeleportEventsWatcher) setWindowStartTime(time time.Time) { + t.windowStartTimeMu.Lock() + defer t.windowStartTimeMu.Unlock() + t.windowStartTime = time +} diff --git a/event-handler/teleport_events_watcher_test.go b/event-handler/teleport_events_watcher_test.go index b3242ac7c..63ab475f0 100644 --- a/event-handler/teleport_events_watcher_test.go +++ b/event-handler/teleport_events_watcher_test.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "strconv" "sync" "testing" @@ -28,7 +29,6 @@ import ( "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/trace" "github.com/stretchr/testify/require" - "golang.org/x/net/context" ) // mockTeleportEventWatcher is Teleport client mock @@ -121,6 +121,7 @@ func (c *mockTeleportEventWatcher) Close() error { } func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient) *TeleportEventsWatcher { + client := &TeleportEventsWatcher{ client: eventsClient, pos: -1, @@ -169,7 +170,7 @@ func TestEvents(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -178,14 +179,14 @@ func TestEvents(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } @@ -196,7 +197,7 @@ func TestEvents(t *testing.T) { select { case err := <-chErr: require.Error(t, mockErr, err) - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } @@ -204,14 +205,14 @@ func TestEvents(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -252,7 +253,7 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -263,7 +264,7 @@ func TestUpdatePage(t *testing.T) { t.Fatalf("Events channel should be open") case <-chErr: t.Fatalf("Events channel should be open") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): } // Update the event watcher with the full page of events an collect. @@ -279,7 +280,7 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -290,7 +291,7 @@ func TestUpdatePage(t *testing.T) { t.Fatalf("Events channel should be open") case <-chErr: t.Fatalf("Events channel should be open") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): } // Add another partial page and collect the events @@ -306,7 +307,7 @@ func TestUpdatePage(t *testing.T) { case err := <-chErr: t.Fatalf("Received unexpected error from error channel: %v", err) return - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -318,7 +319,7 @@ func TestUpdatePage(t *testing.T) { select { case err := <-chErr: require.Error(t, mockErr, err) - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } @@ -326,14 +327,14 @@ func TestUpdatePage(t *testing.T) { select { case _, ok := <-chEvt: require.False(t, ok, "Events channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } select { case _, ok := <-chErr: require.False(t, ok, "Error channel should be closed") - case <-time.After(100 * time.Millisecond): + case <-time.After(2 * time.Second): t.Fatalf("No events received within deadline") } } @@ -414,3 +415,62 @@ func TestValidateConfig(t *testing.T) { }) } } + +func Test_splitRangeByDay(t *testing.T) { + type args struct { + from time.Time + to time.Time + } + tests := []struct { + name string + args args + want []time.Time + }{ + { + name: "Same day", + args: args{ + from: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + to: time.Date(2021, 1, 1, 23, 59, 59, 0, time.UTC), + }, + want: []time.Time{ + time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 1, 23, 59, 59, 0, time.UTC), + }, + }, + { + name: "Two days", + args: args{ + from: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + to: time.Date(2021, 1, 2, 23, 59, 59, 0, time.UTC), + }, + want: []time.Time{ + time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 2, 23, 59, 59, 0, time.UTC), + }, + }, + { + name: "week", + args: args{ + from: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + to: time.Date(2021, 1, 7, 23, 59, 59, 0, time.UTC), + }, + want: []time.Time{ + time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 3, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 4, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 5, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 6, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 7, 0, 0, 0, 0, time.UTC), + time.Date(2021, 1, 7, 23, 59, 59, 0, time.UTC), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitRangeByDay(tt.args.from, tt.args.to) + require.Equal(t, tt.want, got) + }) + } +}