From 797e1606a645057f10917e4e95917f6827264aa0 Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Fri, 16 Apr 2021 12:51:54 -0400 Subject: [PATCH] GODRIVER-1934 Ensure correct CursorOptions are used (#625) (#630) --- mongo/change_stream.go | 15 +++--- mongo/client.go | 7 +++ mongo/collection.go | 10 +--- mongo/database.go | 7 ++- mongo/index_view.go | 2 +- mongo/integration/change_stream_test.go | 24 ++++++++++ mongo/integration/collection_test.go | 62 +++++++++++++++++++++++++ mongo/integration/database_test.go | 55 ++++++++++++++++++++-- mongo/integration/index_view_test.go | 41 ++++++++++++++-- 9 files changed, 195 insertions(+), 28 deletions(-) diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 0bfd99726e..37d72862e5 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -103,11 +103,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in } cs := &ChangeStream{ - client: config.client, - registry: config.registry, - streamType: config.streamType, - options: options.MergeChangeStreamOptions(opts...), - selector: description.ReadPrefSelector(config.readPreference), + client: config.client, + registry: config.registry, + streamType: config.streamType, + options: options.MergeChangeStreamOptions(opts...), + selector: description.ReadPrefSelector(config.readPreference), + cursorOptions: config.client.createBaseCursorOptions(), } cs.sess = sessionFromContext(ctx) @@ -128,9 +129,6 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). Crypt(config.crypt) - if config.crypt != nil { - cs.cursorOptions.Crypt = config.crypt - } if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) } @@ -141,7 +139,6 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in if cs.options.MaxAwaitTime != nil { cs.cursorOptions.MaxTimeMS = int64(time.Duration(*cs.options.MaxAwaitTime) / time.Millisecond) } - cs.cursorOptions.CommandMonitor = cs.client.monitor switch cs.streamType { case ClientStream: diff --git a/mongo/client.go b/mongo/client.go index d266da46d5..fcf627194d 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -926,3 +926,10 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, func (c *Client) NumberSessionsInProgress() int { return c.sessionPool.CheckedOut() } + +func (c *Client) createBaseCursorOptions() driver.CursorOptions { + return driver.CursorOptions{ + CommandMonitor: c.monitor, + Crypt: c.cryptFLE, + } +} diff --git a/mongo/collection.go b/mongo/collection.go index 6921678a04..246880cf16 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -780,10 +780,7 @@ func aggregate(a aggregateParams) (*Cursor, error) { } ao := options.MergeAggregateOptions(a.opts...) - cursorOpts := driver.CursorOptions{ - CommandMonitor: a.client.monitor, - Crypt: a.client.cryptFLE, - } + cursorOpts := a.client.createBaseCursorOptions() op := operation.NewAggregate(pipelineArr). Session(sess). @@ -1139,10 +1136,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE) fo := options.MergeFindOptions(opts...) - cursorOpts := driver.CursorOptions{ - CommandMonitor: coll.client.monitor, - Crypt: coll.client.cryptFLE, - } + cursorOpts := coll.client.createBaseCursorOptions() if fo.AllowDiskUse != nil { op.AllowDiskUse(*fo.AllowDiskUse) diff --git a/mongo/database.go b/mongo/database.go index ebf70843bb..2974cd3d16 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -227,7 +227,7 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} return nil, replaceErrors(err) } - bc, err := op.ResultCursor(driver.CursorOptions{}) + bc, err := op.ResultCursor(db.client.createBaseCursorOptions()) if err != nil { closeImplicitSession(sess) return nil, replaceErrors(err) @@ -362,10 +362,13 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE) + + cursorOpts := db.client.createBaseCursorOptions() if lco.NameOnly != nil { op = op.NameOnly(*lco.NameOnly) } if lco.BatchSize != nil { + cursorOpts.BatchSize = *lco.BatchSize op = op.BatchSize(*lco.BatchSize) } @@ -381,7 +384,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt return nil, replaceErrors(err) } - bc, err := op.Result(driver.CursorOptions{Crypt: db.client.cryptFLE}) + bc, err := op.Result(cursorOpts) if err != nil { closeImplicitSession(sess) return nil, replaceErrors(err) diff --git a/mongo/index_view.go b/mongo/index_view.go index 4b0998bc10..e7def1eb2d 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -97,7 +97,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment) - var cursorOpts driver.CursorOptions + cursorOpts := iv.coll.client.createBaseCursorOptions() lio := options.MergeListIndexesOptions(opts...) if lio.BatchSize != nil { op = op.BatchSize(*lio.BatchSize) diff --git a/mongo/integration/change_stream_test.go b/mongo/integration/change_stream_test.go index 3e9205b721..ac39917ba0 100644 --- a/mongo/integration/change_stream_test.go +++ b/mongo/integration/change_stream_test.go @@ -609,6 +609,30 @@ func TestChangeStream_ReplicaSet(t *testing.T) { // next call to cs.Next should return False since cursor is closed assert.False(mt, cs.Next(mtest.Background), "expected to return false, but returned true") }) + mt.Run("getMore commands are monitored", func(mt *mtest.T) { + cs, err := mt.Coll.Watch(mtest.Background, mongo.Pipeline{}) + assert.Nil(mt, err, "Watch error: %v", err) + defer closeStream(cs) + + _, err = mt.Coll.InsertOne(mtest.Background, bson.M{"x": 1}) + assert.Nil(mt, err, "InsertOne error: %v", err) + + mt.ClearEvents() + assert.True(mt, cs.Next(mtest.Background), "Next returned false with error %v", cs.Err()) + evt := mt.GetStartedEvent() + assert.Equal(mt, "getMore", evt.CommandName, "expected command 'getMore', got %q", evt.CommandName) + }) + mt.Run("killCursors commands are monitored", func(mt *mtest.T) { + cs, err := mt.Coll.Watch(mtest.Background, mongo.Pipeline{}) + assert.Nil(mt, err, "Watch error: %v", err) + defer closeStream(cs) + + mt.ClearEvents() + err = cs.Close(mtest.Background) + assert.Nil(mt, err, "Close error: %v", err) + evt := mt.GetStartedEvent() + assert.Equal(mt, "killCursors", evt.CommandName, "expected command 'killCursors', got %q", evt.CommandName) + }) } func closeStream(cs *mongo.ChangeStream) { diff --git a/mongo/integration/collection_test.go b/mongo/integration/collection_test.go index 06a396f52c..60acf3251f 100644 --- a/mongo/integration/collection_test.go +++ b/mongo/integration/collection_test.go @@ -787,6 +787,18 @@ func TestCollection(t *testing.T) { _, ok := err.(mongo.WriteConcernError) assert.True(mt, ok, "expected error type %v, got %v", mongo.WriteConcernError{}, err) }) + mt.Run("getMore commands are monitored", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertGetMoreCommandsAreMonitored(mt, "aggregate", func() (*mongo.Cursor, error) { + return mt.Coll.Aggregate(mtest.Background, mongo.Pipeline{}, options.Aggregate().SetBatchSize(3)) + }) + }) + mt.Run("killCursors commands are monitored", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertKillCursorsCommandsAreMonitored(mt, "aggregate", func() (*mongo.Cursor, error) { + return mt.Coll.Aggregate(mtest.Background, mongo.Pipeline{}, options.Aggregate().SetBatchSize(3)) + }) + }) }) mt.RunOpts("count documents", noClientOpts, func(mt *mtest.T) { mt.Run("success", func(mt *mtest.T) { @@ -1054,6 +1066,18 @@ func TestCollection(t *testing.T) { }) } }) + mt.Run("getMore commands are monitored", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertGetMoreCommandsAreMonitored(mt, "find", func() (*mongo.Cursor, error) { + return mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(3)) + }) + }) + mt.Run("killCursors commands are monitored", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertKillCursorsCommandsAreMonitored(mt, "find", func() (*mongo.Cursor, error) { + return mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(3)) + }) + }) }) mt.RunOpts("find one", noClientOpts, func(mt *mtest.T) { mt.Run("limit", func(mt *mtest.T) { @@ -1886,3 +1910,41 @@ func create16MBDocument(mt *mtest.T) bsoncore.Document { assert.Equal(mt, targetDocSize, len(doc), "expected document length %v, got %v", targetDocSize, len(doc)) return doc } + +// This is a helper function to ensure that sending getMore commands for a cursor results in command monitoring events +// being published. The cursorFn parameter should be a function that yields a cursor which is open on the server and +// requires at least one getMore to be fully iterated. +func assertGetMoreCommandsAreMonitored(mt *mtest.T, cmdName string, cursorFn func() (*mongo.Cursor, error)) { + mt.Helper() + mt.ClearEvents() + + cursor, err := cursorFn() + assert.Nil(mt, err, "error creating cursor: %v", err) + var docs []bson.D + err = cursor.All(mtest.Background, &docs) + assert.Nil(mt, err, "All error: %v", err) + + // Only assert that the initial command and at least one getMore were sent. The exact number of getMore's required + // is not important. + evt := mt.GetStartedEvent() + assert.Equal(mt, cmdName, evt.CommandName, "expected command %q, got %q", cmdName, evt.CommandName) + evt = mt.GetStartedEvent() + assert.Equal(mt, "getMore", evt.CommandName, "expected command 'getMore', got %q", evt.CommandName) +} + +// This is a helper function to ensure that sending killCursors commands for a cursor results in command monitoring +// events being published. The cursorFn parameter should be a function that yields a cursor which is open on the server. +func assertKillCursorsCommandsAreMonitored(mt *mtest.T, cmdName string, cursorFn func() (*mongo.Cursor, error)) { + mt.Helper() + mt.ClearEvents() + + cursor, err := cursorFn() + assert.Nil(mt, err, "error creating cursor: %v", err) + err = cursor.Close(mtest.Background) + assert.Nil(mt, err, "Close error: %v", err) + + evt := mt.GetStartedEvent() + assert.Equal(mt, cmdName, evt.CommandName, "expected command %q, got %q", cmdName, evt.CommandName) + evt = mt.GetStartedEvent() + assert.Equal(mt, "killCursors", evt.CommandName, "expected command 'killCursors', got %q", evt.CommandName) +} diff --git a/mongo/integration/database_test.go b/mongo/integration/database_test.go index cbeead3849..184a18d036 100644 --- a/mongo/integration/database_test.go +++ b/mongo/integration/database_test.go @@ -170,6 +170,16 @@ func TestDatabase(t *testing.T) { }) mt.RunOpts("list collections", noClientOpts, func(mt *mtest.T) { + createCollections := func(mt *mtest.T, numCollections int) { + mt.Helper() + + for i := 0; i < numCollections; i++ { + mt.CreateCollection(mtest.Collection{ + Name: fmt.Sprintf("list-collections-test-%d", i), + }, true) + } + } + mt.RunOpts("verify results", noClientOpts, func(mt *mtest.T) { testCases := []struct { name string @@ -213,11 +223,7 @@ func TestDatabase(t *testing.T) { }) mt.RunOpts("batch size", mtest.NewOptions().MinServerVersion("3.0"), func(mt *mtest.T) { // Create two new collections so there will be three total. - for i := 0; i < 2; i++ { - mt.CreateCollection(mtest.Collection{ - Name: fmt.Sprintf("list-collections-batchSize-%d", i), - }, true) - } + createCollections(mt, 2) mt.ClearEvents() lcOpts := options.ListCollections().SetBatchSize(2) @@ -230,6 +236,22 @@ func TestDatabase(t *testing.T) { _, err = evt.Command.LookupErr("cursor", "batchSize") assert.Nil(mt, err, "expected command %s to contain key 'batchSize'", evt.Command) }) + + // The BatchSize option is not honored for ListCollections operations on server version 2.6 due to an + // inconsistency in the legacy OP_QUERY code path (GODRIVER-1937). + cmdMonitoringMtOpts := mtest.NewOptions().MinServerVersion("3.0") + mt.RunOpts("getMore commands are monitored", cmdMonitoringMtOpts, func(mt *mtest.T) { + createCollections(mt, 2) + assertGetMoreCommandsAreMonitored(mt, "listCollections", func() (*mongo.Cursor, error) { + return mt.DB.ListCollections(mtest.Background, bson.D{}, options.ListCollections().SetBatchSize(2)) + }) + }) + mt.RunOpts("killCursors commands are monitored", cmdMonitoringMtOpts, func(mt *mtest.T) { + createCollections(mt, 2) + assertKillCursorsCommandsAreMonitored(mt, "listCollections", func() (*mongo.Cursor, error) { + return mt.DB.ListCollections(mtest.Background, bson.D{}, options.ListCollections().SetBatchSize(2)) + }) + }) }) mt.RunOpts("list collection specifications", noClientOpts, func(mt *mtest.T) { @@ -356,6 +378,29 @@ func TestDatabase(t *testing.T) { assert.Equal(mt, tc.numExpected, count, "expected document count %v, got %v", tc.numExpected, count) }) } + + // The find command does not exist on server versions below 3.2. + cmdMonitoringMtOpts := mtest.NewOptions().MinServerVersion("3.2") + mt.RunOpts("getMore commands are monitored", cmdMonitoringMtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertGetMoreCommandsAreMonitored(mt, "find", func() (*mongo.Cursor, error) { + findCmd := bson.D{ + {"find", mt.Coll.Name()}, + {"batchSize", 2}, + } + return mt.DB.RunCommandCursor(mtest.Background, findCmd) + }) + }) + mt.RunOpts("killCursors commands are monitored", cmdMonitoringMtOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + assertKillCursorsCommandsAreMonitored(mt, "find", func() (*mongo.Cursor, error) { + findCmd := bson.D{ + {"find", mt.Coll.Name()}, + {"batchSize", 2}, + } + return mt.DB.RunCommandCursor(mtest.Background, findCmd) + }) + }) }) mt.RunOpts("create collection", noClientOpts, func(mt *mtest.T) { diff --git a/mongo/integration/index_view_test.go b/mongo/integration/index_view_test.go index 2b585bb2e5..a82af3f83b 100644 --- a/mongo/integration/index_view_test.go +++ b/mongo/integration/index_view_test.go @@ -30,9 +30,44 @@ func TestIndexView(t *testing.T) { defer mt.Close() mt.Run("list", func(mt *mtest.T) { - verifyIndexExists(mt, mt.Coll.Indexes(), index{ - Key: bson.D{{"_id", int32(1)}}, - Name: "_id_", + createIndexes := func(mt *mtest.T, numIndexes int) { + mt.Helper() + + models := make([]mongo.IndexModel, 0, numIndexes) + for i, key := 0, 'a'; i < numIndexes; i, key = i+1, key+1 { + models = append(models, mongo.IndexModel{ + Keys: bson.M{string(key): 1}, + }) + } + + _, err := mt.Coll.Indexes().CreateMany(mtest.Background, models) + assert.Nil(mt, err, "CreateMany error: %v", err) + } + + // For server versions below 3.0, we internally execute List() as a legacy OP_QUERY against the system.indexes + // collection. Command monitoring upconversions translate this to a "find" command rather than "listIndexes". + cmdName := "listIndexes" + if mtest.CompareServerVersions(mtest.ServerVersion(), "3.0") < 0 { + cmdName = "find" + } + + mt.Run("_id index is always listed", func(mt *mtest.T) { + verifyIndexExists(mt, mt.Coll.Indexes(), index{ + Key: bson.D{{"_id", int32(1)}}, + Name: "_id_", + }) + }) + mt.Run("getMore commands are monitored", func(mt *mtest.T) { + createIndexes(mt, 2) + assertGetMoreCommandsAreMonitored(mt, cmdName, func() (*mongo.Cursor, error) { + return mt.Coll.Indexes().List(mtest.Background, options.ListIndexes().SetBatchSize(2)) + }) + }) + mt.Run("killCursors commands are monitored", func(mt *mtest.T) { + createIndexes(mt, 2) + assertKillCursorsCommandsAreMonitored(mt, cmdName, func() (*mongo.Cursor, error) { + return mt.Coll.Indexes().List(mtest.Background, options.ListIndexes().SetBatchSize(2)) + }) }) }) mt.RunOpts("create one", noClientOpts, func(mt *mtest.T) {