diff --git a/internal/pkg/agent/application/dispatcher/dispatcher.go b/internal/pkg/agent/application/dispatcher/dispatcher.go index 8628cf5a59f..700c7d35349 100644 --- a/internal/pkg/agent/application/dispatcher/dispatcher.go +++ b/internal/pkg/agent/application/dispatcher/dispatcher.go @@ -9,6 +9,7 @@ import ( "fmt" "reflect" "strings" + "time" "go.elastic.co/apm" @@ -21,6 +22,12 @@ import ( type actionHandlers map[string]actions.Handler +type priorityQueue interface { + Add(fleetapi.Action, int64) + DequeueActions() []fleetapi.Action + Save() error +} + // Dispatcher processes actions coming from fleet api. type Dispatcher interface { Dispatch(context.Context, acker.Acker, ...fleetapi.Action) error @@ -31,10 +38,11 @@ type ActionDispatcher struct { log *logger.Logger handlers actionHandlers def actions.Handler + queue priorityQueue } // New creates a new action dispatcher. -func New(log *logger.Logger, def actions.Handler) (*ActionDispatcher, error) { +func New(log *logger.Logger, def actions.Handler, queue priorityQueue) (*ActionDispatcher, error) { var err error if log == nil { log, err = logger.New("action_dispatcher", false) @@ -51,6 +59,7 @@ func New(log *logger.Logger, def actions.Handler) (*ActionDispatcher, error) { log: log, handlers: make(actionHandlers), def: def, + queue: queue, }, nil } @@ -86,6 +95,17 @@ func (ad *ActionDispatcher) Dispatch(ctx context.Context, acker acker.Acker, act span.End() }() + actions = ad.queueScheduledActions(actions) + actions = ad.dispatchCancelActions(ctx, actions, acker) + queued, expired := ad.gatherQueuedActions(time.Now().UTC()) + ad.log.Debugf("Gathered %d actions from queue, %d actions expired", len(queued), len(expired)) + ad.log.Debugf("Expired actions: %v", expired) + actions = append(actions, queued...) + + if err := ad.queue.Save(); err != nil { + ad.log.Errorf("failed to persist action_queue: %v", err) + } + if len(actions) == 0 { ad.log.Debug("No action to dispatch") return nil @@ -128,3 +148,52 @@ func detectTypes(actions []fleetapi.Action) []string { } return str } + +// queueScheduledActions will add any action in actions with a valid start time to the queue and return the rest. +// start time to current time comparisons are purposefully not made in case of cancel actions. +func (ad *ActionDispatcher) queueScheduledActions(input []fleetapi.Action) []fleetapi.Action { + actions := make([]fleetapi.Action, 0, len(input)) + for _, action := range input { + start, err := action.StartTime() + if err == nil { + ad.log.Debugf("Adding action id: %s to queue.", action.ID()) + ad.queue.Add(action, start.Unix()) + continue + } + if !errors.Is(err, fleetapi.ErrNoStartTime) { + ad.log.Warnf("Issue gathering start time from action id %s: %v", action.ID(), err) + } + actions = append(actions, action) + } + return actions +} + +// dispatchCancelActions will separate and dispatch any cancel actions from the actions list and return the rest of the list. +// cancel actions are dispatched seperatly as they may remove items from the queue. +func (ad *ActionDispatcher) dispatchCancelActions(ctx context.Context, actions []fleetapi.Action, acker acker.Acker) []fleetapi.Action { + for i := len(actions) - 1; i >= 0; i-- { + action := actions[i] + // If it is a cancel action, remove from list and dispatch + if action.Type() == fleetapi.ActionTypeCancel { + actions = append(actions[:i], actions[i+1:]...) + if err := ad.dispatchAction(ctx, action, acker); err != nil { + ad.log.Errorf("Unable to dispatch cancel action id %s: %v", action.ID(), err) + } + } + } + return actions +} + +// gatherQueuedActions will dequeue actions from the action queue and separate those that have already expired. +func (ad *ActionDispatcher) gatherQueuedActions(ts time.Time) (queued, expired []fleetapi.Action) { + actions := ad.queue.DequeueActions() + for _, action := range actions { + exp, _ := action.Expiration() + if ts.After(exp) { + expired = append(expired, action) + continue + } + queued = append(queued, action) + } + return queued, expired +} diff --git a/internal/pkg/agent/application/dispatcher/dispatcher_test.go b/internal/pkg/agent/application/dispatcher/dispatcher_test.go index 4c19779688a..d140033655c 100644 --- a/internal/pkg/agent/application/dispatcher/dispatcher_test.go +++ b/internal/pkg/agent/application/dispatcher/dispatcher_test.go @@ -58,13 +58,34 @@ func (m *mockAction) Expiration() (time.Time, error) { return args.Get(0).(time.Time), args.Error(1) } +type mockQueue struct { + mock.Mock +} + +func (m *mockQueue) Add(action fleetapi.Action, n int64) { + m.Called(action, n) +} + +func (m *mockQueue) DequeueActions() []fleetapi.Action { + args := m.Called() + return args.Get(0).([]fleetapi.Action) +} + +func (m *mockQueue) Save() error { + args := m.Called() + return args.Error(0) +} + func TestActionDispatcher(t *testing.T) { ack := noop.New() t.Run("Success to dispatch multiples events", func(t *testing.T) { ctx := context.Background() def := &mockHandler{} - d, err := New(nil, def) + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + d, err := New(nil, def, queue) require.NoError(t, err) success1 := &mockHandler{} @@ -76,7 +97,13 @@ func TestActionDispatcher(t *testing.T) { require.NoError(t, err) action1 := &mockAction{} + action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action1.On("Type").Return("action") + action1.On("ID").Return("id") action2 := &mockOtherAction{} + action2.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action2.On("Type").Return("action") + action2.On("ID").Return("id") // TODO better matching for actions success1.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() @@ -88,20 +115,28 @@ func TestActionDispatcher(t *testing.T) { success1.AssertExpectations(t) success2.AssertExpectations(t) def.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything, mock.Anything) + queue.AssertExpectations(t) }) t.Run("Unknown action are caught by the unknown handler", func(t *testing.T) { def := &mockHandler{} def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() ctx := context.Background() - d, err := New(nil, def) + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + d, err := New(nil, def, queue) require.NoError(t, err) action := &mockUnknownAction{} + action.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action.On("Type").Return("action") + action.On("ID").Return("id") err = d.Dispatch(ctx, ack, action) require.NoError(t, err) def.AssertExpectations(t) + queue.AssertExpectations(t) }) t.Run("Could not register two handlers on the same action", func(t *testing.T) { @@ -109,7 +144,8 @@ func TestActionDispatcher(t *testing.T) { success2 := &mockHandler{} def := &mockHandler{} - d, err := New(nil, def) + queue := &mockQueue{} + d, err := New(nil, def, queue) require.NoError(t, err) err = d.Register(&mockAction{}, success1) @@ -117,5 +153,107 @@ func TestActionDispatcher(t *testing.T) { err = d.Register(&mockAction{}, success2) require.Error(t, err) + queue.AssertExpectations(t) + }) + + t.Run("Dispatched action is queued", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + queue.On("Add", mock.Anything, mock.Anything).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockAction{}, def) + require.NoError(t, err) + + action1 := &mockAction{} + action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action1.On("Type").Return("action") + action1.On("ID").Return("id") + action2 := &mockAction{} + action2.On("StartTime").Return(time.Now().Add(time.Hour), nil) + action2.On("Type").Return("action") + action2.On("ID").Return("id") + + err = d.Dispatch(context.Background(), ack, action1, action2) + require.NoError(t, err) + def.AssertExpectations(t) + queue.AssertExpectations(t) + }) + + t.Run("Cancel queued action", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockAction{}, def) + require.NoError(t, err) + + action := &mockAction{} + action.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action.On("Type").Return(fleetapi.ActionTypeCancel) + action.On("ID").Return("id") + + err = d.Dispatch(context.Background(), ack, action) + require.NoError(t, err) + def.AssertExpectations(t) + queue.AssertExpectations(t) + }) + + t.Run("Retrieve actions from queue", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + + action1 := &mockAction{} + action1.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action1.On("Expiration").Return(time.Now().Add(time.Hour), fleetapi.ErrNoStartTime) + action1.On("Type").Return(fleetapi.ActionTypeCancel) + action1.On("ID").Return("id") + + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{action1}).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockAction{}, def) + require.NoError(t, err) + + action2 := &mockAction{} + action2.On("StartTime").Return(time.Time{}, fleetapi.ErrNoStartTime) + action2.On("Type").Return(fleetapi.ActionTypeCancel) + action2.On("ID").Return("id") + + err = d.Dispatch(context.Background(), ack, action2) + require.NoError(t, err) + def.AssertExpectations(t) + queue.AssertExpectations(t) + }) + + t.Run("Retrieve no actions from queue", func(t *testing.T) { + def := &mockHandler{} + def.On("Handle", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + queue := &mockQueue{} + queue.On("Save").Return(nil).Once() + queue.On("DequeueActions").Return([]fleetapi.Action{}).Once() + + d, err := New(nil, def, queue) + require.NoError(t, err) + err = d.Register(&mockAction{}, def) + require.NoError(t, err) + + err = d.Dispatch(context.Background(), ack) + require.NoError(t, err) + def.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything, mock.Anything) }) } diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go index fe5028b0fce..38fad92057c 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go @@ -6,7 +6,6 @@ package fleet import ( "context" - stderr "errors" "fmt" "time" @@ -62,14 +61,6 @@ type stateStore interface { AckToken() string SetAckToken(ackToken string) Save() error - SetQueue([]fleetapi.Action) - Actions() []fleetapi.Action -} - -type actionQueue interface { - Add(fleetapi.Action, int64) - DequeueActions() []fleetapi.Action - Cancel(string) int Actions() []fleetapi.Action } @@ -84,7 +75,6 @@ type fleetGateway struct { unauthCounter int stateFetcher coordinator.StateFetcher stateStore stateStore - queue actionQueue errCh chan error } @@ -97,7 +87,6 @@ func New( acker acker.Acker, stateFetcher coordinator.StateFetcher, stateStore stateStore, - queue actionQueue, ) (gateway.FleetGateway, error) { scheduler := scheduler.NewPeriodicJitter(defaultGatewaySettings.Duration, defaultGatewaySettings.Jitter) @@ -111,7 +100,6 @@ func New( acker, stateFetcher, stateStore, - queue, ) } @@ -125,7 +113,6 @@ func newFleetGatewayWithScheduler( acker acker.Acker, stateFetcher coordinator.StateFetcher, stateStore stateStore, - queue actionQueue, ) (gateway.FleetGateway, error) { return &fleetGateway{ log: log, @@ -137,7 +124,6 @@ func newFleetGatewayWithScheduler( acker: acker, stateFetcher: stateFetcher, stateStore: stateStore, - queue: queue, errCh: make(chan error), }, nil } @@ -163,7 +149,7 @@ func (f *fleetGateway) Run(ctx context.Context) error { f.scheduler.Stop() f.log.Info("Fleet gateway stopped") return ctx.Err() - case ts := <-f.scheduler.WaitTick(): + case <-f.scheduler.WaitTick(): f.log.Debug("FleetGateway calling Checkin API") // Execute the checkin call and for any errors returned by the fleet-server API @@ -174,28 +160,11 @@ func (f *fleetGateway) Run(ctx context.Context) error { continue } - actions := f.queueScheduledActions(resp.Actions) - actions, err = f.dispatchCancelActions(actions) - if err != nil { - f.log.Error(err.Error()) - } - - queued, expired := f.gatherQueuedActions(ts.UTC()) - f.log.Debugf("Gathered %d actions from queue, %d actions expired", len(queued), len(expired)) - f.log.Debugf("Expired actions: %v", expired) - - actions = append(actions, queued...) + actions := make([]fleetapi.Action, len(resp.Actions)) + copy(actions, resp.Actions) // Persist state hadErr := false - f.stateStore.SetQueue(f.queue.Actions()) - if err := f.stateStore.Save(); err != nil { - err = fmt.Errorf("failed to persist action_queue, error: %w", err) - f.log.Error(err) - f.errCh <- err - hadErr = true - } - if err := f.dispatcher.Dispatch(context.Background(), f.acker, actions...); err != nil { err = fmt.Errorf("failed to dispatch actions, error: %w", err) f.log.Error(err) @@ -216,60 +185,6 @@ func (f *fleetGateway) Errors() <-chan error { return f.errCh } -// queueScheduledActions will add any action in actions with a valid start time to the queue and return the rest. -// start time to current time comparisons are purposefully not made in case of cancel actions. -func (f *fleetGateway) queueScheduledActions(input fleetapi.Actions) []fleetapi.Action { - actions := make([]fleetapi.Action, 0, len(input)) - for _, action := range input { - start, err := action.StartTime() - if err == nil { - f.log.Debugf("Adding action id: %s to queue.", action.ID()) - f.queue.Add(action, start.Unix()) - continue - } - if !stderr.Is(err, fleetapi.ErrNoStartTime) { - f.log.Warnf("Issue gathering start time from action id %s: %v", action.ID(), err) - } - actions = append(actions, action) - } - return actions -} - -// dispatchCancelActions will separate and dispatch any cancel actions from the actions list and return the rest of the list. -// cancel actions are dispatched seperatly as they may remove items from the queue. -func (f *fleetGateway) dispatchCancelActions(actions []fleetapi.Action) ([]fleetapi.Action, error) { - // separate cancel actions from the actions list - cancelActions := make([]fleetapi.Action, 0, len(actions)) - for i := len(actions) - 1; i >= 0; i-- { - action := actions[i] - if action.Type() == fleetapi.ActionTypeCancel { - cancelActions = append(cancelActions, action) - actions = append(actions[:i], actions[i+1:]...) - } - } - // Dispatch cancel actions - if len(cancelActions) > 0 { - if err := f.dispatcher.Dispatch(context.Background(), f.acker, cancelActions...); err != nil { - return actions, fmt.Errorf("failed to dispatch cancel actions: %w", err) - } - } - return actions, nil -} - -// gatherQueuedActions will dequeue actions from the action queue and separate those that have already expired. -func (f *fleetGateway) gatherQueuedActions(ts time.Time) (queued, expired []fleetapi.Action) { - actions := f.queue.DequeueActions() - for _, action := range actions { - exp, _ := action.Expiration() - if ts.After(exp) { - expired = append(expired, action) - continue - } - queued = append(queued, action) - } - return queued, expired -} - func (f *fleetGateway) doExecute(ctx context.Context, bo backoff.Backoff) (*fleetapi.CheckinResponse, error) { bo.Reset() diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go index f7ba6ec961d..49c05112e18 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway_test.go @@ -19,7 +19,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator" @@ -110,29 +109,6 @@ func newTestingDispatcher() *testingDispatcher { return &testingDispatcher{received: make(chan struct{}, 1)} } -type mockQueue struct { - mock.Mock -} - -func (m *mockQueue) Add(action fleetapi.Action, n int64) { - m.Called(action, n) -} - -func (m *mockQueue) DequeueActions() []fleetapi.Action { - args := m.Called() - return args.Get(0).([]fleetapi.Action) -} - -func (m *mockQueue) Cancel(id string) int { - args := m.Called(id) - return args.Int(0) -} - -func (m *mockQueue) Actions() []fleetapi.Action { - args := m.Called() - return args.Get(0).([]fleetapi.Action) -} - type withGatewayFunc func(*testing.T, gateway.FleetGateway, *testingClient, *testingDispatcher, *scheduler.Stepper) func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGatewayFunc) func(t *testing.T) { @@ -145,10 +121,6 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat stateStore := newStateStore(t, log) - queue := &mockQueue{} - queue.On("DequeueActions").Return([]fleetapi.Action{}) - queue.On("Actions").Return([]fleetapi.Action{}) - gateway, err := newFleetGatewayWithScheduler( log, settings, @@ -159,7 +131,6 @@ func withGateway(agentInfo agentInfo, settings *fleetGatewaySettings, fn withGat noop.New(), &emptyStateFetcher{}, stateStore, - queue, ) require.NoError(t, err) @@ -290,10 +261,6 @@ func TestFleetGateway(t *testing.T) { log, _ := logger.New("tst", false) stateStore := newStateStore(t, log) - queue := &mockQueue{} - queue.On("DequeueActions").Return([]fleetapi.Action{}) - queue.On("Actions").Return([]fleetapi.Action{}) - gateway, err := newFleetGatewayWithScheduler( log, settings, @@ -304,7 +271,6 @@ func TestFleetGateway(t *testing.T) { noop.New(), &emptyStateFetcher{}, stateStore, - queue, ) require.NoError(t, err) @@ -337,244 +303,6 @@ func TestFleetGateway(t *testing.T) { require.NoError(t, err) }) - t.Run("queue action from checkin", func(t *testing.T) { - scheduler := scheduler.NewStepper() - client := newTestingClient() - dispatcher := newTestingDispatcher() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - log, _ := logger.New("tst", false) - stateStore := newStateStore(t, log) - - ts := time.Now().UTC().Round(time.Second) - queue := &mockQueue{} - queue.On("Add", mock.Anything, ts.Add(time.Hour).Unix()).Return().Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}) - queue.On("Actions").Return([]fleetapi.Action{}) - - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - dispatcher, - scheduler, - noop.New(), - &emptyStateFetcher{}, - stateStore, - queue, - ) - require.NoError(t, err) - - waitFn := ackSeq( - client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { - resp := wrapStrToResp(http.StatusOK, fmt.Sprintf(`{"actions": [{ - "type": "UPGRADE", - "id": "id1", - "start_time": "%s", - "expiration": "%s", - "data": { - "version": "1.2.3" - } - }]}`, - ts.Add(time.Hour).Format(time.RFC3339), - ts.Add(2*time.Hour).Format(time.RFC3339), - )) - return resp, nil - }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 0, len(actions)) - return nil - }), - ) - - errCh := runFleetGateway(ctx, gateway) - - scheduler.Next() - waitFn() - queue.AssertExpectations(t) - - cancel() - err = <-errCh - require.NoError(t, err) - }) - - t.Run("run action from queue", func(t *testing.T) { - scheduler := scheduler.NewStepper() - client := newTestingClient() - dispatcher := newTestingDispatcher() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - log, _ := logger.New("tst", false) - stateStore := newStateStore(t, log) - - ts := time.Now().UTC().Round(time.Second) - queue := &mockQueue{} - queue.On("DequeueActions").Return([]fleetapi.Action{&fleetapi.ActionUpgrade{ActionID: "id1", ActionType: "UPGRADE", ActionStartTime: ts.Add(-1 * time.Hour).Format(time.RFC3339), ActionExpiration: ts.Add(time.Hour).Format(time.RFC3339)}}).Once() - queue.On("Actions").Return([]fleetapi.Action{}) - - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - dispatcher, - scheduler, - noop.New(), - &emptyStateFetcher{}, - stateStore, - queue, - ) - require.NoError(t, err) - - waitFn := ackSeq( - client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { - resp := wrapStrToResp(http.StatusOK, `{"actions": []}`) - return resp, nil - }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 1, len(actions)) - return nil - }), - ) - - errCh := runFleetGateway(ctx, gateway) - - scheduler.Next() - waitFn() - queue.AssertExpectations(t) - - cancel() - err = <-errCh - require.NoError(t, err) - }) - - t.Run("discard expired action from queue", func(t *testing.T) { - scheduler := scheduler.NewStepper() - client := newTestingClient() - dispatcher := newTestingDispatcher() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - log, _ := logger.New("tst", false) - stateStore := newStateStore(t, log) - - ts := time.Now().UTC().Round(time.Second) - queue := &mockQueue{} - queue.On("DequeueActions").Return([]fleetapi.Action{&fleetapi.ActionUpgrade{ActionID: "id1", ActionType: "UPGRADE", ActionStartTime: ts.Add(-2 * time.Hour).Format(time.RFC3339), ActionExpiration: ts.Add(-1 * time.Hour).Format(time.RFC3339)}}).Once() - queue.On("Actions").Return([]fleetapi.Action{}) - - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - dispatcher, - scheduler, - noop.New(), - &emptyStateFetcher{}, - stateStore, - queue, - ) - require.NoError(t, err) - - waitFn := ackSeq( - client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { - resp := wrapStrToResp(http.StatusOK, `{"actions": []}`) - return resp, nil - }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - require.Equal(t, 0, len(actions)) - return nil - }), - ) - - errCh := runFleetGateway(ctx, gateway) - - scheduler.Next() - waitFn() - queue.AssertExpectations(t) - - cancel() - err = <-errCh - require.NoError(t, err) - }) - - t.Run("cancel action from checkin", func(t *testing.T) { - scheduler := scheduler.NewStepper() - client := newTestingClient() - dispatcher := newTestingDispatcher() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - log, _ := logger.New("tst", false) - stateStore := newStateStore(t, log) - - ts := time.Now().UTC().Round(time.Second) - queue := &mockQueue{} - queue.On("Add", mock.Anything, ts.Add(-1*time.Hour).Unix()).Return().Once() - queue.On("DequeueActions").Return([]fleetapi.Action{}) - queue.On("Actions").Return([]fleetapi.Action{}).Maybe() // this test seems flakey if we check for this call - // queue.Cancel does not need to be mocked here as it is ran in the cancel action dispatcher. - - gateway, err := newFleetGatewayWithScheduler( - log, - settings, - agentInfo, - client, - dispatcher, - scheduler, - noop.New(), - &emptyStateFetcher{}, - stateStore, - queue, - ) - require.NoError(t, err) - - waitFn := ackSeq( - client.Answer(func(headers http.Header, body io.Reader) (*http.Response, error) { - resp := wrapStrToResp(http.StatusOK, fmt.Sprintf(`{"actions": [{ - "type": "UPGRADE", - "id": "id1", - "start_time": "%s", - "expiration": "%s", - "data": { - "version": "1.2.3" - } - }, { - "type": "CANCEL", - "id": "id2", - "data": { - "target_id": "id1" - } - }]}`, - ts.Add(-1*time.Hour).Format(time.RFC3339), - ts.Add(2*time.Hour).Format(time.RFC3339), - )) - return resp, nil - }), - dispatcher.Answer(func(actions ...fleetapi.Action) error { - return nil - }), - ) - - errCh := runFleetGateway(ctx, gateway) - - scheduler.Next() - waitFn() - queue.AssertExpectations(t) - - cancel() - err = <-errCh - require.NoError(t, err) - }) - t.Run("Test the wait loop is interruptible", func(t *testing.T) { // 20mins is the double of the base timeout values for golang test suites. // If we cannot interrupt we will timeout. @@ -588,10 +316,6 @@ func TestFleetGateway(t *testing.T) { log, _ := logger.New("tst", false) stateStore := newStateStore(t, log) - queue := &mockQueue{} - queue.On("DequeueActions").Return([]fleetapi.Action{}) - queue.On("Actions").Return([]fleetapi.Action{}) - gateway, err := newFleetGatewayWithScheduler( log, &fleetGatewaySettings{ @@ -605,7 +329,6 @@ func TestFleetGateway(t *testing.T) { noop.New(), &emptyStateFetcher{}, stateStore, - queue, ) require.NoError(t, err) diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index 8abeab60eba..cd477753a1f 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -67,7 +67,7 @@ func newManagedConfigManager( return nil, errors.New(err, fmt.Sprintf("fail to read action store '%s'", paths.AgentActionStoreFile())) } - actionQueue, err := queue.NewActionQueue(stateStore.Queue()) + actionQueue, err := queue.NewActionQueue(stateStore.Queue(), stateStore) if err != nil { return nil, fmt.Errorf("unable to initialize action queue: %w", err) } @@ -170,7 +170,6 @@ func (m *managedConfigManager) Run(ctx context.Context) error { actionAcker, m.coord, m.stateStore, - m.actionQueue, ) if err != nil { return err @@ -281,7 +280,7 @@ func fleetServerRunning(state runtime.ComponentState) bool { } func newManagedActionDispatcher(m *managedConfigManager, canceller context.CancelFunc) (*dispatcher.ActionDispatcher, *handlers.PolicyChange, error) { - actionDispatcher, err := dispatcher.New(m.log, handlers.NewDefault(m.log)) + actionDispatcher, err := dispatcher.New(m.log, handlers.NewDefault(m.log), m.actionQueue) if err != nil { return nil, nil, err } diff --git a/internal/pkg/queue/actionqueue.go b/internal/pkg/queue/actionqueue.go index 671291639a2..0f3a2c20ffc 100644 --- a/internal/pkg/queue/actionqueue.go +++ b/internal/pkg/queue/actionqueue.go @@ -11,6 +11,12 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/fleetapi" ) +// saver is an the minimal interface needed for state storage. +type saver interface { + SetQueue(a []fleetapi.Action) + Save() error +} + // item tracks an action in the action queue type item struct { action fleetapi.Action @@ -18,23 +24,28 @@ type item struct { index int } -// ActionQueue uses the standard library's container/heap to implement a priority queue -// This queue should not be indexed directly, instead use the provided Add, DequeueActions, or Cancel methods to add or remove items -// Actions() is indended to get the list of actions in the queue for serialization. -type ActionQueue []*item +// queue uses the standard library's container/heap to implement a priority queue +// This queue should not be used directly, instead the exported ActionQueue should be used. +type queue []*item + +// ActionQueue is a priority queue with the ability to persist to disk. +type ActionQueue struct { + q *queue + s saver +} // Len returns the length of the queue -func (q ActionQueue) Len() int { +func (q queue) Len() int { return len(q) } // Less will determine if item i's priority is less then item j's -func (q ActionQueue) Less(i, j int) bool { +func (q queue) Less(i, j int) bool { return q[i].priority < q[j].priority } // Swap will swap the items at index i and j -func (q ActionQueue) Swap(i, j int) { +func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] q[i].index = i q[j].index = j @@ -42,7 +53,7 @@ func (q ActionQueue) Swap(i, j int) { // Push will add x as an item to the queue // When using the queue, the Add method should be used instead. -func (q *ActionQueue) Push(x interface{}) { +func (q *queue) Push(x interface{}) { n := len(*q) e := x.(*item) //nolint:errcheck // should be an *item e.index = n @@ -51,7 +62,7 @@ func (q *ActionQueue) Push(x interface{}) { // Pop will return the last item from the queue // When using the queue, DequeueActions should be used instead -func (q *ActionQueue) Pop() interface{} { +func (q *queue) Pop() interface{} { old := *q n := len(old) e := old[n-1] @@ -61,10 +72,10 @@ func (q *ActionQueue) Pop() interface{} { return e } -// NewActionQueue creates a new ActionQueue initialized with the passed actions. +// newQueue creates a new priority queue using container/heap. // Will return an error if StartTime fails for any action. -func NewActionQueue(actions []fleetapi.Action) (*ActionQueue, error) { - q := make(ActionQueue, len(actions)) +func newQueue(actions []fleetapi.Action) (*queue, error) { + q := make(queue, len(actions)) for i, action := range actions { ts, err := action.StartTime() if err != nil { @@ -80,6 +91,18 @@ func NewActionQueue(actions []fleetapi.Action) (*ActionQueue, error) { return &q, nil } +// NewActionQueue creates a new queue with the passed actions using the persistor for state storage. +func NewActionQueue(actions []fleetapi.Action, s saver) (*ActionQueue, error) { + q, err := newQueue(actions) + if err != nil { + return nil, err + } + return &ActionQueue{ + q: q, + s: s, + }, nil +} + // Add will add an action to the queue with the associated priority. // The priority is meant to be the start-time of the action as a unix epoch time. // Complexity: O(log n) @@ -88,7 +111,7 @@ func (q *ActionQueue) Add(action fleetapi.Action, priority int64) { action: action, priority: priority, } - heap.Push(q, e) + heap.Push(q.q, e) } // DequeueActions will dequeue all actions that have a priority less then time.Now(). @@ -96,11 +119,11 @@ func (q *ActionQueue) Add(action fleetapi.Action, priority int64) { func (q *ActionQueue) DequeueActions() []fleetapi.Action { ts := time.Now().Unix() actions := make([]fleetapi.Action, 0) - for q.Len() != 0 { - if (*q)[0].priority > ts { + for q.q.Len() != 0 { + if (*q.q)[0].priority > ts { break } - item := heap.Pop(q).(*item) //nolint:errcheck // should be an *item + item := heap.Pop(q.q).(*item) //nolint:errcheck // should be an *item actions = append(actions, item.action) } return actions @@ -110,22 +133,28 @@ func (q *ActionQueue) DequeueActions() []fleetapi.Action { // Complexity: O(n*log n) func (q *ActionQueue) Cancel(actionID string) int { items := make([]*item, 0) - for _, item := range *q { + for _, item := range *q.q { if item.action.ID() == actionID { items = append(items, item) } } for _, item := range items { - heap.Remove(q, item.index) + heap.Remove(q.q, item.index) } return len(items) } // Actions returns all actions in the queue, item 0 is garunteed to be the min, the rest may not be in sorted order. func (q *ActionQueue) Actions() []fleetapi.Action { - actions := make([]fleetapi.Action, q.Len()) - for i, item := range *q { + actions := make([]fleetapi.Action, q.q.Len()) + for i, item := range *q.q { actions[i] = item.action } return actions } + +// Save persists the queue to disk. +func (q *ActionQueue) Save() error { + q.s.SetQueue(q.Actions()) + return q.s.Save() +} diff --git a/internal/pkg/queue/actionqueue_test.go b/internal/pkg/queue/actionqueue_test.go index 1c1e1959a9f..d951f855737 100644 --- a/internal/pkg/queue/actionqueue_test.go +++ b/internal/pkg/queue/actionqueue_test.go @@ -47,7 +47,20 @@ func (m *mockAction) Expiration() (time.Time, error) { return args.Get(0).(time.Time), args.Error(1) } -func TestNewActionQueue(t *testing.T) { +type mockPersistor struct { + mock.Mock +} + +func (m *mockPersistor) SetQueue(a []fleetapi.Action) { + m.Called(a) +} + +func (m *mockPersistor) Save() error { + args := m.Called() + return args.Error(0) +} + +func TestNewQueue(t *testing.T) { ts := time.Now() a1 := &mockAction{} a1.On("ID").Return("test-1") @@ -60,21 +73,21 @@ func TestNewActionQueue(t *testing.T) { a3.On("StartTime").Return(ts.Add(time.Minute), nil) t.Run("nil actions slice", func(t *testing.T) { - q, err := NewActionQueue(nil) + q, err := newQueue(nil) require.NoError(t, err) assert.NotNil(t, q) assert.Empty(t, q) }) t.Run("empty actions slice", func(t *testing.T) { - q, err := NewActionQueue([]fleetapi.Action{}) + q, err := newQueue([]fleetapi.Action{}) require.NoError(t, err) assert.NotNil(t, q) assert.Empty(t, q) }) t.Run("ordered actions list", func(t *testing.T) { - q, err := NewActionQueue([]fleetapi.Action{a1, a2, a3}) + q, err := newQueue([]fleetapi.Action{a1, a2, a3}) assert.NotNil(t, q) require.NoError(t, err) assert.Len(t, *q, 3) @@ -89,7 +102,7 @@ func TestNewActionQueue(t *testing.T) { }) t.Run("unordered actions list", func(t *testing.T) { - q, err := NewActionQueue([]fleetapi.Action{a3, a2, a1}) + q, err := newQueue([]fleetapi.Action{a3, a2, a1}) require.NoError(t, err) assert.NotNil(t, q) assert.Len(t, *q, 3) @@ -106,13 +119,13 @@ func TestNewActionQueue(t *testing.T) { t.Run("start time error", func(t *testing.T) { a := &mockAction{} a.On("StartTime").Return(time.Time{}, errors.New("oh no")) - q, err := NewActionQueue([]fleetapi.Action{a}) + q, err := newQueue([]fleetapi.Action{a}) assert.EqualError(t, err, "oh no") assert.Nil(t, q) }) } -func assertOrdered(t *testing.T, q *ActionQueue) { +func assertOrdered(t *testing.T, q *queue) { t.Helper() require.Len(t, *q, 3) i := heap.Pop(q).(*item) @@ -137,48 +150,56 @@ func Test_ActionQueue_Add(t *testing.T) { a3.On("ID").Return("test-3") t.Run("ascending order", func(t *testing.T) { - q := &ActionQueue{} - q.Add(a1, 1) - q.Add(a2, 2) - q.Add(a3, 3) - - assertOrdered(t, q) + aq := &ActionQueue{ + q: &queue{}, + } + aq.Add(a1, 1) + aq.Add(a2, 2) + aq.Add(a3, 3) + + assertOrdered(t, aq.q) }) t.Run("Add descending order", func(t *testing.T) { - q := &ActionQueue{} - q.Add(a3, 3) - q.Add(a2, 2) - q.Add(a1, 1) - - assertOrdered(t, q) + aq := &ActionQueue{ + q: &queue{}, + } + aq.Add(a3, 3) + aq.Add(a2, 2) + aq.Add(a1, 1) + + assertOrdered(t, aq.q) }) t.Run("mixed order", func(t *testing.T) { - q := &ActionQueue{} - q.Add(a1, 1) - q.Add(a3, 3) - q.Add(a2, 2) - - assertOrdered(t, q) + aq := &ActionQueue{ + q: &queue{}, + } + aq.Add(a1, 1) + aq.Add(a3, 3) + aq.Add(a2, 2) + + assertOrdered(t, aq.q) }) t.Run("two items have same priority", func(t *testing.T) { - q := &ActionQueue{} - q.Add(a1, 1) - q.Add(a2, 2) - q.Add(a3, 2) - - require.Len(t, *q, 3) - i := heap.Pop(q).(*item) + aq := &ActionQueue{ + q: &queue{}, + } + aq.Add(a1, 1) + aq.Add(a2, 2) + aq.Add(a3, 2) + + require.Len(t, *aq.q, 3) + i := heap.Pop(aq.q).(*item) assert.Equal(t, int64(1), i.priority) assert.Equal(t, "test-1", i.action.ID()) // next two items have same priority, however the ids may not match insertion order - i = heap.Pop(q).(*item) + i = heap.Pop(aq.q).(*item) assert.Equal(t, int64(2), i.priority) - i = heap.Pop(q).(*item) + i = heap.Pop(aq.q).(*item) assert.Equal(t, int64(2), i.priority) - assert.Empty(t, *q) + assert.Empty(t, *aq.q) }) } @@ -191,17 +212,19 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { a3.On("ID").Return("test-3") t.Run("empty queue", func(t *testing.T) { - q := &ActionQueue{} + aq := &ActionQueue{ + q: &queue{}, + } - actions := q.DequeueActions() + actions := aq.DequeueActions() assert.Empty(t, actions) - assert.Empty(t, *q) + assert.Empty(t, *aq.q) }) t.Run("one action from queue", func(t *testing.T) { ts := time.Now() - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: ts.Add(-1 * time.Minute).Unix(), index: 0, @@ -215,8 +238,9 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - actions := q.DequeueActions() + actions := aq.DequeueActions() require.Len(t, actions, 1) assert.Equal(t, "test-1", actions[0].ID()) @@ -234,7 +258,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { t.Run("two actions from queue", func(t *testing.T) { ts := time.Now() - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: ts.Add(-1 * time.Minute).Unix(), index: 0, @@ -248,8 +272,9 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - actions := q.DequeueActions() + actions := aq.DequeueActions() require.Len(t, actions, 2) assert.Equal(t, "test-2", actions[0].ID()) @@ -265,7 +290,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { t.Run("all actions from queue", func(t *testing.T) { ts := time.Now() - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: ts.Add(-1 * time.Minute).Unix(), index: 0, @@ -279,8 +304,9 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - actions := q.DequeueActions() + actions := aq.DequeueActions() require.Len(t, actions, 3) assert.Equal(t, "test-3", actions[0].ID()) @@ -292,7 +318,7 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { t.Run("no actions from queue", func(t *testing.T) { ts := time.Now() - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: ts.Add(1 * time.Minute).Unix(), index: 0, @@ -306,8 +332,9 @@ func Test_ActionQueue_DequeueActions(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - actions := q.DequeueActions() + actions := aq.DequeueActions() assert.Empty(t, actions) require.Len(t, *q, 3) @@ -333,15 +360,16 @@ func Test_ActionQueue_Cancel(t *testing.T) { a3.On("ID").Return("test-3") t.Run("empty queue", func(t *testing.T) { - q := &ActionQueue{} + q := &queue{} + aq := &ActionQueue{q, &mockPersistor{}} - n := q.Cancel("test-1") + n := aq.Cancel("test-1") assert.Zero(t, n) assert.Empty(t, *q) }) t.Run("one item cancelled", func(t *testing.T) { - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: 1, index: 0, @@ -355,8 +383,9 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - n := q.Cancel("test-1") + n := aq.Cancel("test-1") assert.Equal(t, 1, n) assert.Len(t, *q, 2) @@ -370,7 +399,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { }) t.Run("two items cancelled", func(t *testing.T) { - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: 1, index: 0, @@ -384,8 +413,9 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - n := q.Cancel("test-1") + n := aq.Cancel("test-1") assert.Equal(t, 2, n) assert.Len(t, *q, 1) @@ -396,7 +426,7 @@ func Test_ActionQueue_Cancel(t *testing.T) { }) t.Run("all items cancelled", func(t *testing.T) { - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: 1, index: 0, @@ -410,14 +440,15 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - n := q.Cancel("test-1") + n := aq.Cancel("test-1") assert.Equal(t, 3, n) assert.Empty(t, *q) }) t.Run("no items cancelled", func(t *testing.T) { - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: 1, index: 0, @@ -431,8 +462,9 @@ func Test_ActionQueue_Cancel(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - n := q.Cancel("test-0") + n := aq.Cancel("test-0") assert.Zero(t, n) assert.Len(t, *q, 3) @@ -451,8 +483,9 @@ func Test_ActionQueue_Cancel(t *testing.T) { func Test_ActionQueue_Actions(t *testing.T) { t.Run("empty queue", func(t *testing.T) { - q := &ActionQueue{} - actions := q.Actions() + q := &queue{} + aq := &ActionQueue{q, &mockPersistor{}} + actions := aq.Actions() assert.Len(t, actions, 0) }) @@ -463,7 +496,7 @@ func Test_ActionQueue_Actions(t *testing.T) { a2.On("ID").Return("test-2") a3 := &mockAction{} a3.On("ID").Return("test-3") - q := &ActionQueue{&item{ + q := &queue{&item{ action: a1, priority: 1, index: 0, @@ -477,8 +510,9 @@ func Test_ActionQueue_Actions(t *testing.T) { index: 2, }} heap.Init(q) + aq := &ActionQueue{q, &mockPersistor{}} - actions := q.Actions() + actions := aq.Actions() assert.Len(t, actions, 3) assert.Equal(t, "test-1", actions[0].ID()) })