diff --git a/Dockerfile b/Dockerfile index c96c694b87..d49c7ca1a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,6 +31,8 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ software-properties-common \ gpg \ apt-utils \ + libc6-dev \ + gcc \ make && \ sudo update-ca-certificates && \ rm -rf /var/lib/apt/lists/* diff --git a/internal/driverutil/hello.go b/internal/driverutil/hello.go new file mode 100644 index 0000000000..356e1d3336 --- /dev/null +++ b/internal/driverutil/hello.go @@ -0,0 +1,128 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "os" + "strings" +) + +const AwsLambdaPrefix = "AWS_Lambda_" + +const ( + // FaaS environment variable names + + // EnvVarAWSExecutionEnv is the AWS Execution environment variable. + EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" + // EnvVarAWSLambdaRuntimeAPI is the AWS Lambda runtime API variable. + EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + // EnvVarFunctionsWorkerRuntime is the functions worker runtime variable. + EnvVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" + // EnvVarKService is the K Service variable. + EnvVarKService = "K_SERVICE" + // EnvVarFunctionName is the function name variable. + EnvVarFunctionName = "FUNCTION_NAME" + // EnvVarVercel is the Vercel variable. + EnvVarVercel = "VERCEL" + // EnvVarK8s is the K8s veriable. + EnvVarK8s = "KUBERNETES_SERVICE_HOST" +) + +const ( + // FaaS environment variable names + + // EnvVarAWSRegion is the AWS region variable. + EnvVarAWSRegion = "AWS_REGION" + // EnvVarAWSLambdaFunctionMemorySize is the AWS Lambda function memory size variable. + EnvVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" + // EnvVarFunctionMemoryMB is the function memory in megabytes variable. + EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" + // EnvVarFunctionTimeoutSec is the function timeout in seconds variable. + EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" + // EnvVarFunctionRegion is the function region variable. + EnvVarFunctionRegion = "FUNCTION_REGION" + // EnvVarVercelRegion is the Vercel region variable. + EnvVarVercelRegion = "VERCEL_REGION" +) + +const ( + // FaaS environment names used by the client + + // EnvNameAWSLambda is the AWS Lambda environment name. + EnvNameAWSLambda = "aws.lambda" + // EnvNameAzureFunc is the Azure Function environment name. + EnvNameAzureFunc = "azure.func" + // EnvNameGCPFunc is the Google Cloud Function environment name. + EnvNameGCPFunc = "gcp.func" + // EnvNameVercel is the Vercel environment name. + EnvNameVercel = "vercel" +) + +// GetFaasEnvName parses the FaaS environment variable name and returns the +// corresponding name used by the client. If none of the variables or variables +// for multiple names are populated the client.env value MUST be entirely +// omitted. When variables for multiple "client.env.name" values are present, +// "vercel" takes precedence over "aws.lambda"; any other combination MUST cause +// "client.env" to be entirely omitted. +func GetFaasEnvName() string { + envVars := []string{ + EnvVarAWSExecutionEnv, + EnvVarAWSLambdaRuntimeAPI, + EnvVarFunctionsWorkerRuntime, + EnvVarKService, + EnvVarFunctionName, + EnvVarVercel, + } + + // If none of the variables are populated the client.env value MUST be + // entirely omitted. + names := make(map[string]struct{}) + + for _, envVar := range envVars { + val := os.Getenv(envVar) + if val == "" { + continue + } + + var name string + + switch envVar { + case EnvVarAWSExecutionEnv: + if !strings.HasPrefix(val, AwsLambdaPrefix) { + continue + } + + name = EnvNameAWSLambda + case EnvVarAWSLambdaRuntimeAPI: + name = EnvNameAWSLambda + case EnvVarFunctionsWorkerRuntime: + name = EnvNameAzureFunc + case EnvVarKService, EnvVarFunctionName: + name = EnvNameGCPFunc + case EnvVarVercel: + // "vercel" takes precedence over "aws.lambda". + delete(names, EnvNameAWSLambda) + + name = EnvNameVercel + } + + names[name] = struct{}{} + if len(names) > 1 { + // If multiple names are populated the client.env value + // MUST be entirely omitted. + names = nil + + break + } + } + + for name := range names { + return name + } + + return "" +} diff --git a/internal/driverutil/const.go b/internal/driverutil/operation.go similarity index 100% rename from internal/driverutil/const.go rename to internal/driverutil/operation.go diff --git a/internal/test/faas/awslambda/mongodb/main.go b/internal/test/faas/awslambda/mongodb/main.go index a0c55f9085..f9d8765550 100644 --- a/internal/test/faas/awslambda/mongodb/main.go +++ b/internal/test/faas/awslambda/mongodb/main.go @@ -27,11 +27,12 @@ const timeout = 60 * time.Second // event durations, as well as the number of heartbeats, commands, and open // conections. type eventListener struct { - commandCount int - commandDuration int64 - heartbeatCount int - heartbeatDuration int64 - openConnections int + commandCount int + commandDuration int64 + heartbeatAwaitedCount int + heartbeatCount int + heartbeatDuration int64 + openConnections int } // commandMonitor initializes an event.CommandMonitor that will count the number @@ -61,11 +62,19 @@ func (listener *eventListener) serverMonitor() *event.ServerMonitor { succeeded := func(e *event.ServerHeartbeatSucceededEvent) { listener.heartbeatCount++ listener.heartbeatDuration += e.DurationNanos + + if e.Awaited { + listener.heartbeatAwaitedCount++ + } } failed := func(e *event.ServerHeartbeatFailedEvent) { listener.heartbeatCount++ listener.heartbeatDuration += e.DurationNanos + + if e.Awaited { + listener.heartbeatAwaitedCount++ + } } return &event.ServerMonitor{ @@ -150,6 +159,12 @@ func handler(ctx context.Context, request events.APIGatewayProxyRequest) (events return gateway500(), fmt.Errorf("failed to delete: %w", err) } + // Driver must switch to polling monitoring when running within a FaaS + // environment. + if listener.heartbeatAwaitedCount > 0 { + return gateway500(), fmt.Errorf("FaaS environment fialed to switch to polling") + } + var avgCmdDur float64 if count := listener.commandCount; count != 0 { avgCmdDur = float64(listener.commandDuration) / float64(count) diff --git a/mongo/integration/handshake_test.go b/mongo/integration/handshake_test.go index b2cb7562f0..fc1d25eba9 100644 --- a/mongo/integration/handshake_test.go +++ b/mongo/integration/handshake_test.go @@ -55,31 +55,18 @@ func TestHandshakeProse(t *testing.T) { return elems } - const ( - envVarAWSExecutionEnv = "AWS_EXECUTION_ENV" - envVarAWSRegion = "AWS_REGION" - envVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - envVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" - envVarKService = "K_SERVICE" - envVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" - envVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" - envVarFunctionRegion = "FUNCTION_REGION" - envVarVercel = "VERCEL" - envVarVercelRegion = "VERCEL_REGION" - ) - // Reset the environment variables to avoid environment namespace // collision. - t.Setenv(envVarAWSExecutionEnv, "") - t.Setenv(envVarFunctionsWorkerRuntime, "") - t.Setenv(envVarKService, "") - t.Setenv(envVarVercel, "") - t.Setenv(envVarAWSRegion, "") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "") - t.Setenv(envVarFunctionMemoryMB, "") - t.Setenv(envVarFunctionTimeoutSec, "") - t.Setenv(envVarFunctionRegion, "") - t.Setenv(envVarVercelRegion, "") + t.Setenv("AWS_EXECUTION_ENV", "") + t.Setenv("FUNCTIONS_WORKER_RUNTIME", "") + t.Setenv("K_SERVICE", "") + t.Setenv("VERCEL", "") + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "") + t.Setenv("FUNCTION_MEMORY_MB", "") + t.Setenv("FUNCTION_TIMEOUT_SEC", "") + t.Setenv("FUNCTION_REGION", "") + t.Setenv("VERCEL_REGION", "") for _, test := range []struct { name string @@ -89,9 +76,9 @@ func TestHandshakeProse(t *testing.T) { { name: "1. valid AWS", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSRegion: "us-east-2", - envVarAWSLambdaFunctionMemorySize: "1024", + "AWS_EXECUTION_ENV": "AWS_Lambda_java8", + "AWS_REGION": "us-east-2", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, want: clientMetadata(bson.D{ {Key: "name", Value: "aws.lambda"}, @@ -102,7 +89,7 @@ func TestHandshakeProse(t *testing.T) { { name: "2. valid Azure", env: map[string]string{ - envVarFunctionsWorkerRuntime: "node", + "FUNCTIONS_WORKER_RUNTIME": "node", }, want: clientMetadata(bson.D{ {Key: "name", Value: "azure.func"}, @@ -111,10 +98,10 @@ func TestHandshakeProse(t *testing.T) { { name: "3. valid GCP", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionMemoryMB: "1024", - envVarFunctionTimeoutSec: "60", - envVarFunctionRegion: "us-central1", + "K_SERVICE": "servicename", + "FUNCTION_MEMORY_MB": "1024", + "FUNCTION_TIMEOUT_SEC": "60", + "FUNCTION_REGION": "us-central1", }, want: clientMetadata(bson.D{ {Key: "name", Value: "gcp.func"}, @@ -126,8 +113,8 @@ func TestHandshakeProse(t *testing.T) { { name: "4. valid Vercel", env: map[string]string{ - envVarVercel: "1", - envVarVercelRegion: "cdg1", + "VERCEL": "1", + "VERCEL_REGION": "cdg1", }, want: clientMetadata(bson.D{ {Key: "name", Value: "vercel"}, @@ -137,16 +124,16 @@ func TestHandshakeProse(t *testing.T) { { name: "5. invalid multiple providers", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarFunctionsWorkerRuntime: "node", + "AWS_EXECUTION_ENV": "AWS_Lambda_java8", + "FUNCTIONS_WORKER_RUNTIME": "node", }, want: clientMetadata(nil), }, { name: "6. invalid long string", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSRegion: func() string { + "AWS_EXECUTION_ENV": "AWS_Lambda_java8", + "AWS_REGION": func() string { var s string for i := 0; i < 512; i++ { s += "a" @@ -161,8 +148,8 @@ func TestHandshakeProse(t *testing.T) { { name: "7. invalid wrong types", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_java8", - envVarAWSLambdaFunctionMemorySize: "big", + "AWS_EXECUTION_ENV": "AWS_Lambda_java8", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big", }, want: clientMetadata(bson.D{ {Key: "name", Value: "aws.lambda"}, @@ -171,7 +158,7 @@ func TestHandshakeProse(t *testing.T) { { name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"", env: map[string]string{ - envVarAWSExecutionEnv: "EC2", + "AWS_EXECUTION_ENV": "EC2", }, want: clientMetadata(nil), }, @@ -188,32 +175,27 @@ func TestHandshakeProse(t *testing.T) { require.NoError(mt, err, "Ping error: %v", err) messages := mt.GetProxiedMessages() + handshakeMessage := messages[:1][0] - // First two messages are handshake messages - for idx, pair := range messages[:2] { - hello := handshake.LegacyHello - // Expect "hello" command name with API version. - if os.Getenv("REQUIRE_API_VERSION") == "true" { - hello = "hello" - } - - assert.Equal(mt, pair.CommandName, hello, "expected and actual command name at index %d are different", idx) + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } - sent := pair.Sent + assert.Equal(mt, hello, handshakeMessage.CommandName) - // Lookup the "client" field in the command document. - clientVal, err := sent.Command.LookupErr("client") - require.NoError(mt, err, "expected command %s at index %d to contain client field", sent.Command, idx) + // Lookup the "client" field in the command document. + clientVal, err := handshakeMessage.Sent.Command.LookupErr("client") + require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command) - got, ok := clientVal.DocumentOK() - require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) + got, ok := clientVal.DocumentOK() + require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) - wantBytes, err := bson.Marshal(test.want) - require.NoError(mt, err, "error marshaling want document: %v", err) + wantBytes, err := bson.Marshal(test.want) + require.NoError(mt, err, "error marshaling want document: %v", err) - want := bsoncore.Document(wantBytes) - assert.Equal(mt, want, got, "want: %v, got: %v", want, got) - } + want := bsoncore.Document(wantBytes) + assert.Equal(mt, want, got, "want: %v, got: %v", want, got) }) } } diff --git a/mongo/integration/sdam_prose_test.go b/mongo/integration/sdam_prose_test.go index 4e7f7dcab0..435bdc72da 100644 --- a/mongo/integration/sdam_prose_test.go +++ b/mongo/integration/sdam_prose_test.go @@ -32,7 +32,8 @@ func TestSDAMProse(t *testing.T) { heartbeatIntervalMtOpts := mtest.NewOptions(). ClientOptions(heartbeatIntervalClientOpts). CreateCollection(false). - ClientType(mtest.Proxy) + ClientType(mtest.Proxy). + MinServerVersion("4.4") // RTT Monitor / Streaming protocol is not supported for versions < 4.4. mt.RunOpts("heartbeats processed more frequently", heartbeatIntervalMtOpts, func(mt *mtest.T) { // Test that setting heartbeat interval to 500ms causes the client to process heartbeats // approximately every 500ms instead of the default 10s. Note that a Client doesn't diff --git a/mongo/integration/unified/client_entity.go b/mongo/integration/unified/client_entity.go index bd131799a3..6848485c70 100644 --- a/mongo/integration/unified/client_entity.go +++ b/mongo/integration/unified/client_entity.go @@ -599,6 +599,8 @@ func setClientOptionsFromURIOptions(clientOpts *options.ClientOptions, uriOpts b clientOpts.SetTimeout(time.Duration(value.(int32)) * time.Millisecond) case "serverselectiontimeoutms": clientOpts.SetServerSelectionTimeout(time.Duration(value.(int32)) * time.Millisecond) + case "servermonitoringmode": + clientOpts.SetServerMonitoringMode(value.(string)) default: return fmt.Errorf("unrecognized URI option %s", key) } diff --git a/mongo/integration/unified/event_verification.go b/mongo/integration/unified/event_verification.go index 91f7452907..1d54e3fb2a 100644 --- a/mongo/integration/unified/event_verification.go +++ b/mongo/integration/unified/event_verification.go @@ -9,6 +9,7 @@ package unified import ( "bytes" "context" + "errors" "fmt" "go.mongodb.org/mongo-driver/bson" @@ -64,10 +65,37 @@ type cmapEvent struct { } `bson:"poolClearedEvent"` } +type sdamEvent struct { + ServerDescriptionChangedEvent *struct { + NewDescription *struct { + Type *string `bson:"type"` + } `bson:"newDescription"` + + PreviousDescription *struct { + Type *string `bson:"type"` + } `bson:"previousDescription"` + } `bson:"serverDescriptionChangedEvent"` + + ServerHeartbeatStartedEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatStartedEvent"` + + ServerHeartbeatSucceededEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatSucceededEvent"` + + ServerHeartbeatFailedEvent *struct { + Awaited *bool `bson:"awaited"` + } `bson:"serverHeartbeatFailedEvent"` + + TopologyDescriptionChangedEvent *struct{} `bson:"topologyDescriptionChangedEvent"` +} + type expectedEvents struct { ClientID string `bson:"client"` CommandEvents []commandMonitoringEvent CMAPEvents []cmapEvent + SDAMEvents []sdamEvent IgnoreExtraEvents *bool } @@ -102,6 +130,8 @@ func (e *expectedEvents) UnmarshalBSON(data []byte) error { target = &e.CommandEvents case "cmap": target = &e.CMAPEvents + case "sdam": + target = &e.SDAMEvents default: return fmt.Errorf("unrecognized 'eventType' value for expectedEvents: %q", temp.EventType) } @@ -127,6 +157,8 @@ func verifyEvents(ctx context.Context, expectedEvents *expectedEvents) error { return verifyCommandEvents(ctx, client, expectedEvents) case expectedEvents.CMAPEvents != nil: return verifyCMAPEvents(client, expectedEvents) + case expectedEvents.SDAMEvents != nil: + return verifySDAMEvents(client, expectedEvents) } return nil } @@ -405,3 +437,145 @@ func stringifyEventsForClient(client *clientEntity) string { return str.String() } + +func getNextServerDescriptionChangedEvent( + events []*event.ServerDescriptionChangedEvent, +) (*event.ServerDescriptionChangedEvent, []*event.ServerDescriptionChangedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no server changed event published") + } + + return events[0], events[1:], nil +} + +func getNextServerHeartbeatStartedEvent( + events []*event.ServerHeartbeatStartedEvent, +) (*event.ServerHeartbeatStartedEvent, []*event.ServerHeartbeatStartedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat started event published") + } + + return events[0], events[1:], nil +} + +func getNextServerHeartbeatSucceededEvent( + events []*event.ServerHeartbeatSucceededEvent, +) (*event.ServerHeartbeatSucceededEvent, []*event.ServerHeartbeatSucceededEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat succeeded event published") + } + + return events[0], events[:1], nil +} + +func getNextServerHeartbeatFailedEvent( + events []*event.ServerHeartbeatFailedEvent, +) (*event.ServerHeartbeatFailedEvent, []*event.ServerHeartbeatFailedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no heartbeat failed event published") + } + + return events[0], events[:1], nil +} + +func getNextTopologyDescriptionChangedEvent( + events []*event.TopologyDescriptionChangedEvent, +) (*event.TopologyDescriptionChangedEvent, []*event.TopologyDescriptionChangedEvent, error) { + if len(events) == 0 { + return nil, nil, errors.New("no topology description changed event published") + } + + return events[0], events[:1], nil +} + +func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error { + var ( + changed = client.serverDescriptionChanged + started = client.serverHeartbeatStartedEvent + succeeded = client.serverHeartbeatSucceeded + failed = client.serverHeartbeatFailedEvent + tchanged = client.topologyDescriptionChanged + ) + + vol := func() int { return len(changed) + len(started) + len(succeeded) + len(failed) + len(tchanged) } + + if len(expectedEvents.SDAMEvents) == 0 && vol() != 0 { + return fmt.Errorf("expected no sdam events to be sent but got %s", stringifyEventsForClient(client)) + } + + for idx, evt := range expectedEvents.SDAMEvents { + var err error + + switch { + case evt.ServerDescriptionChangedEvent != nil: + var got *event.ServerDescriptionChangedEvent + if got, changed, err = getNextServerDescriptionChangedEvent(changed); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + prevDesc := evt.ServerDescriptionChangedEvent.NewDescription + + var wantPrevDesc string + if prevDesc != nil && prevDesc.Type != nil { + wantPrevDesc = *prevDesc.Type + } + + gotPrevDesc := got.PreviousDescription.Kind.String() + if gotPrevDesc != wantPrevDesc { + return newEventVerificationError(idx, client, + "expected previous server description %q, got %q", wantPrevDesc, gotPrevDesc) + } + + newDesc := evt.ServerDescriptionChangedEvent.PreviousDescription + + var wantNewDesc string + if newDesc != nil && newDesc.Type != nil { + wantNewDesc = *newDesc.Type + } + + gotNewDesc := got.NewDescription.Kind.String() + if gotNewDesc != wantNewDesc { + return newEventVerificationError(idx, client, + "expected new server description %q, got %q", wantNewDesc, gotNewDesc) + } + case evt.ServerHeartbeatStartedEvent != nil: + var got *event.ServerHeartbeatStartedEvent + if got, started, err = getNextServerHeartbeatStartedEvent(started); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatStartedEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.ServerHeartbeatSucceededEvent != nil: + var got *event.ServerHeartbeatSucceededEvent + if got, succeeded, err = getNextServerHeartbeatSucceededEvent(succeeded); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatSucceededEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.ServerHeartbeatFailedEvent != nil: + var got *event.ServerHeartbeatFailedEvent + if got, failed, err = getNextServerHeartbeatFailedEvent(failed); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + + if want := evt.ServerHeartbeatFailedEvent.Awaited; want != nil && *want != got.Awaited { + return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited) + } + case evt.TopologyDescriptionChangedEvent != nil: + if _, tchanged, err = getNextTopologyDescriptionChangedEvent(tchanged); err != nil { + return newEventVerificationError(idx, client, err.Error()) + } + } + } + + // Verify that there are no remaining events if ignoreExtraEvents is unset or false. + ignoreExtraEvents := expectedEvents.IgnoreExtraEvents != nil && *expectedEvents.IgnoreExtraEvents + if !ignoreExtraEvents && vol() > 0 { + return fmt.Errorf("extra sdam events published; all events for client: %s", stringifyEventsForClient(client)) + } + return nil +} diff --git a/mongo/integration/unified/schema_version.go b/mongo/integration/unified/schema_version.go index c85a2efa79..9aec89a18d 100644 --- a/mongo/integration/unified/schema_version.go +++ b/mongo/integration/unified/schema_version.go @@ -16,7 +16,7 @@ import ( var ( supportedSchemaVersions = map[int]string{ - 1: "1.16", + 1: "1.17", } ) diff --git a/mongo/integration/unified/testrunner_operation.go b/mongo/integration/unified/testrunner_operation.go index 474c01c88a..297ebbdf5d 100644 --- a/mongo/integration/unified/testrunner_operation.go +++ b/mongo/integration/unified/testrunner_operation.go @@ -19,7 +19,12 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var waitForEventTimeout = 10 * time.Second +// waitForEventTimeout is the amount of time to wait for an event to occur. The +// maximum amount of time expected for this value is currently 10 seconds, which +// is the amoutn of time that the driver will attempt to wait between streamable +// heartbeats. Increase this value if a new maximum time is expected in another +// operation. +var waitForEventTimeout = 11 * time.Second type loopArgs struct { Operations []*operation `bson:"operations"` diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 4572d331cf..764b0c38ef 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -33,6 +33,26 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + // ServerMonitoringModeAuto indicates that the client will behave like "poll" + // mode when running on a FaaS (Function as a Service) platform, or like + // "stream" mode otherwise. The client detects its execution environment by + // following the rules for generating the "client.env" handshake metadata field + // as specified in the MongoDB Handshake specification. This is the default + // mode. + ServerMonitoringModeAuto = connstring.ServerMonitoringModeAuto + + // ServerMonitoringModePoll indicates that the client will periodically check + // the server using a hello or legacy hello command and then sleep for + // heartbeatFrequencyMS milliseconds before running another check. + ServerMonitoringModePoll = connstring.ServerMonitoringModePoll + + // ServerMonitoringModeStream indicates that the client will use a streaming + // protocol when the server supports it. The streaming protocol optimally + // reduces the time it takes for a client to discover server state changes. + ServerMonitoringModeStream = connstring.ServerMonitoringModeStream +) + // ContextDialer is an interface that can be implemented by types that can create connections. It should be used to // provide a custom dialer when configuring a Client. // @@ -206,6 +226,7 @@ type ClientOptions struct { RetryReads *bool RetryWrites *bool ServerAPIOptions *ServerAPIOptions + ServerMonitoringMode *string ServerSelectionTimeout *time.Duration SRVMaxHosts *int SRVServiceName *string @@ -300,6 +321,11 @@ func (c *ClientOptions) validate() error { return connstring.ErrSRVMaxHostsWithLoadBalanced } } + + if mode := c.ServerMonitoringMode; mode != nil && !connstring.IsValidServerMonitoringMode(*mode) { + return fmt.Errorf("invalid server monitoring mode: %q", *mode) + } + return nil } @@ -937,6 +963,16 @@ func (c *ClientOptions) SetServerAPIOptions(opts *ServerAPIOptions) *ClientOptio return c } +// SetServerMonitoringMode specifies the server monitoring protocol to use. See +// the helper constants ServerMonitoringModeAuto, ServerMonitoringModePoll, and +// ServerMonitoringModeStream for more information about valid server +// monitoring modes. +func (c *ClientOptions) SetServerMonitoringMode(mode string) *ClientOptions { + c.ServerMonitoringMode = &mode + + return c +} + // SetSRVMaxHosts specifies the maximum number of SRV results to randomly select during polling. To limit the number // of hosts selected in SRV discovery, this function must be called before ApplyURI. This can also be set through // the "srvMaxHosts" URI option. @@ -1096,6 +1132,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.LoggerOptions != nil { c.LoggerOptions = opt.LoggerOptions } + if opt.ServerMonitoringMode != nil { + c.ServerMonitoringMode = opt.ServerMonitoringMode + } } return c diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 9db9b6b82f..9633f0131d 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -760,6 +760,52 @@ func TestClientOptions(t *testing.T) { }) } }) + t.Run("server monitoring mode validation", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + opts *ClientOptions + err error + }{ + { + name: "undefined", + opts: Client(), + err: nil, + }, + { + name: "auto", + opts: Client().SetServerMonitoringMode(ServerMonitoringModeAuto), + err: nil, + }, + { + name: "poll", + opts: Client().SetServerMonitoringMode(ServerMonitoringModePoll), + err: nil, + }, + { + name: "stream", + opts: Client().SetServerMonitoringMode(ServerMonitoringModeStream), + err: nil, + }, + { + name: "invalid", + opts: Client().SetServerMonitoringMode("invalid"), + err: errors.New("invalid server monitoring mode: \"invalid\""), + }, + } + + for _, tc := range testCases { + tc := tc // Capture the range variable + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := tc.opts.Validate() + assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) + }) + } + }) } func createCertPool(t *testing.T, paths ...string) *x509.CertPool { diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json new file mode 100644 index 0000000000..7d681b4f9e --- /dev/null +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.json @@ -0,0 +1,449 @@ +{ + "description": "serverMonitoringMode", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "topologies": [ + "single", + "sharded", + "sharded-replicaset" + ], + "serverless": "forbid" + } + ], + "tests": [ + { + "description": "connect with serverMonitoringMode=auto >=4.4", + "runOnRequirements": [ + { + "minServerVersion": "4.4.0" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "auto" + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": true + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=auto <4.4", + "runOnRequirements": [ + { + "maxServerVersion": "4.2.99" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "auto", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=stream >=4.4", + "runOnRequirements": [ + { + "minServerVersion": "4.4.0" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "stream" + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": true + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=stream <4.4", + "runOnRequirements": [ + { + "maxServerVersion": "4.2.99" + } + ], + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "stream", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] + } + ] + }, + { + "description": "connect with serverMonitoringMode=poll", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "uriOptions": { + "serverMonitoringMode": "poll", + "heartbeatFrequencyMS": 500 + }, + "useMultipleMongoses": false, + "observeEvents": [ + "serverHeartbeatStartedEvent", + "serverHeartbeatSucceededEvent", + "serverHeartbeatFailedEvent" + ] + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "sdam-tests" + } + } + ] + } + }, + { + "name": "runCommand", + "object": "db", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectResult": { + "ok": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverHeartbeatStartedEvent": {} + }, + "count": 2 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "ignoreExtraEvents": true, + "events": [ + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + }, + { + "serverHeartbeatSucceededEvent": { + "awaited": false + } + }, + { + "serverHeartbeatStartedEvent": { + "awaited": false + } + } + ] + } + ] + } + ] +} diff --git a/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml new file mode 100644 index 0000000000..28c7853d04 --- /dev/null +++ b/testdata/server-discovery-and-monitoring/unified/serverMonitoringMode.yml @@ -0,0 +1,173 @@ +description: serverMonitoringMode + +schemaVersion: "1.17" +# These tests cannot run on replica sets because the order of the expected +# SDAM events are non-deterministic when monitoring multiple servers. +# They also cannot run on Serverless or load balanced clusters where SDAM is disabled. +runOnRequirements: + - topologies: [single, sharded, sharded-replicaset] + serverless: forbid +tests: + - description: "connect with serverMonitoringMode=auto >=4.4" + runOnRequirements: + - minServerVersion: "4.4.0" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "auto" + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - &ping + name: runCommand + object: db + arguments: + commandName: ping + command: { ping: 1 } + expectResult: { ok: 1 } + # Wait for the second serverHeartbeatStartedEvent to ensure we start streaming. + - &waitForSecondHeartbeatStarted + name: waitForEvent + object: testRunner + arguments: + client: client + event: + serverHeartbeatStartedEvent: {} + count: 2 + expectEvents: &streamingStartedEvents + - client: client + eventType: sdam + ignoreExtraEvents: true + events: + - serverHeartbeatStartedEvent: + awaited: False + - serverHeartbeatSucceededEvent: + awaited: False + - serverHeartbeatStartedEvent: + awaited: True + + - description: "connect with serverMonitoringMode=auto <4.4" + runOnRequirements: + - maxServerVersion: "4.2.99" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "auto" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: &pollingStartedEvents + - client: client + eventType: sdam + ignoreExtraEvents: true + events: + - serverHeartbeatStartedEvent: + awaited: False + - serverHeartbeatSucceededEvent: + awaited: False + - serverHeartbeatStartedEvent: + awaited: False + + - description: "connect with serverMonitoringMode=stream >=4.4" + runOnRequirements: + - minServerVersion: "4.4.0" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "stream" + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we start streaming. + - *waitForSecondHeartbeatStarted + expectEvents: *streamingStartedEvents + + - description: "connect with serverMonitoringMode=stream <4.4" + runOnRequirements: + - maxServerVersion: "4.2.99" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "stream" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: *pollingStartedEvents + + - description: "connect with serverMonitoringMode=poll" + operations: + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: client + uriOptions: + serverMonitoringMode: "poll" + heartbeatFrequencyMS: 500 + useMultipleMongoses: false + observeEvents: + - serverHeartbeatStartedEvent + - serverHeartbeatSucceededEvent + - serverHeartbeatFailedEvent + - database: + id: db + client: client + databaseName: sdam-tests + - *ping + # Wait for the second serverHeartbeatStartedEvent to ensure we do not stream. + - *waitForSecondHeartbeatStarted + expectEvents: *pollingStartedEvents diff --git a/testdata/uri-options/sdam-options.json b/testdata/uri-options/sdam-options.json new file mode 100644 index 0000000000..673f5607ee --- /dev/null +++ b/testdata/uri-options/sdam-options.json @@ -0,0 +1,46 @@ +{ + "tests": [ + { + "description": "serverMonitoringMode=auto", + "uri": "mongodb://example.com/?serverMonitoringMode=auto", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "auto" + } + }, + { + "description": "serverMonitoringMode=stream", + "uri": "mongodb://example.com/?serverMonitoringMode=stream", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "stream" + } + }, + { + "description": "serverMonitoringMode=poll", + "uri": "mongodb://example.com/?serverMonitoringMode=poll", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "serverMonitoringMode": "poll" + } + }, + { + "description": "invalid serverMonitoringMode", + "uri": "mongodb://example.com/?serverMonitoringMode=invalid", + "valid": true, + "warning": true, + "hosts": null, + "auth": null, + "options": {} + } + ] +} diff --git a/testdata/uri-options/sdam-options.yml b/testdata/uri-options/sdam-options.yml new file mode 100644 index 0000000000..8f72ff4098 --- /dev/null +++ b/testdata/uri-options/sdam-options.yml @@ -0,0 +1,35 @@ +tests: + - description: "serverMonitoringMode=auto" + uri: "mongodb://example.com/?serverMonitoringMode=auto" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "auto" + + - description: "serverMonitoringMode=stream" + uri: "mongodb://example.com/?serverMonitoringMode=stream" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "stream" + + - description: "serverMonitoringMode=poll" + uri: "mongodb://example.com/?serverMonitoringMode=poll" + valid: true + warning: false + hosts: ~ + auth: ~ + options: + serverMonitoringMode: "poll" + + - description: "invalid serverMonitoringMode" + uri: "mongodb://example.com/?serverMonitoringMode=invalid" + valid: true + warning: true + hosts: ~ + auth: ~ + options: {} diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 983c1dab22..cd43136471 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -21,6 +21,26 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + // ServerMonitoringModeAuto indicates that the client will behave like "poll" + // mode when running on a FaaS (Function as a Service) platform, or like + // "stream" mode otherwise. The client detects its execution environment by + // following the rules for generating the "client.env" handshake metadata field + // as specified in the MongoDB Handshake specification. This is the default + // mode. + ServerMonitoringModeAuto = "auto" + + // ServerMonitoringModePoll indicates that the client will periodically check + // the server using a hello or legacy hello command and then sleep for + // heartbeatFrequencyMS milliseconds before running another check. + ServerMonitoringModePoll = "poll" + + // ServerMonitoringModeStream indicates that the client will use a streaming + // protocol when the server supports it. The streaming protocol optimally + // reduces the time it takes for a client to discover server state changes. + ServerMonitoringModeStream = "stream" +) + var ( // ErrLoadBalancedWithMultipleHosts is returned when loadBalanced=true is // specified in a URI with multiple hosts. @@ -125,6 +145,7 @@ type ConnString struct { MaxStalenessSet bool ReplicaSet string Scheme string + ServerMonitoringMode string ServerSelectionTimeout time.Duration ServerSelectionTimeoutSet bool SocketTimeout time.Duration @@ -621,6 +642,14 @@ func (p *parser) addHost(host string) error { return nil } +// IsValidServerMonitoringMode will return true if the given string matches a +// valid server monitoring mode. +func IsValidServerMonitoringMode(mode string) bool { + return mode == ServerMonitoringModeAuto || + mode == ServerMonitoringModeStream || + mode == ServerMonitoringModePoll +} + func (p *parser) addOption(pair string) error { kv := strings.SplitN(pair, "=", 2) if len(kv) != 2 || kv[0] == "" { @@ -823,6 +852,12 @@ func (p *parser) addOption(pair string) error { } p.RetryReadsSet = true + case "servermonitoringmode": + if !IsValidServerMonitoringMode(value) { + return fmt.Errorf("invalid value for %q: %q", key, value) + } + + p.ServerMonitoringMode = value case "serverselectiontimeoutms": n, err := strconv.Atoi(value) if err != nil || n < 0 { diff --git a/x/mongo/driver/connstring/connstring_spec_test.go b/x/mongo/driver/connstring/connstring_spec_test.go index a5f646297c..699ae16bdb 100644 --- a/x/mongo/driver/connstring/connstring_spec_test.go +++ b/x/mongo/driver/connstring/connstring_spec_test.go @@ -286,6 +286,8 @@ func verifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map require.Equal(t, value, float64(cs.ZstdLevel)) case "tlsdisableocspendpointcheck": require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck) + case "servermonitoringmode": + require.Equal(t, value, cs.ServerMonitoringMode) default: opt, ok := cs.UnknownOptions[key] require.True(t, ok) diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index de7e05cb5f..b3c4d7ee7f 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/bsonutil" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -31,7 +32,6 @@ import ( // sharded clusters is 512. const maxClientMetadataSize = 512 -const awsLambdaPrefix = "AWS_Lambda_" const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. @@ -125,36 +125,7 @@ func (h *Hello) Result(addr address.Address) description.Server { return description.NewServer(addr, bson.Raw(h.res)) } -const ( - // FaaS environment variable names - envVarAWSExecutionEnv = "AWS_EXECUTION_ENV" - envVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" - envVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" - envVarKService = "K_SERVICE" - envVarFunctionName = "FUNCTION_NAME" - envVarVercel = "VERCEL" -) - -const ( - // FaaS environment variable names - envVarAWSRegion = "AWS_REGION" - envVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" - envVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" - envVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" - envVarFunctionRegion = "FUNCTION_REGION" - envVarVercelRegion = "VERCEL_REGION" -) - -const ( - // FaaS environment names used by the client - envNameAWSLambda = "aws.lambda" - envNameAzureFunc = "azure.func" - envNameGCPFunc = "gcp.func" - envNameVercel = "vercel" -) - const dockerEnvPath = "/.dockerenv" -const envVarK8s = "KUBERNETES_SERVICE_HOST" const ( // Runtime names @@ -172,12 +143,12 @@ const ( // values to be entirely omitted. func getFaasEnvName() string { envVars := []string{ - envVarAWSExecutionEnv, - envVarAWSLambdaRuntimeAPI, - envVarFunctionsWorkerRuntime, - envVarKService, - envVarFunctionName, - envVarVercel, + driverutil.EnvVarAWSExecutionEnv, + driverutil.EnvVarAWSLambdaRuntimeAPI, + driverutil.EnvVarFunctionsWorkerRuntime, + driverutil.EnvVarKService, + driverutil.EnvVarFunctionName, + driverutil.EnvVarVercel, } // If none of the variables are populated the client.env value MUST be @@ -193,23 +164,23 @@ func getFaasEnvName() string { var name string switch envVar { - case envVarAWSExecutionEnv: - if !strings.HasPrefix(val, awsLambdaPrefix) { + case driverutil.EnvVarAWSExecutionEnv: + if !strings.HasPrefix(val, driverutil.AwsLambdaPrefix) { continue } - name = envNameAWSLambda - case envVarAWSLambdaRuntimeAPI: - name = envNameAWSLambda - case envVarFunctionsWorkerRuntime: - name = envNameAzureFunc - case envVarKService, envVarFunctionName: - name = envNameGCPFunc - case envVarVercel: + name = driverutil.EnvNameAWSLambda + case driverutil.EnvVarAWSLambdaRuntimeAPI: + name = driverutil.EnvNameAWSLambda + case driverutil.EnvVarFunctionsWorkerRuntime: + name = driverutil.EnvNameAzureFunc + case driverutil.EnvVarKService, driverutil.EnvVarFunctionName: + name = driverutil.EnvNameGCPFunc + case driverutil.EnvVarVercel: // "vercel" takes precedence over "aws.lambda". - delete(names, envNameAWSLambda) + delete(names, driverutil.EnvNameAWSLambda) - name = envNameVercel + name = driverutil.EnvNameVercel } names[name] = struct{}{} @@ -242,7 +213,7 @@ func getContainerEnvInfo() *containerInfo { if _, err := os.Stat(dockerEnvPath); !os.IsNotExist(err) { runtime = runtimeNameDocker } - if v := os.Getenv(envVarK8s); v != "" { + if v := os.Getenv(driverutil.EnvVarK8s); v != "" { orchestrator = orchestratorNameK8s } if runtime != "" || orchestrator != "" { @@ -350,15 +321,15 @@ func appendClientEnv(dst []byte, omitNonName, omitDoc bool) ([]byte, error) { if !omitNonName { // No other FaaS fields will be populated if the name is empty. switch name { - case envNameAWSLambda: - dst = addMem(envVarAWSLambdaFunctionMemorySize) - dst = addRegion(envVarAWSRegion) - case envNameGCPFunc: - dst = addMem(envVarFunctionMemoryMB) - dst = addRegion(envVarFunctionRegion) - dst = addTimeout(envVarFunctionTimeoutSec) - case envNameVercel: - dst = addRegion(envVarVercelRegion) + case driverutil.EnvNameAWSLambda: + dst = addMem(driverutil.EnvVarAWSLambdaFunctionMemorySize) + dst = addRegion(driverutil.EnvVarAWSRegion) + case driverutil.EnvNameGCPFunc: + dst = addMem(driverutil.EnvVarFunctionMemoryMB) + dst = addRegion(driverutil.EnvVarFunctionRegion) + dst = addTimeout(driverutil.EnvVarFunctionTimeoutSec) + case driverutil.EnvNameVercel: + dst = addRegion(driverutil.EnvVarVercelRegion) } } diff --git a/x/mongo/driver/operation/hello_test.go b/x/mongo/driver/operation/hello_test.go index b33d7632cd..114f53b617 100644 --- a/x/mongo/driver/operation/hello_test.go +++ b/x/mongo/driver/operation/hello_test.go @@ -13,6 +13,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -54,18 +55,18 @@ func encodeWithCallback(t *testing.T, cb func(int, []byte) ([]byte, error)) bson // ensure that the local environment does not effect the outcome of a unit // test. func clearTestEnv(t *testing.T) { - t.Setenv(envVarAWSExecutionEnv, "") - t.Setenv(envVarAWSLambdaRuntimeAPI, "") - t.Setenv(envVarFunctionsWorkerRuntime, "") - t.Setenv(envVarKService, "") - t.Setenv(envVarFunctionName, "") - t.Setenv(envVarVercel, "") - t.Setenv(envVarAWSRegion, "") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "") - t.Setenv(envVarFunctionMemoryMB, "") - t.Setenv(envVarFunctionTimeoutSec, "") - t.Setenv(envVarFunctionRegion, "") - t.Setenv(envVarVercelRegion, "") + t.Setenv("AWS_EXECUTION_ENV", "") + t.Setenv("AWS_LAMBDA_RUNTIME_API", "") + t.Setenv("FUNCTIONS_WORKER_RUNTIME", "") + t.Setenv("K_SERVICE", "") + t.Setenv("FUNCTION_NAME", "") + t.Setenv("VERCEL", "") + t.Setenv("AWS_REGION", "") + t.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "") + t.Setenv("FUNCTION_MEMORY_MB", "") + t.Setenv("FUNCTION_TIMEOUT_SEC", "") + t.Setenv("FUNCTION_REGION", "") + t.Setenv("VERCEL_REGION", "") } func TestAppendClientName(t *testing.T) { @@ -159,32 +160,32 @@ func TestAppendClientEnv(t *testing.T) { { name: "aws only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", }, want: []byte(`{"env":{"name":"aws.lambda"}}`), }, { name: "aws mem only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, want: []byte(`{"env":{"name":"aws.lambda","memory_mb":1024}}`), }, { name: "aws region only", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSRegion: "us-east-2", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "AWS_REGION": "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda","region":"us-east-2"}}`), }, { name: "aws mem and region", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", - envVarAWSRegion: "us-east-2", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", + "AWS_REGION": "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda","memory_mb":1024,"region":"us-east-2"}}`), }, @@ -192,50 +193,50 @@ func TestAppendClientEnv(t *testing.T) { name: "aws mem and region with omit fields", omitEnvFields: true, env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaFunctionMemorySize: "1024", - envVarAWSRegion: "us-east-2", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", + "AWS_REGION": "us-east-2", }, want: []byte(`{"env":{"name":"aws.lambda"}}`), }, { name: "gcp only", env: map[string]string{ - envVarKService: "servicename", + "K_SERVICE": "servicename", }, want: []byte(`{"env":{"name":"gcp.func"}}`), }, { name: "gcp mem", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionMemoryMB: "1024", + "K_SERVICE": "servicename", + "FUNCTION_MEMORY_MB": "1024", }, want: []byte(`{"env":{"name":"gcp.func","memory_mb":1024}}`), }, { name: "gcp region", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionRegion: "us-east-2", + "K_SERVICE": "servicename", + "FUNCTION_REGION": "us-east-2", }, want: []byte(`{"env":{"name":"gcp.func","region":"us-east-2"}}`), }, { name: "gcp timeout", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", + "K_SERVICE": "servicename", + "FUNCTION_TIMEOUT_SEC": "1", }, want: []byte(`{"env":{"name":"gcp.func","timeout_sec":1}}`), }, { name: "gcp mem, region, and timeout", env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", - envVarFunctionRegion: "us-east-2", - envVarFunctionMemoryMB: "1024", + "K_SERVICE": "servicename", + "FUNCTION_TIMEOUT_SEC": "1", + "FUNCTION_REGION": "us-east-2", + "FUNCTION_MEMORY_MB": "1024", }, want: []byte(`{"env":{"name":"gcp.func","memory_mb":1024,"region":"us-east-2","timeout_sec":1}}`), }, @@ -243,39 +244,39 @@ func TestAppendClientEnv(t *testing.T) { name: "gcp mem, region, and timeout with omit fields", omitEnvFields: true, env: map[string]string{ - envVarKService: "servicename", - envVarFunctionTimeoutSec: "1", - envVarFunctionRegion: "us-east-2", - envVarFunctionMemoryMB: "1024", + "K_SERVICE": "servicename", + "FUNCTION_TIMEOUT_SEC": "1", + "FUNCTION_REGION": "us-east-2", + "FUNCTION_MEMORY_MB": "1024", }, want: []byte(`{"env":{"name":"gcp.func"}}`), }, { name: "vercel only", env: map[string]string{ - envVarVercel: "1", + "VERCEL": "1", }, want: []byte(`{"env":{"name":"vercel"}}`), }, { name: "vercel region", env: map[string]string{ - envVarVercel: "1", - envVarVercelRegion: "us-east-2", + "VERCEL": "1", + "VERCEL_REGION": "us-east-2", }, want: []byte(`{"env":{"name":"vercel","region":"us-east-2"}}`), }, { name: "azure only", env: map[string]string{ - envVarFunctionsWorkerRuntime: "go1.x", + "FUNCTIONS_WORKER_RUNTIME": "go1.x", }, want: []byte(`{"env":{"name":"azure.func"}}`), }, { name: "k8s", env: map[string]string{ - envVarK8s: "0.0.0.0", + "KUBERNETES_SERVICE_HOST": "0.0.0.0", }, want: []byte(`{"env":{"container":{"orchestrator":"kubernetes"}}}`), }, @@ -419,10 +420,10 @@ func TestEncodeClientMetadata(t *testing.T) { } // Set environment variables to add `env` field to handshake. - t.Setenv(envVarAWSLambdaRuntimeAPI, "lambda") - t.Setenv(envVarAWSLambdaFunctionMemorySize, "123") - t.Setenv(envVarAWSRegion, "us-east-2") - t.Setenv(envVarK8s, "0.0.0.0") + t.Setenv("AWS_LAMBDA_RUNTIME_API", "lambda") + t.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "123") + t.Setenv("AWS_REGION", "us-east-2") + t.Setenv("KUBERNETES_SERVICE_HOST", "0.0.0.0") t.Run("nothing is omitted", func(t *testing.T) { got, err := encodeClientMetadata("foo", maxClientMetadataSize) @@ -434,7 +435,7 @@ func TestEncodeClientMetadata(t *testing.T) { OS: &dist{Type: runtime.GOOS, Architecture: runtime.GOARCH}, Platform: runtime.Version(), Env: &env{ - Name: envNameAWSLambda, + Name: "aws.lambda", MemoryMB: 123, Region: "us-east-2", Container: &container{ @@ -460,7 +461,7 @@ func TestEncodeClientMetadata(t *testing.T) { OS: &dist{Type: runtime.GOOS, Architecture: runtime.GOARCH}, Platform: runtime.Version(), Env: &env{ - Name: envNameAWSLambda, + Name: "aws.lambda", Container: &container{ Orchestrator: "kubernetes", }, @@ -480,7 +481,7 @@ func TestEncodeClientMetadata(t *testing.T) { require.NoError(t, err, "error constructing env template: %v", err) // Calculate what the env.name costs. - ndst := bsoncore.AppendStringElement(nil, "name", envNameAWSLambda) + ndst := bsoncore.AppendStringElement(nil, "name", "aws.lambda") idx, ndst := bsoncore.AppendDocumentElementStart(ndst, "container") ndst = bsoncore.AppendStringElement(ndst, "orchestrator", "kubernetes") ndst, err = bsoncore.AppendDocumentEnd(ndst, idx) @@ -498,7 +499,7 @@ func TestEncodeClientMetadata(t *testing.T) { OS: &dist{Type: runtime.GOOS}, Platform: runtime.Version(), Env: &env{ - Name: envNameAWSLambda, + Name: "aws.lambda", Container: &container{ Orchestrator: "kubernetes", }, @@ -588,38 +589,38 @@ func TestParseFaasEnvName(t *testing.T) { { name: "one aws", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", }, - want: envNameAWSLambda, + want: "aws.lambda", }, { name: "both aws options", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarAWSLambdaRuntimeAPI: "hello", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "AWS_LAMBDA_RUNTIME_API": "hello", }, - want: envNameAWSLambda, + want: "aws.lambda", }, { name: "multiple variables", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarFunctionsWorkerRuntime: "hello", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "FUNCTIONS_WORKER_RUNTIME": "hello", }, want: "", }, { name: "vercel and aws lambda", env: map[string]string{ - envVarAWSExecutionEnv: "AWS_Lambda_foo", - envVarVercel: "hello", + "AWS_EXECUTION_ENV": "AWS_Lambda_foo", + "VERCEL": "hello", }, - want: envNameVercel, + want: "vercel", }, { name: "invalid aws prefix", env: map[string]string{ - envVarAWSExecutionEnv: "foo", + "AWS_EXECUTION_ENV": "foo", }, want: "", }, @@ -633,7 +634,7 @@ func TestParseFaasEnvName(t *testing.T) { t.Setenv(key, value) } - got := getFaasEnvName() + got := driverutil.GetFaasEnvName() if got != test.want { t.Errorf("parseFaasEnvName(%s) = %s, want %s", test.name, got, test.want) } @@ -659,14 +660,14 @@ func BenchmarkClientMetadtaLargeEnv(b *testing.B) { b.ReportAllocs() b.ResetTimer() - b.Setenv(envNameAWSLambda, "foo") + b.Setenv("aws.lambda", "foo") str := "" for i := 0; i < 512; i++ { str += "a" } - b.Setenv(envVarAWSLambdaRuntimeAPI, str) + b.Setenv("AWS_LAMBDA_RUNTIME_API", str) b.RunParallel(func(pb *testing.PB) { for pb.Next() { diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 998d2a0253..eacc6bf6d3 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -39,7 +39,12 @@ type rttConfig struct { } type rttMonitor struct { - mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet + mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet + + // connMu guards connecting and disconnecting. This is necessary since + // disconnecting will await the cancelation of a started connection. The + // use case for rttMonitor.connect needs to be goroutine safe. + connMu sync.Mutex samples []time.Duration offset int minRTT time.Duration @@ -51,6 +56,7 @@ type rttMonitor struct { cfg *rttConfig ctx context.Context cancelFn context.CancelFunc + started bool } var _ driver.RTTMonitor = &rttMonitor{} @@ -74,19 +80,34 @@ func newRTTMonitor(cfg *rttConfig) *rttMonitor { } func (r *rttMonitor) connect() { + r.connMu.Lock() + defer r.connMu.Unlock() + + r.started = true r.closeWg.Add(1) - go r.start() + + go func() { + defer r.closeWg.Done() + + r.start() + }() } func (r *rttMonitor) disconnect() { - // Signal for the routine to stop. + r.connMu.Lock() + defer r.connMu.Unlock() + + if !r.started { + return + } + r.cancelFn() + + // Wait for the existing connection to complete. r.closeWg.Wait() } func (r *rttMonitor) start() { - defer r.closeWg.Done() - var conn *connection defer func() { if conn != nil { diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 29c96e9a72..a20ad729f1 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -17,10 +17,12 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -131,7 +133,12 @@ type updateTopologyCallback func(description.Server) description.Server // ConnectServer creates a new Server and then initializes it using the // Connect method. -func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) { +func ConnectServer( + addr address.Address, + updateCallback updateTopologyCallback, + topologyID primitive.ObjectID, + opts ...ServerOption, +) (*Server, error) { srvr := NewServer(addr, topologyID, opts...) err := srvr.Connect(updateCallback) if err != nil { @@ -239,7 +246,6 @@ func (s *Server) Connect(updateCallback updateTopologyCallback) error { s.updateTopologyCallback.Store(updateCallback) if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced { - s.rttMonitor.connect() s.closewg.Add(1) go s.update() } @@ -648,12 +654,15 @@ func (s *Server) update() { // If the server supports streaming or we're already streaming, we want to move to streaming the next response // without waiting. If the server has transitioned to Unknown from a network error, we want to do another // check without waiting in case it was a transient error and the server isn't actually down. - serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming() transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil && previousDescription.Kind != description.Unknown - if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError { + if isStreamingEnabled(s) && isStreamable(s) && !s.rttMonitor.started { + s.rttMonitor.connect() + } + + if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError { continue } @@ -785,10 +794,25 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { return operation. NewHello(). ClusterClock(s.cfg.clock). - Deployment(driver.SingleConnectionDeployment{conn}). + Deployment(driver.SingleConnectionDeployment{C: conn}). ServerAPI(s.cfg.serverAPI) } +func isStreamingEnabled(srv *Server) bool { + switch srv.cfg.serverMonitoringMode { + case connstring.ServerMonitoringModeStream: + return true + case connstring.ServerMonitoringModePoll: + return false + default: + return driverutil.GetFaasEnvName() == "" + } +} + +func isStreamable(srv *Server) bool { + return srv.Description().Kind != description.Unknown && srv.Description().TopologyVersion != nil +} + func (s *Server) check() (description.Server, error) { var descPtr *description.Server var err error @@ -827,9 +851,10 @@ func (s *Server) check() (description.Server, error) { heartbeatConn := initConnection{s.conn} baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() - streamable := previousDescription.TopologyVersion != nil + streamable := isStreamingEnabled(s) && isStreamable(s) s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) + switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. @@ -860,8 +885,16 @@ func (s *Server) check() (description.Server, error) { s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) err = baseOperation.Execute(s.heartbeatCtx) } + duration = time.Since(start) + // We need to record an RTT sample in the polling case so that if the server + // is < 4.4, or if polling is specified by the user, then the + // RTT-short-circuit feature of CSOT is not disabled. + if !streamable { + s.rttMonitor.addSample(duration) + } + if err == nil { tempDesc := baseOperation.Result(s.address) descPtr = &tempDesc diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 4272b3f751..4504a25355 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -14,23 +14,25 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { - clock *session.ClusterClock - compressionOpts []string - connectionOpts []ConnectionOption - appname string - heartbeatInterval time.Duration - heartbeatTimeout time.Duration - serverMonitor *event.ServerMonitor - registry *bsoncodec.Registry - monitoringDisabled bool - serverAPI *driver.ServerAPIOptions - loadBalanced bool + clock *session.ClusterClock + compressionOpts []string + connectionOpts []ConnectionOption + appname string + heartbeatInterval time.Duration + heartbeatTimeout time.Duration + serverMonitoringMode string + serverMonitor *event.ServerMonitor + registry *bsoncodec.Registry + monitoringDisabled bool + serverAPI *driver.ServerAPIOptions + loadBalanced bool // Connection pool options. maxConns uint64 @@ -202,3 +204,17 @@ func withLogger(fn func() *logger.Logger) ServerOption { cfg.logger = fn() } } + +// withServerMonitoringMode configures the mode (stream, poll, or auto) to use +// for monitoring. +func withServerMonitoringMode(mode *string) ServerOption { + return func(cfg *serverConfig) { + if mode != nil { + cfg.serverMonitoringMode = *mode + + return + } + + cfg.serverMonitoringMode = connstring.ServerMonitoringModeAuto + } +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 460b82e406..05b748dca6 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -355,6 +355,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, serverOpts = append( serverOpts, withLogger(func() *logger.Logger { return lgr }), + withServerMonitoringMode(co.ServerMonitoringMode), ) cfgp.logger = lgr