diff --git a/beater/beater.go b/beater/beater.go index 89486168f00..f4e363c2eed 100644 --- a/beater/beater.go +++ b/beater/beater.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/pkg/errors" + "go.elastic.co/apm" "golang.org/x/sync/errgroup" "github.com/elastic/beats/v7/libbeat/beat" @@ -44,18 +45,16 @@ var ( // CreatorParams holds parameters for creating beat.Beaters. type CreatorParams struct { - // RunServer is used to run the APM Server. + // WrapRunServer is used to wrap the RunServerFunc used to run the APM Server. // - // This should be set to beater.RunServer, or a function which wraps it. - RunServer RunServerFunc + // WrapRunServer is optional. If provided, it must return a function that calls + // its input, possibly modifying the parameters on the way in. + WrapRunServer func(RunServerFunc) RunServerFunc } // NewCreator returns a new beat.Creator which creates beaters // using the provided CreatorParams. func NewCreator(args CreatorParams) beat.Creator { - if args.RunServer == nil { - panic("args.RunServer must be non-nil") - } return func(b *beat.Beat, ucfg *common.Config) (beat.Beater, error) { logger := logp.NewLogger(logs.Beater) if err := checkConfig(logger); err != nil { @@ -72,10 +71,10 @@ func NewCreator(args CreatorParams) beat.Creator { } bt := &beater{ - config: beaterConfig, - stopped: false, - logger: logger, - runServer: args.RunServer, + config: beaterConfig, + stopped: false, + logger: logger, + wrapRunServer: args.WrapRunServer, } // setup pipelines if explicitly directed to or setup --pipelines and config is not set at all @@ -119,9 +118,9 @@ func checkConfig(logger *logp.Logger) error { } type beater struct { - config *config.Config - logger *logp.Logger - runServer RunServerFunc + config *config.Config + logger *logp.Logger + wrapRunServer func(RunServerFunc) RunServerFunc mutex sync.Mutex // guards stopServer and stopped stopServer func() @@ -137,30 +136,14 @@ func (bt *beater) Run(b *beat.Beat) error { } defer tracer.Close() - runServer := bt.runServer + runServer := runServer if tracerServer != nil { - // Self-instrumentation enabled, so running the APM Server - // should run an internal server for receiving trace data. - origRunServer := runServer - runServer = func(ctx context.Context, args ServerParams) error { - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - defer tracerServer.stop() - <-ctx.Done() - // Close the tracer now to prevent the server - // from waiting for more events during graceful - // shutdown. - tracer.Close() - return nil - }) - g.Go(func() error { - return tracerServer.serve(args.Reporter) - }) - g.Go(func() error { - return origRunServer(ctx, args) - }) - return g.Wait() - } + runServer = runServerWithTracerServer(runServer, tracerServer, tracer) + } + if bt.wrapRunServer != nil { + // Wrap runServer function, enabling injection of + // behaviour into the processing/reporting pipeline. + runServer = bt.wrapRunServer(runServer) } publisher, err := publish.NewPublisher(b.Publisher, tracer, &publish.PublisherConfig{ @@ -244,3 +227,27 @@ func (bt *beater) Stop() { bt.stopServer() bt.stopped = true } + +// runServerWithTracerServer wraps runServer such that it also runs +// tracerServer, stopping it and the tracer when the server shuts down. +func runServerWithTracerServer(runServer RunServerFunc, tracerServer *tracerServer, tracer *apm.Tracer) RunServerFunc { + return func(ctx context.Context, args ServerParams) error { + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + defer tracerServer.stop() + <-ctx.Done() + // Close the tracer now to prevent the server + // from waiting for more events during graceful + // shutdown. + tracer.Close() + return nil + }) + g.Go(func() error { + return tracerServer.serve(args.Reporter) + }) + g.Go(func() error { + return runServer(ctx, args) + }) + return g.Wait() + } +} diff --git a/beater/beater_test.go b/beater/beater_test.go index 0e1dd5fe817..748c798bc4e 100644 --- a/beater/beater_test.go +++ b/beater/beater_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/elastic/apm-server/beater/config" + "github.com/elastic/apm-server/model" "github.com/elastic/apm-server/publish" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" @@ -42,26 +43,44 @@ type testBeater struct { client *http.Client } -func setupBeater(t *testing.T, apmBeat *beat.Beat, ucfg *common.Config, beatConfig *beat.BeatConfig) (*testBeater, error) { +func setupBeater( + t *testing.T, + apmBeat *beat.Beat, + ucfg *common.Config, + beatConfig *beat.BeatConfig, +) (*testBeater, error) { + onboardingDocs := make(chan onboardingDoc, 1) createBeater := NewCreator(CreatorParams{ - RunServer: func(ctx context.Context, args ServerParams) error { - // Wrap the reporter so we can intercept the - // onboarding doc, to extract the listen address. - origReporter := args.Reporter - args.Reporter = func(ctx context.Context, req publish.PendingReq) error { - for _, tf := range req.Transformables { - if o, ok := tf.(onboardingDoc); ok { - select { - case <-ctx.Done(): - return ctx.Err() - case onboardingDocs <- o: + WrapRunServer: func(runServer RunServerFunc) RunServerFunc { + return func(ctx context.Context, args ServerParams) error { + // Wrap the reporter so we can intercept the + // onboarding doc, to extract the listen address. + origReporter := args.Reporter + args.Reporter = func(ctx context.Context, req publish.PendingReq) error { + for _, tf := range req.Transformables { + switch tf := tf.(type) { + case onboardingDoc: + select { + case <-ctx.Done(): + return ctx.Err() + case onboardingDocs <- tf: + } + + case *model.Transaction: + // Add a label to test that everything + // goes through the wrapped reporter. + if tf.Labels == nil { + labels := make(model.Labels) + tf.Labels = &labels + } + (*tf.Labels)["wrapped_reporter"] = true } } + return origReporter(ctx, req) } - return origReporter(ctx, req) + return runServer(ctx, args) } - return RunServer(ctx, args) }, }) diff --git a/beater/server.go b/beater/server.go index bdb49902e00..e23c129387e 100644 --- a/beater/server.go +++ b/beater/server.go @@ -53,8 +53,8 @@ type ServerParams struct { Reporter publish.Reporter } -// RunServer runs the APM Server until a fatal error occurs, or ctx is cancelled. -func RunServer(ctx context.Context, args ServerParams) error { +// runServer runs the APM Server until a fatal error occurs, or ctx is cancelled. +func runServer(ctx context.Context, args ServerParams) error { srv, err := newServer(args.Logger, args.Config, args.Tracer, args.Reporter) if err != nil { return err diff --git a/beater/test_approved_es_documents/TestPublishIntegrationEvents.approved.json b/beater/test_approved_es_documents/TestPublishIntegrationEvents.approved.json index a96f7fd22ee..eafddce8827 100644 --- a/beater/test_approved_es_documents/TestPublishIntegrationEvents.approved.json +++ b/beater/test_approved_es_documents/TestPublishIntegrationEvents.approved.json @@ -100,7 +100,8 @@ "ab_testing": true, "group": "experimental", "organization_uuid": "9f0e9d64-c185-4d21-a6f4-4673ed561ec8", - "segment": 5 + "segment": 5, + "wrapped_reporter": true }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", diff --git a/beater/test_approved_es_documents/TestPublishIntegrationMinimalEvents.approved.json b/beater/test_approved_es_documents/TestPublishIntegrationMinimalEvents.approved.json index e2dcb79c6c4..73ad6671ba8 100644 --- a/beater/test_approved_es_documents/TestPublishIntegrationMinimalEvents.approved.json +++ b/beater/test_approved_es_documents/TestPublishIntegrationMinimalEvents.approved.json @@ -18,6 +18,9 @@ "host": { "ip": "127.0.0.1" }, + "labels": { + "wrapped_reporter": true + }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", "hostname": "", diff --git a/beater/test_approved_es_documents/TestPublishIntegrationTransactions.approved.json b/beater/test_approved_es_documents/TestPublishIntegrationTransactions.approved.json index 267384f2393..5277293a092 100644 --- a/beater/test_approved_es_documents/TestPublishIntegrationTransactions.approved.json +++ b/beater/test_approved_es_documents/TestPublishIntegrationTransactions.approved.json @@ -59,7 +59,8 @@ }, "labels": { "tag1": "one", - "tag2": 2 + "tag2": 2, + "wrapped_reporter": true }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", @@ -253,7 +254,8 @@ "tag1": "one", "tag2": 12, "tag3": 12.45, - "tag4": false + "tag4": false, + "wrapped_reporter": true }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", @@ -419,7 +421,8 @@ }, "labels": { "tag1": "one", - "tag2": 2 + "tag2": 2, + "wrapped_reporter": true }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", @@ -555,7 +558,8 @@ }, "labels": { "tag1": "one", - "tag2": 2 + "tag2": 2, + "wrapped_reporter": true }, "observer": { "ephemeral_id": "00000000-0000-0000-0000-000000000000", diff --git a/beater/tracing_test.go b/beater/tracing_test.go index 5e3723ba077..d4ca7aaec7a 100644 --- a/beater/tracing_test.go +++ b/beater/tracing_test.go @@ -54,6 +54,12 @@ func TestServerTracingEnabled(t *testing.T) { if testTransactionIds.Contains(eventTransactionId(e)) { continue } + + // Check that self-instrumentation goes through the + // reporter wrapped by setupBeater. + wrapped, _ := e.GetValue("labels.wrapped_reporter") + assert.Equal(t, true, wrapped) + selfTransactions = append(selfTransactions, eventTransactionName(e)) case <-time.After(5 * time.Second): assert.FailNow(t, "timed out waiting for transaction") diff --git a/main.go b/main.go index 98ba8b60ff6..9bdaf0ec8fb 100644 --- a/main.go +++ b/main.go @@ -26,9 +26,7 @@ import ( "github.com/elastic/apm-server/cmd" ) -var rootCmd = cmd.NewRootCommand(beater.NewCreator(beater.CreatorParams{ - RunServer: beater.RunServer, -})) +var rootCmd = cmd.NewRootCommand(beater.NewCreator(beater.CreatorParams{})) func main() { if err := rootCmd.Execute(); err != nil { diff --git a/x-pack/apm-server/cmd/root_test.go b/x-pack/apm-server/cmd/root_test.go index a4e1a82eb04..368aff3cefe 100644 --- a/x-pack/apm-server/cmd/root_test.go +++ b/x-pack/apm-server/cmd/root_test.go @@ -22,9 +22,7 @@ func TestSubCommands(t *testing.T) { "version": {}, } - rootCmd := NewXPackRootCommand(beater.NewCreator(beater.CreatorParams{ - RunServer: beater.RunServer, - })) + rootCmd := NewXPackRootCommand(beater.NewCreator(beater.CreatorParams{})) for _, cmd := range rootCmd.Commands() { name := cmd.Name() if _, ok := validCommands[name]; !ok { diff --git a/x-pack/apm-server/main.go b/x-pack/apm-server/main.go index dfc60387a69..ff8f57e9ca1 100644 --- a/x-pack/apm-server/main.go +++ b/x-pack/apm-server/main.go @@ -24,9 +24,9 @@ import ( // and the publish.Reporter will be wrapped such that all // transactions pass through the aggregator before being // published to libbeat. -func runServerWithAggregator(ctx context.Context, args beater.ServerParams) error { +func runServerWithAggregator(ctx context.Context, runServer beater.RunServerFunc, args beater.ServerParams) error { if !args.Config.Aggregation.Enabled { - return beater.RunServer(ctx, args) + return runServer(ctx, args) } agg, err := txmetrics.NewAggregator(txmetrics.AggregatorConfig{ @@ -58,14 +58,16 @@ func runServerWithAggregator(ctx context.Context, args beater.ServerParams) erro } }) g.Go(func() error { - return beater.RunServer(ctx, args) + return runServer(ctx, args) }) return g.Wait() } var rootCmd = cmd.NewXPackRootCommand(beater.NewCreator(beater.CreatorParams{ - RunServer: func(ctx context.Context, args beater.ServerParams) error { - return runServerWithAggregator(ctx, args) + WrapRunServer: func(runServer beater.RunServerFunc) beater.RunServerFunc { + return func(ctx context.Context, args beater.ServerParams) error { + return runServerWithAggregator(ctx, runServer, args) + } }, }))