diff --git a/internal/pkg/agent/application/application.go b/internal/pkg/agent/application/application.go index b6b47a30382..b312759b505 100644 --- a/internal/pkg/agent/application/application.go +++ b/internal/pkg/agent/application/application.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/elastic/elastic-agent-libs/logp" + "github.com/elastic/elastic-agent/pkg/features" "go.elastic.co/apm" @@ -25,7 +26,6 @@ import ( "github.com/elastic/elastic-agent/pkg/component" "github.com/elastic/elastic-agent/pkg/component/runtime" "github.com/elastic/elastic-agent/pkg/core/logger" - "github.com/elastic/elastic-agent/pkg/features" ) // New creates a new Agent and bootstrap the required subsystem. @@ -38,40 +38,36 @@ func New( tracer *apm.Tracer, disableMonitoring bool, modifiers ...component.PlatformModifier, -) (*coordinator.Coordinator, error) { +) (*coordinator.Coordinator, composable.Controller, error) { platform, err := component.LoadPlatformDetail(modifiers...) if err != nil { - return nil, fmt.Errorf("failed to gather system information: %w", err) + return nil, nil, fmt.Errorf("failed to gather system information: %w", err) } log.Info("Gathered system information") specs, err := component.LoadRuntimeSpecs(paths.Components(), platform) if err != nil { - return nil, fmt.Errorf("failed to detect inputs and outputs: %w", err) + return nil, nil, fmt.Errorf("failed to detect inputs and outputs: %w", err) } log.With("inputs", specs.Inputs()).Info("Detected available inputs and outputs") caps, err := capabilities.Load(paths.AgentCapabilitiesPath(), log) if err != nil { - return nil, fmt.Errorf("failed to determine capabilities: %w", err) + return nil, nil, fmt.Errorf("failed to determine capabilities: %w", err) } log.Info("Determined allowed capabilities") pathConfigFile := paths.ConfigFile() rawConfig, err := config.LoadFile(pathConfigFile) if err != nil { - return nil, fmt.Errorf("failed to load configuration: %w", err) + return nil, nil, fmt.Errorf("failed to load configuration: %w", err) } if err := info.InjectAgentConfig(rawConfig); err != nil { - return nil, fmt.Errorf("failed to load configuration: %w", err) + return nil, nil, fmt.Errorf("failed to load configuration: %w", err) } cfg, err := configuration.NewFromConfig(rawConfig) if err != nil { - return nil, fmt.Errorf("failed to load configuration: %w", err) - } - - if err := features.Apply(rawConfig); err != nil { - return nil, fmt.Errorf("could not parse and apply feature flags config: %w", err) + return nil, nil, fmt.Errorf("failed to load configuration: %w", err) } // monitoring is not supported in bootstrap mode https://github.com/elastic/elastic-agent/issues/1761 @@ -89,7 +85,7 @@ func New( cfg.Settings.GRPC, ) if err != nil { - return nil, fmt.Errorf("failed to initialize runtime manager: %w", err) + return nil, nil, fmt.Errorf("failed to initialize runtime manager: %w", err) } var configMgr coordinator.ConfigManager @@ -114,7 +110,7 @@ func New( var store storage.Store store, cfg, err = mergeFleetConfig(rawConfig) if err != nil { - return nil, err + return nil, nil, err } if configuration.IsFleetServerBootstrap(cfg.Fleet) { @@ -131,7 +127,7 @@ func New( managed, err = newManagedConfigManager(log, agentInfo, cfg, store, runtime) if err != nil { - return nil, err + return nil, nil, err } configMgr = managed } @@ -139,7 +135,7 @@ func New( composable, err := composable.New(log, rawConfig, composableManaged) if err != nil { - return nil, errors.New(err, "failed to initialize composable controller") + return nil, nil, errors.New(err, "failed to initialize composable controller") } coord := coordinator.New(log, logLevel, agentInfo, specs, reexec, upgrader, runtime, configMgr, composable, caps, monitor, isManaged, compModifiers...) @@ -148,7 +144,14 @@ func New( // coordinator, so it must be set here once the coordinator is created managed.coord = coord } - return coord, nil + + // It is important that feature flags from configuration are applied as late as possible. This will ensure that + // any feature flag change callbacks are registered before they get called by `features.Apply`. + if err := features.Apply(rawConfig); err != nil { + return nil, nil, fmt.Errorf("could not parse and apply feature flags config: %w", err) + } + + return coord, composable, nil } func mergeFleetConfig(rawConfig *config.Config) (storage.Store, *configuration.Configuration, error) { diff --git a/internal/pkg/agent/cmd/run.go b/internal/pkg/agent/cmd/run.go index d6fa2127ef8..2d5d0ff4841 100644 --- a/internal/pkg/agent/cmd/run.go +++ b/internal/pkg/agent/cmd/run.go @@ -207,10 +207,11 @@ func run(override cfgOverrider, modifiers ...component.PlatformModifier) error { l.Info("APM instrumentation disabled") } - coord, err := application.New(l, baseLogger, logLvl, agentInfo, rex, tracer, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) + coord, composable, err := application.New(l, baseLogger, logLvl, agentInfo, rex, tracer, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) if err != nil { return err } + defer composable.Close() serverStopFn, err := setupMetrics(l, cfg.Settings.DownloadConfig.OS(), cfg.Settings.MonitoringConfig, tracer, coord) if err != nil { diff --git a/internal/pkg/agent/vars/vars.go b/internal/pkg/agent/vars/vars.go index 65c0ef2ae1f..3a78b4e76f9 100644 --- a/internal/pkg/agent/vars/vars.go +++ b/internal/pkg/agent/vars/vars.go @@ -26,6 +26,7 @@ func WaitForVariables(ctx context.Context, l *logger.Logger, cfg *config.Config, if err != nil { return nil, fmt.Errorf("failed to create composable controller: %w", err) } + defer composable.Close() hasTimeout := false if wait > time.Duration(0) { diff --git a/internal/pkg/composable/controller.go b/internal/pkg/composable/controller.go index 4c736bb7d0f..8c39dc4ed1e 100644 --- a/internal/pkg/composable/controller.go +++ b/internal/pkg/composable/controller.go @@ -34,6 +34,10 @@ type Controller interface { // Watch returns the channel to watch for variable changes. Watch() <-chan []*transpiler.Vars + + // Close closes the controller, allowing for any resource + // cleanup and such. + Close() } // controller manages the state of the providers current context. @@ -251,6 +255,34 @@ func (c *controller) Watch() <-chan []*transpiler.Vars { return c.ch } +// Close closes the controller, allowing for any resource +// cleanup and such. +func (c *controller) Close() { + // Attempt to close all closeable context providers. + for name, state := range c.contextProviders { + cp, ok := state.provider.(corecomp.CloseableProvider) + if !ok { + continue + } + + if err := cp.Close(); err != nil { + c.logger.Errorf("unable to close context provider %q: %s", name, err.Error()) + } + } + + // Attempt to close all closeable dynamic providers. + for name, state := range c.dynamicProviders { + cp, ok := state.provider.(corecomp.CloseableProvider) + if !ok { + continue + } + + if err := cp.Close(); err != nil { + c.logger.Errorf("unable to close dynamic provider %q: %s", name, err.Error()) + } + } +} + type contextProviderState struct { context.Context diff --git a/internal/pkg/composable/controller_test.go b/internal/pkg/composable/controller_test.go index 050fe78ed74..ec418795c19 100644 --- a/internal/pkg/composable/controller_test.go +++ b/internal/pkg/composable/controller_test.go @@ -191,6 +191,8 @@ func TestCancellation(t *testing.T) { t.Run(fmt.Sprintf("test run %d", i), func(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) + defer c.Close() + ctx, cancelFn := context.WithTimeout(context.Background(), timeout) defer cancelFn() err = c.Run(ctx) @@ -205,6 +207,8 @@ func TestCancellation(t *testing.T) { t.Run("immediate cancellation", func(t *testing.T) { c, err := composable.New(log, cfg, false) require.NoError(t, err) + defer c.Close() + ctx, cancelFn := context.WithTimeout(context.Background(), 0) cancelFn() err = c.Run(ctx) diff --git a/internal/pkg/composable/providers/host/host.go b/internal/pkg/composable/providers/host/host.go index ec7fef8927e..343d8d04488 100644 --- a/internal/pkg/composable/providers/host/host.go +++ b/internal/pkg/composable/providers/host/host.go @@ -20,8 +20,12 @@ import ( "github.com/elastic/elastic-agent/pkg/core/logger" ) -// DefaultCheckInterval is the default timeout used to check if any host information has changed. -const DefaultCheckInterval = 5 * time.Minute +const ( + // DefaultCheckInterval is the default timeout used to check if any host information has changed. + DefaultCheckInterval = 5 * time.Minute + + fqdnFeatureFlagCallbackID = "host_provider" +) func init() { composable.Providers.MustAddContextProvider("host", ContextProviderBuilder) @@ -34,6 +38,10 @@ type contextProvider struct { CheckInterval time.Duration `config:"check_interval"` + // fqdnFFChangeCh is used to signal when the FQDN + // feature flag has changed + fqdnFFChangeCh chan struct{} + // used by testing fetcher infoFetcher } @@ -49,21 +57,6 @@ func (c *contextProvider) Run(comm corecomp.ContextProviderComm) error { return errors.New(err, "failed to set mapping", errors.TypeUnexpected) } - const fqdnFeatureFlagCallbackID = "host_provider" - fqdnFFChangeCh := make(chan struct{}) - err = features.AddFQDNOnChangeCallback( - onFQDNFeatureFlagChange(fqdnFFChangeCh), - fqdnFeatureFlagCallbackID, - ) - if err != nil { - return fmt.Errorf("unable to add FQDN onChange callback in host provider: %w", err) - } - - defer func() { - features.RemoveFQDNOnChangeCallback(fqdnFeatureFlagCallbackID) - close(fqdnFFChangeCh) - }() - // Update context when any host information changes. for { t := time.NewTimer(c.CheckInterval) @@ -71,7 +64,7 @@ func (c *contextProvider) Run(comm corecomp.ContextProviderComm) error { case <-comm.Done(): t.Stop() return comm.Err() - case <-fqdnFFChangeCh: + case <-c.fqdnFFChangeCh: case <-t.C: } @@ -92,13 +85,21 @@ func (c *contextProvider) Run(comm corecomp.ContextProviderComm) error { } } -func onFQDNFeatureFlagChange(fqdnFFChangeCh chan struct{}) features.BoolValueOnChangeCallback { - return func(new, old bool) { - // FQDN feature flag was toggled, so notify on channel - fqdnFFChangeCh <- struct{}{} +func (c *contextProvider) onFQDNFeatureFlagChange(new, old bool) { + // FQDN feature flag was toggled, so notify on channel + select { + case c.fqdnFFChangeCh <- struct{}{}: + default: } } +func (c *contextProvider) Close() error { + features.RemoveFQDNOnChangeCallback(fqdnFeatureFlagCallbackID) + close(c.fqdnFFChangeCh) + + return nil +} + // ContextProviderBuilder builds the context provider. func ContextProviderBuilder(log *logger.Logger, c *config.Config, _ bool) (corecomp.ContextProvider, error) { p := &contextProvider{ @@ -114,6 +115,16 @@ func ContextProviderBuilder(log *logger.Logger, c *config.Config, _ bool) (corec if p.CheckInterval <= 0 { p.CheckInterval = DefaultCheckInterval } + + p.fqdnFFChangeCh = make(chan struct{}, 1) + err := features.AddFQDNOnChangeCallback( + p.onFQDNFeatureFlagChange, + fqdnFeatureFlagCallbackID, + ) + if err != nil { + return nil, fmt.Errorf("unable to add FQDN onChange callback in host provider: %w", err) + } + return p, nil } diff --git a/internal/pkg/composable/providers/host/host_test.go b/internal/pkg/composable/providers/host/host_test.go index 62fb9738505..dbdd0ed08e7 100644 --- a/internal/pkg/composable/providers/host/host_test.go +++ b/internal/pkg/composable/providers/host/host_test.go @@ -101,6 +101,13 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { provider, err := builder(log, c, true) require.NoError(t, err) + hostProvider, ok := provider.(*contextProvider) + require.True(t, ok) + defer func() { + err := hostProvider.Close() + require.NoError(t, err) + }() + ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() @@ -109,9 +116,6 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { // Track the number of times hostProvider.fetcher is called. numCalled := 0 - hostProvider, ok := provider.(*contextProvider) - require.True(t, ok) - hostProvider.fetcher = func() (map[string]interface{}, error) { numCalled++ return nil, nil @@ -119,16 +123,9 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { // Run the provider go func() { - err = provider.Run(comm) + err = hostProvider.Run(comm) }() - // Wait long enough for provider.Run to register - // the FQDN feature flag onChange callback. - numCallbacks := features.NumFQDNOnChangeCallbacks() - require.Eventually(t, func() bool { - return features.NumFQDNOnChangeCallbacks() == numCallbacks+1 - }, 100*time.Millisecond, 10*time.Millisecond) - // Trigger the FQDN feature flag callback by // toggling the FQDN feature flag err = features.Apply(config.MustNewConfigFrom(map[string]interface{}{ @@ -143,7 +140,7 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { // - once, right after the provider is run, and // - once again, when the FQDN feature flag callback is triggered return numCalled == 2 - }, 100*time.Millisecond, 10*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) } func returnHostMapping(log *logger.Logger) infoFetcher { diff --git a/internal/pkg/core/composable/providers.go b/internal/pkg/core/composable/providers.go index 235e17d83fa..f6d2a8f3e26 100644 --- a/internal/pkg/core/composable/providers.go +++ b/internal/pkg/core/composable/providers.go @@ -28,3 +28,11 @@ type ContextProvider interface { // Run runs the context provider. Run(ContextProviderComm) error } + +// CloseableProvider is an interface that providers may choose to implement +// if it makes sense for them, e.g. if they have any resources that need +// cleaning up after the provider's (final) run. +type CloseableProvider interface { + // Close is called after all runs of the provider have finished. + Close() error +} diff --git a/pkg/features/features.go b/pkg/features/features.go index 07b48877d87..7613b055fd2 100644 --- a/pkg/features/features.go +++ b/pkg/features/features.go @@ -81,15 +81,6 @@ func RemoveFQDNOnChangeCallback(id string) { delete(current.fqdnCallbacks, id) } -// NumFQDNOnChangeCallbacks returns the number of FQDN onChange -// callbacks currently registered. Useful for testing. -func NumFQDNOnChangeCallbacks() int { - current.mu.RLock() - defer current.mu.RUnlock() - - return len(current.fqdnCallbacks) -} - // setFQDN sets the value of the FQDN flag in Flags. func (f *Flags) setFQDN(newValue bool) { f.mu.Lock()