From bfd25ddfb1497c17dacb852fc534de8b7eaff370 Mon Sep 17 00:00:00 2001 From: Valentin Kovalenko Date: Fri, 30 Jul 2021 11:45:26 -0600 Subject: [PATCH] Support calling `QueryBatchCursor.close` concurrently with other `QueryBatchCursor` methods (#765) JAVA-4183 --- .evergreen/run-load-balancer-tests.sh | 11 + .../com/mongodb/assertions/Assertions.java | 14 +- .../internal/async/AsyncBatchCursor.java | 4 +- .../internal/binding/ConnectionSource.java | 3 + .../connection/DefaultConnectionPool.java | 4 +- .../internal/connection/DefaultServer.java | 3 + .../connection/LoadBalancedServer.java | 7 + .../internal/operation/BatchCursor.java | 11 + .../operation/ChangeStreamBatchCursor.java | 7 +- .../internal/operation/QueryBatchCursor.java | 473 +++++++++++++----- .../internal/connection/ServerHelper.java | 25 +- .../AggregateOperationSpecification.groovy | 1 + ...ountDocumentsOperationSpecification.groovy | 1 + .../DistinctOperationSpecification.groovy | 1 + .../FindOperationSpecification.groovy | 3 + ...InlineResultsOperationSpecification.groovy | 1 + ...yBatchCursorFunctionalSpecification.groovy | 11 +- .../AsyncQueryBatchCursorSpecification.groovy | 89 ++-- .../FindOperationUnitSpecification.groovy | 2 + .../QueryBatchCursorSpecification.groovy | 241 ++++++++- .../client/MongoChangeStreamCursor.java | 8 +- .../main/com/mongodb/client/MongoCursor.java | 13 +- 22 files changed, 732 insertions(+), 201 deletions(-) diff --git a/.evergreen/run-load-balancer-tests.sh b/.evergreen/run-load-balancer-tests.sh index 69fbeda5d6..641155431a 100755 --- a/.evergreen/run-load-balancer-tests.sh +++ b/.evergreen/run-load-balancer-tests.sh @@ -74,11 +74,22 @@ echo $first --tests UnifiedTransactionsTest \ --tests InitialDnsSeedlistDiscoveryTest second=$? +echo $second + +./gradlew -PjdkHome=/opt/java/${JDK} \ + -Dorg.mongodb.test.uri=${SINGLE_MONGOS_LB_URI} \ + -Dorg.mongodb.test.transaction.uri=${MULTI_MONGOS_LB_URI} \ + ${GRADLE_EXTRA_VARS} --stacktrace --info --continue driver-core:test \ + --tests QueryBatchCursorFunctionalSpecification +third=$? +echo $third if [ $first -ne 0 ]; then exit $first elif [ $second -ne 0 ]; then exit $second +elif [ $third -ne 0 ]; then + exit $third else exit 0 fi diff --git a/driver-core/src/main/com/mongodb/assertions/Assertions.java b/driver-core/src/main/com/mongodb/assertions/Assertions.java index 465a2b2cb4..3969d25a65 100644 --- a/driver-core/src/main/com/mongodb/assertions/Assertions.java +++ b/driver-core/src/main/com/mongodb/assertions/Assertions.java @@ -179,11 +179,23 @@ public static boolean assertFalse(final boolean value) throws AssertionError { /** * @throws AssertionError Always + * @return Never completes normally. The return type is {@link AssertionError} to allow writing {@code throw fail()}. + * This may be helpful in non-{@code void} methods. */ - public static void fail() throws AssertionError { + public static AssertionError fail() throws AssertionError { throw new AssertionError(); } + /** + * @param msg The failure message. + * @throws AssertionError Always + * @return Never completes normally. The return type is {@link AssertionError} to allow writing {@code throw fail("failure message")}. + * This may be helpful in non-{@code void} methods. + */ + public static AssertionError fail(final String msg) throws AssertionError { + throw new AssertionError(assertNotNull(msg)); + } + private Assertions() { } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncBatchCursor.java b/driver-core/src/main/com/mongodb/internal/async/AsyncBatchCursor.java index 2cee6ef9b1..c2d969c28b 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncBatchCursor.java @@ -66,10 +66,8 @@ public interface AsyncBatchCursor extends Closeable { * To help making such code simpler, this method is required to be idempotent. *

* Another quirk is that this method is allowed to release resources "eventually", - * i.e., not before (in the happens before order) returning. + * i.e., not before (in the happens-before order) returning. * Nevertheless, {@link #isClosed()} called after (in the happens-before order) {@link #close()} must return {@code true}. - * - * @see #close() */ @Override void close(); diff --git a/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java b/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java index 682ef20909..1b48960a50 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ConnectionSource.java @@ -20,6 +20,7 @@ import com.mongodb.connection.ServerDescription; import com.mongodb.internal.connection.Connection; import com.mongodb.internal.session.SessionContext; +import com.mongodb.lang.Nullable; /** * A source of connections to a single MongoDB server. @@ -42,8 +43,10 @@ public interface ConnectionSource extends ReferenceCounted { * * @since 3.6 */ + @Nullable SessionContext getSessionContext(); + @Nullable ServerApi getServerApi(); /** diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java index 243b880a6c..ac8286d76d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultConnectionPool.java @@ -333,7 +333,9 @@ private MongoTimeoutException createTimeoutException(final Timeout timeout) { } } - + /** + * Is package-access for the purpose of testing and must not be used for any other purpose outside of this class. + */ ConcurrentPool getPool() { return pool; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java index 379319e49f..725e3db41d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServer.java @@ -209,6 +209,9 @@ public void connect() { serverMonitor.connect(); } + /** + * Is package-access for the purpose of testing and must not be used for any other purpose outside of this class. + */ ConnectionPool getConnectionPool() { return connectionPool; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java index 9b8721248e..108b0e53b3 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/LoadBalancedServer.java @@ -144,6 +144,13 @@ public void getConnectionAsync(final SingleResultCallback callb }); } + /** + * Is package-access for the purpose of testing and must not be used for any other purpose outside of this class. + */ + ConnectionPool getConnectionPool() { + return connectionPool; + } + private class LoadBalancedServerProtocolExecutor implements ProtocolExecutor { @SuppressWarnings("unchecked") @Override diff --git a/driver-core/src/main/com/mongodb/internal/operation/BatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/BatchCursor.java index 93cf6b742b..6b26b19e96 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/BatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/BatchCursor.java @@ -19,6 +19,7 @@ import com.mongodb.ServerAddress; import com.mongodb.ServerCursor; import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.lang.Nullable; import java.io.Closeable; import java.util.Iterator; @@ -37,6 +38,15 @@ */ @NotThreadSafe public interface BatchCursor extends Iterator>, Closeable { + /** + * Despite this interface being {@linkplain NotThreadSafe non-thread-safe}, + * {@link #close()} is allowed to be called concurrently with any method of the cursor, including itself. + * This is useful to cancel blocked {@link #hasNext()}, {@link #next()}. + * This method is idempotent. + *

+ * Another quirk is that this method is allowed to release resources "eventually", + * i.e., not before (in the happens-before order) returning. + */ @Override void close(); @@ -85,6 +95,7 @@ public interface BatchCursor extends Iterator>, Closeable { * * @return ServerCursor */ + @Nullable ServerCursor getServerCursor(); /** diff --git a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java index ab938c341e..18c174f5d5 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java @@ -16,6 +16,7 @@ package com.mongodb.internal.operation; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import com.mongodb.MongoChangeStreamException; import com.mongodb.MongoException; @@ -44,7 +45,7 @@ final class ChangeStreamBatchCursor implements AggregateResponseBatchCursor wrapped; private BsonDocument resumeToken; - private volatile boolean closed; + private final AtomicBoolean closed; ChangeStreamBatchCursor(final ChangeStreamOperation changeStreamOperation, final AggregateResponseBatchCursor wrapped, @@ -56,6 +57,7 @@ final class ChangeStreamBatchCursor implements AggregateResponseBatchCursor getWrapped() { @@ -108,8 +110,7 @@ public List apply(final AggregateResponseBatchCursor queryBa @Override public void close() { - if (!closed) { - closed = true; + if (!closed.getAndSet(true)) { wrapped.close(); binding.release(); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java index 486c2e3748..dc41267502 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java @@ -22,12 +22,16 @@ import com.mongodb.MongoSocketException; import com.mongodb.ReadPreference; import com.mongodb.ServerAddress; +import com.mongodb.ServerApi; import com.mongodb.ServerCursor; +import com.mongodb.annotations.ThreadSafe; import com.mongodb.connection.ServerType; import com.mongodb.internal.binding.ConnectionSource; import com.mongodb.internal.connection.Connection; import com.mongodb.internal.connection.QueryResult; +import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.validator.NoOpFieldNameValidator; +import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonInt32; @@ -40,8 +44,15 @@ import java.util.List; import java.util.NoSuchElementException; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.StampedLock; +import java.util.function.Consumer; +import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.assertions.Assertions.fail; import static com.mongodb.assertions.Assertions.isTrueArgument; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.operation.CursorHelper.getNumberToReturn; @@ -55,24 +66,24 @@ class QueryBatchCursor implements AggregateResponseBatchCursor { private static final String CURSOR = "cursor"; private static final String POST_BATCH_RESUME_TOKEN = "postBatchResumeToken"; private static final String OPERATION_TIME = "operationTime"; + private static final String MESSAGE_IF_CLOSED_AS_CURSOR = "Cursor has been closed"; + private static final String MESSAGE_IF_CLOSED_AS_ITERATOR = "Iterator has been closed"; private final MongoNamespace namespace; + @Nullable + private final ServerApi serverApi; private final ServerAddress serverAddress; private final int limit; private final Decoder decoder; private final long maxTimeMS; private int batchSize; - private ConnectionSource connectionSource; - private Connection connection; - private ServerCursor serverCursor; private List nextBatch; private int count; - private volatile boolean closed; private BsonDocument postBatchResumeToken; private BsonTimestamp operationTime; private final boolean firstBatchEmpty; private int maxWireVersion = 0; - private boolean killCursorOnClose = true; + private final ResourceManager resourceManager; QueryBatchCursor(final QueryResult firstQueryResult, final int limit, final int batchSize, final Decoder decoder) { this(firstQueryResult, limit, batchSize, decoder, null); @@ -94,6 +105,7 @@ class QueryBatchCursor implements AggregateResponseBatchCursor { isTrueArgument("maxTimeMS >= 0", maxTimeMS >= 0); this.maxTimeMS = maxTimeMS; this.namespace = firstQueryResult.getNamespace(); + this.serverApi = connectionSource == null ? null : connectionSource.getServerApi(); this.serverAddress = firstQueryResult.getAddress(); this.limit = limit; this.batchSize = batchSize; @@ -102,37 +114,36 @@ class QueryBatchCursor implements AggregateResponseBatchCursor { this.operationTime = result.getTimestamp(OPERATION_TIME, null); this.postBatchResumeToken = getPostBatchResumeTokenFromResponse(result); } - if (firstQueryResult.getCursor() != null) { + ServerCursor serverCursor = initFromQueryResult(firstQueryResult); + if (serverCursor != null) { notNull("connectionSource", connectionSource); } - if (connectionSource != null) { - this.connectionSource = connectionSource.retain(); - } - - initFromQueryResult(firstQueryResult); firstBatchEmpty = firstQueryResult.getResults().isEmpty(); - + Connection connectionToPin = null; + boolean releaseServerAndResources = false; if (connection != null) { this.maxWireVersion = connection.getDescription().getMaxWireVersion(); if (limitReached()) { - killCursor(connection); + releaseServerAndResources = true; } else { assertNotNull(connectionSource); if (connectionSource.getServerDescription().getType() == ServerType.LOAD_BALANCER) { - this.connection = connection.retain(); - this.connection.markAsPinned(Connection.PinningMode.CURSOR); + connectionToPin = connection; } } } - releaseConnectionAndSourceIfNoServerCursor(); + resourceManager = new ResourceManager(connectionSource, connectionToPin, serverCursor); + if (releaseServerAndResources) { + resourceManager.releaseServerAndClientResources(assertNotNull(connection)); + } } @Override public boolean hasNext() { - if (closed) { - throw new IllegalStateException("Cursor has been closed"); - } + return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, this::doHasNext)); + } + private boolean doHasNext() { if (nextBatch != null) { return true; } @@ -141,10 +152,10 @@ public boolean hasNext() { return false; } - while (serverCursor != null) { + while (resourceManager.serverCursor() != null) { getMore(); - if (closed) { - throw new IllegalStateException("Cursor has been closed"); + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_CURSOR); } if (nextBatch != null) { return true; @@ -156,11 +167,11 @@ public boolean hasNext() { @Override public List next() { - if (closed) { - throw new IllegalStateException("Iterator has been closed"); - } + return assertNotNull(resourceManager.execute(MESSAGE_IF_CLOSED_AS_ITERATOR, this::doNext)); + } - if (!hasNext()) { + private List doNext() { + if (!doHasNext()) { throw new NoSuchElementException(); } @@ -186,29 +197,20 @@ public void remove() { @Override public void close() { - if (!closed) { - closed = true; - try { - killCursor(); - } finally { - releaseConnectionAndSource(); - } - } + resourceManager.close(); } @Override public List tryNext() { - if (closed) { - throw new IllegalStateException("Cursor has been closed"); - } - - if (!tryHasNext()) { - return null; - } - return next(); + return resourceManager.execute(MESSAGE_IF_CLOSED_AS_CURSOR, () -> { + if (!tryHasNext()) { + return null; + } + return doNext(); + }); } - boolean tryHasNext() { + private boolean tryHasNext() { if (nextBatch != null) { return true; } @@ -217,7 +219,7 @@ boolean tryHasNext() { return false; } - if (serverCursor != null) { + if (resourceManager.serverCursor() != null) { getMore(); } @@ -225,18 +227,19 @@ boolean tryHasNext() { } @Override + @Nullable public ServerCursor getServerCursor() { - if (closed) { - throw new IllegalStateException("Iterator has been closed"); + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); } - return serverCursor; + return resourceManager.serverCursor(); } @Override public ServerAddress getServerAddress() { - if (closed) { - throw new IllegalStateException("Iterator has been closed"); + if (!resourceManager.operable()) { + throw new IllegalStateException(MESSAGE_IF_CLOSED_AS_ITERATOR); } return serverAddress; @@ -263,49 +266,34 @@ public int getMaxWireVersion() { } private void getMore() { - Connection connection = getConnection(); - try { + ServerCursor serverCursor = assertNotNull(resourceManager.serverCursor()); + resourceManager.executeWithConnection(connection -> { + ServerCursor nextServerCursor; if (serverIsAtLeastVersionThreeDotTwo(connection.getDescription())) { try { - initFromCommandResult(connection.command(namespace.getDatabaseName(), - asGetMoreCommandDocument(), + nextServerCursor = initFromCommandResult(connection.command(namespace.getDatabaseName(), + asGetMoreCommandDocument(serverCursor), NO_OP_FIELD_NAME_VALIDATOR, ReadPreference.primary(), CommandResultDocumentCodec.create(decoder, "nextBatch"), - connectionSource.getSessionContext(), - connectionSource.getServerApi())); + resourceManager.sessionContext(), + serverApi)); } catch (MongoCommandException e) { throw translateCommandException(e, serverCursor); } } else { QueryResult getMore = connection.getMore(namespace, serverCursor.getId(), getNumberToReturn(limit, batchSize, count), decoder); - initFromQueryResult(getMore); + nextServerCursor = initFromQueryResult(getMore); } + resourceManager.setServerCursor(nextServerCursor); if (limitReached()) { - killCursor(connection); - } - } catch (MongoSocketException e) { - // If connection is pinned, don't attempt to kill the cursor on close, since the connection is in a bad state - if (this.connection != null) { - killCursorOnClose = false; + resourceManager.releaseServerAndClientResources(connection); } - throw e; - } finally { - connection.release(); - releaseConnectionAndSourceIfNoServerCursor(); - } - } - - private Connection getConnection() { - if (connection == null) { - return connectionSource.getConnection(); - } else { - return connection.retain(); - } + }); } - private BsonDocument asGetMoreCommandDocument() { + private BsonDocument asGetMoreCommandDocument(final ServerCursor serverCursor) { BsonDocument document = new BsonDocument("getMore", new BsonInt64(serverCursor.getId())) .append("collection", new BsonString(namespace.getCollectionName())); @@ -320,83 +308,326 @@ private BsonDocument asGetMoreCommandDocument() { return document; } - private void initFromQueryResult(final QueryResult queryResult) { - serverCursor = queryResult.getCursor(); + @Nullable + private ServerCursor initFromQueryResult(final QueryResult queryResult) { nextBatch = queryResult.getResults().isEmpty() ? null : queryResult.getResults(); count += queryResult.getResults().size(); + return queryResult.getCursor(); } - private void initFromCommandResult(final BsonDocument getMoreCommandResultDocument) { - QueryResult queryResult = getMoreCursorDocumentToQueryResult(getMoreCommandResultDocument.getDocument(CURSOR), - connectionSource.getServerDescription().getAddress()); + @Nullable + private ServerCursor initFromCommandResult(final BsonDocument getMoreCommandResultDocument) { + QueryResult queryResult = getMoreCursorDocumentToQueryResult(getMoreCommandResultDocument.getDocument(CURSOR), serverAddress); postBatchResumeToken = getPostBatchResumeTokenFromResponse(getMoreCommandResultDocument); operationTime = getMoreCommandResultDocument.getTimestamp(OPERATION_TIME, null); - initFromQueryResult(queryResult); + return initFromQueryResult(queryResult); } private boolean limitReached() { return Math.abs(limit) != 0 && count >= Math.abs(limit); } - private void killCursor() { - if (serverCursor != null && killCursorOnClose) { + private BsonDocument getPostBatchResumeTokenFromResponse(final BsonDocument result) { + BsonDocument cursor = result.getDocument(CURSOR, null); + if (cursor != null) { + return cursor.getDocument(POST_BATCH_RESUME_TOKEN, null); + } + return null; + } + + /** + * This class maintains all resources that must be released in {@link QueryBatchCursor#close()}. + * It also implements a {@linkplain #doClose() deferred close action} such that it is totally ordered with other operations of + * {@link QueryBatchCursor} (methods {@link #tryStartOperation()}/{@link #endOperation()} must be used properly to enforce the order) + * despite the method {@link QueryBatchCursor#close()} being called concurrently with those operations. + * This total order induces the happens-before order. + *

+ * The deferred close action does not violate externally observable idempotence of {@link QueryBatchCursor#close()}, + * because {@link QueryBatchCursor#close()} is allowed to release resources "eventually". + *

+ * Only methods explicitly documented as thread-safe are thread-safe, + * others are not and rely on the total order mentioned above. + */ + @ThreadSafe + private final class ResourceManager { + private final Lock lock; + private volatile State state; + @Nullable + private volatile ConnectionSource connectionSource; + @Nullable + private volatile Connection pinnedConnection; + @Nullable + private volatile ServerCursor serverCursor; + private volatile boolean skipReleasingServerResourcesOnClose; + + ResourceManager(@Nullable final ConnectionSource connectionSource, + @Nullable final Connection connectionToPin, @Nullable final ServerCursor serverCursor) { + lock = new StampedLock().asWriteLock(); + state = State.IDLE; + if (serverCursor != null) { + this.connectionSource = (assertNotNull(connectionSource)).retain(); + if (connectionToPin != null) { + this.pinnedConnection = connectionToPin.retain(); + connectionToPin.markAsPinned(Connection.PinningMode.CURSOR); + } + } + skipReleasingServerResourcesOnClose = false; + this.serverCursor = serverCursor; + } + + /** + * Thread-safe. + */ + boolean operable() { + return state.operable(); + } + + /** + * Thread-safe. + * Executes {@code operation} within the {@link #tryStartOperation()}/{@link #endOperation()} bounds. + * + * @throws IllegalStateException If {@linkplain QueryBatchCursor#close() closed}. + */ + @Nullable + R execute(final String exceptionMessageIfClosed, final Supplier operation) throws IllegalStateException { + if (!tryStartOperation()) { + throw new IllegalStateException(exceptionMessageIfClosed); + } try { - Connection connection = getConnection(); - try { - killCursor(connection); - } finally { - connection.release(); + return operation.get(); + } finally { + endOperation(); + } + } + + /** + * Thread-safe. + * Returns {@code true} iff started an operation. + * If {@linkplain #operable() closed}, then returns false, otherwise completes abruptly. + * @throws IllegalStateException Iff another operation is in progress. + */ + private boolean tryStartOperation() throws IllegalStateException { + lock.lock(); + try { + State localState = state; + if (!localState.operable()) { + return false; + } else if (localState == State.IDLE) { + state = State.OPERATION_IN_PROGRESS; + return true; + } else if (localState == State.OPERATION_IN_PROGRESS) { + throw new IllegalStateException("Another operation is currently in progress, concurrent operations are not supported"); + } else { + throw fail(state.toString()); } - } catch (MongoException e) { - // Ignore exceptions from calling killCursor + } finally { + lock.unlock(); } } - } - private void killCursor(final Connection connection) { - if (serverCursor != null) { - notNull("connection", connection); + /** + * Thread-safe. + */ + private void endOperation() { + boolean doClose = false; + lock.lock(); try { - if (serverIsAtLeastVersionThreeDotTwo(connection.getDescription())) { - connection.command(namespace.getDatabaseName(), asKillCursorsCommandDocument(), NO_OP_FIELD_NAME_VALIDATOR, - ReadPreference.primary(), new BsonDocumentCodec(), connectionSource.getSessionContext(), - connectionSource.getServerApi()); + State localState = state; + if (localState == State.OPERATION_IN_PROGRESS) { + state = State.IDLE; + } else if (localState == State.CLOSE_PENDING) { + state = State.CLOSED; + doClose = true; } else { - connection.killCursor(namespace, singletonList(serverCursor.getId())); + fail(localState.toString()); + } + } finally { + lock.unlock(); + } + if (doClose) { + doClose(); + } + } + + /** + * Thread-safe. + */ + void close() { + boolean doClose = false; + lock.lock(); + try { + State localState = state; + if (localState == State.OPERATION_IN_PROGRESS) { + state = State.CLOSE_PENDING; + } else if (localState != State.CLOSED) { + state = State.CLOSED; + doClose = true; } } finally { + lock.unlock(); + } + if (doClose) { + doClose(); + } + } + + /** + * This method is never executed concurrently with either itself or other operations + * demarcated by {@link #tryStartOperation()}/{@link #endOperation()}. + */ + private void doClose() { + try { + if (skipReleasingServerResourcesOnClose) { + serverCursor = null; + } else if (serverCursor != null) { + Connection connection = connection(); + try { + releaseServerResources(connection); + } finally { + connection.release(); + } + } + } catch (MongoException e) { + // ignore exceptions when releasing server resources + } finally { + // guarantee that regardless of exceptions, `serverCursor` is null and client resources are released serverCursor = null; + releaseClientResources(); } } - } - private void releaseConnectionAndSourceIfNoServerCursor() { - if (serverCursor == null) { - releaseConnectionAndSource(); + void onCorruptedConnection(final Connection corruptedConnection) { + assertTrue(state.inProgress()); + // if `pinnedConnection` is corrupted, then we cannot kill `serverCursor` via such a connection + Connection localPinnedConnection = pinnedConnection; + if (localPinnedConnection != null) { + assertTrue(corruptedConnection == localPinnedConnection); + skipReleasingServerResourcesOnClose = true; + } } - } - private void releaseConnectionAndSource() { - if (connectionSource != null) { - connectionSource.release(); - connectionSource = null; + void executeWithConnection(final Consumer action) { + Connection connection = connection(); + try { + action.accept(connection); + } catch (MongoSocketException e) { + try { + onCorruptedConnection(connection); + } catch (RuntimeException suppressed) { + e.addSuppressed(suppressed); + } + throw e; + } finally { + connection.release(); + } } - if (connection != null) { - connection.release(); - connection = null; + + private Connection connection() { + assertTrue(state != State.IDLE); + if (pinnedConnection == null) { + return assertNotNull(connectionSource).getConnection(); + } else { + return assertNotNull(pinnedConnection).retain(); + } + } + + /** + * Thread-safe. + */ + @Nullable + ServerCursor serverCursor() { + return serverCursor; + } + + void setServerCursor(@Nullable final ServerCursor serverCursor) { + assertTrue(state.inProgress()); + assertNotNull(this.serverCursor); + // without `connectionSource` we will not be able to kill `serverCursor` later + assertNotNull(connectionSource); + this.serverCursor = serverCursor; + if (serverCursor == null) { + releaseClientResources(); + } + } + + @Nullable + SessionContext sessionContext() { + return assertNotNull(connectionSource).getSessionContext(); } - } - private BsonDocument asKillCursorsCommandDocument() { - return new BsonDocument("killCursors", new BsonString(namespace.getCollectionName())) - .append("cursors", new BsonArray(singletonList(new BsonInt64(serverCursor.getId())))); + void releaseServerAndClientResources(final Connection connection) { + try { + releaseServerResources(assertNotNull(connection)); + } finally { + releaseClientResources(); + } + } + + private void releaseServerResources(final Connection connection) { + try { + ServerCursor localServerCursor = serverCursor; + if (localServerCursor != null) { + killServerCursor(namespace, localServerCursor, sessionContext(), serverApi, assertNotNull(connection)); + } + } finally { + serverCursor = null; + } + } + + private void killServerCursor(final MongoNamespace namespace, final ServerCursor serverCursor, + @Nullable final SessionContext sessionContext, @Nullable final ServerApi serverApi, final Connection connection) { + long cursorId = serverCursor.getId(); + if (serverIsAtLeastVersionThreeDotTwo(connection.getDescription())) { + connection.command(namespace.getDatabaseName(), asKillCursorsCommandDocument(namespace, serverCursor), + NO_OP_FIELD_NAME_VALIDATOR, ReadPreference.primary(), new BsonDocumentCodec(), sessionContext, serverApi); + } else { + connection.killCursor(namespace, singletonList(cursorId)); + } + } + + private BsonDocument asKillCursorsCommandDocument(final MongoNamespace namespace, final ServerCursor serverCursor) { + return new BsonDocument("killCursors", new BsonString(namespace.getCollectionName())) + .append("cursors", new BsonArray(singletonList(new BsonInt64(serverCursor.getId())))); + } + + private void releaseClientResources() { + assertNull(serverCursor); + ConnectionSource localConnectionSource = connectionSource; + if (localConnectionSource != null) { + localConnectionSource.release(); + connectionSource = null; + } + Connection localPinnedConnection = pinnedConnection; + if (localPinnedConnection != null) { + localPinnedConnection.release(); + pinnedConnection = null; + } + } } - private BsonDocument getPostBatchResumeTokenFromResponse(final BsonDocument result) { - BsonDocument cursor = result.getDocument(CURSOR, null); - if (cursor != null) { - return cursor.getDocument(POST_BATCH_RESUME_TOKEN, null); + private enum State { + IDLE(true, false), + OPERATION_IN_PROGRESS(true, true), + /** + * Implies {@link #OPERATION_IN_PROGRESS}. + */ + CLOSE_PENDING(false, true), + CLOSED(false, false); + + private final boolean operable; + private final boolean inProgress; + + State(final boolean operable, final boolean inProgress) { + this.operable = operable; + this.inProgress = inProgress; + } + + boolean operable() { + return operable; + } + + boolean inProgress() { + return inProgress; } - return null; } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ServerHelper.java b/driver-core/src/test/functional/com/mongodb/internal/connection/ServerHelper.java index 8c6877b376..1b2d71c920 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/ServerHelper.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ServerHelper.java @@ -26,6 +26,7 @@ import static com.mongodb.ClusterFixture.getAsyncCluster; import static com.mongodb.ClusterFixture.getCluster; +import static com.mongodb.assertions.Assertions.fail; import static java.lang.Thread.sleep; public final class ServerHelper { @@ -43,10 +44,8 @@ public static void waitForLastRelease(final Cluster cluster) { } public static void waitForLastRelease(final ServerAddress address, final Cluster cluster) { - DefaultServer server = (DefaultServer) cluster.selectServer(new ServerAddressSelector(address)) - .getServer(); - DefaultConnectionPool connectionProvider = (DefaultConnectionPool) server.getConnectionPool(); - ConcurrentPool pool = connectionProvider.getPool(); + ConcurrentPool pool = connectionPool( + cluster.selectServer(new ServerAddressSelector(address)).getServer()); long startTime = System.currentTimeMillis(); while (pool.getInUseCount() > 0) { try { @@ -62,15 +61,25 @@ public static void waitForLastRelease(final ServerAddress address, final Cluster } private static void checkPool(final ServerAddress address, final Cluster cluster) { - DefaultServer server = (DefaultServer) cluster.selectServer(new ServerAddressSelector(address)) - .getServer(); - DefaultConnectionPool connectionProvider = (DefaultConnectionPool) server.getConnectionPool(); - ConcurrentPool pool = connectionProvider.getPool(); + ConcurrentPool pool = connectionPool( + cluster.selectServer(new ServerAddressSelector(address)).getServer()); if (pool.getInUseCount() > 0) { throw new IllegalStateException("Connection pool in use count is " + pool.getInUseCount()); } } + private static ConcurrentPool connectionPool(final Server server) { + ConnectionPool connectionPool; + if (server instanceof DefaultServer) { + connectionPool = ((DefaultServer) server).getConnectionPool(); + } else if (server instanceof LoadBalancedServer) { + connectionPool = ((LoadBalancedServer) server).getConnectionPool(); + } else { + throw fail(server.getClass().toString()); + } + return ((DefaultConnectionPool) connectionPool).getPool(); + } + public static void waitForRelease(final AsyncConnectionSource connectionSource, final int expectedCount) { long startTime = System.currentTimeMillis(); while (connectionSource.getCount() > expectedCount) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy index 228921a9ca..11d9cf4a8d 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AggregateOperationSpecification.groovy @@ -452,6 +452,7 @@ class AggregateOperationSpecification extends OperationFunctionalSpecification { binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', new BsonArray()) .append('cursor', new BsonDocument()) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy index 5d8111d06b..074fb96845 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/CountDocumentsOperationSpecification.groovy @@ -324,6 +324,7 @@ class CountDocumentsOperationSpecification extends OperationFunctionalSpecificat binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def pipeline = new BsonArray([BsonDocument.parse('{ $match: {}}'), BsonDocument.parse('{$group: {_id: 1, n: {$sum: 1}}}')]) def commandDocument = new BsonDocument('aggregate', new BsonString(getCollectionName())) .append('pipeline', pipeline) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy index 93c545367b..13cccefe02 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/DistinctOperationSpecification.groovy @@ -295,6 +295,7 @@ class DistinctOperationSpecification extends OperationFunctionalSpecification { binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def commandDocument = new BsonDocument('distinct', new BsonString(getCollectionName())) .append('key', new BsonString('str')) appendReadConcernToCommand(sessionContext, MIN_WIRE_VERSION, commandDocument) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy index e2f1a1a730..ba97260efc 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/FindOperationSpecification.groovy @@ -529,6 +529,7 @@ class FindOperationSpecification extends OperationFunctionalSpecification { binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())) appendReadConcernToCommand(sessionContext, MIN_WIRE_VERSION, commandDocument) @@ -609,6 +610,7 @@ class FindOperationSpecification extends OperationFunctionalSpecification { binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def commandDocument = new BsonDocument('find', new BsonString(getCollectionName())).append('allowDiskUse', BsonBoolean.TRUE) appendReadConcernToCommand(sessionContext, MIN_WIRE_VERSION, commandDocument) @@ -694,6 +696,7 @@ class FindOperationSpecification extends OperationFunctionalSpecification { binding.readConnectionSource >> source source.connection >> connection source.retain() >> source + source.getServerApi() >> null when: operation.execute(binding) diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy index dcdfffa9a8..bf1c0602b6 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/MapReduceWithInlineResultsOperationSpecification.groovy @@ -265,6 +265,7 @@ class MapReduceWithInlineResultsOperationSpecification extends OperationFunction binding.sessionContext >> sessionContext source.connection >> connection source.retain() >> source + source.getServerApi() >> null def commandDocument = BsonDocument.parse(''' { "mapreduce" : "coll", "map" : { "$code" : "function(){ }" }, diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/QueryBatchCursorFunctionalSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/QueryBatchCursorFunctionalSpecification.groovy index c584e11264..04eb38b992 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/QueryBatchCursorFunctionalSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/QueryBatchCursorFunctionalSpecification.groovy @@ -268,18 +268,19 @@ class QueryBatchCursorFunctionalSpecification extends OperationFunctionalSpecifi @Slow def 'hasNext should throw when cursor is closed in another thread'() { + Connection conn = connectionSource.getConnection() collectionHelper.create(collectionName, new CreateCollectionOptions().capped(true).sizeInBytes(1000)) collectionHelper.insertDocuments(new DocumentCodec(), new Document('_id', 1).append('ts', new BsonTimestamp(5, 0))) def firstBatch = executeQuery(new BsonDocument('ts', new BsonDocument('$gte', new BsonTimestamp(5, 0))), 0, 2, true, true); - cursor = new QueryBatchCursor(firstBatch, 0, 2, new DocumentCodec(), connectionSource) + cursor = new QueryBatchCursor(firstBatch, 0, 2, 0, new DocumentCodec(), connectionSource, conn) cursor.next() - def latch = new CountDownLatch(1) + def closeCompleted = new CountDownLatch(1) // wait a second then close the cursor new Thread({ sleep(1000) cursor.close() - latch.countDown() + closeCompleted.countDown() } as Runnable).start() when: @@ -287,9 +288,11 @@ class QueryBatchCursorFunctionalSpecification extends OperationFunctionalSpecifi then: thrown(Exception) + closeCompleted.await(5, TimeUnit.SECONDS) + conn.getCount() == 1 cleanup: - latch.await(5, TimeUnit.SECONDS) // wait for cursor.close to complete + conn.release() } @IgnoreIf({ !serverVersionAtLeast(3, 2) || isSharded() }) diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncQueryBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncQueryBatchCursorSpecification.groovy index 93a2bd8abe..6a68a58963 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncQueryBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncQueryBatchCursorSpecification.groovy @@ -27,6 +27,7 @@ import com.mongodb.connection.ServerConnectionState import com.mongodb.connection.ServerDescription import com.mongodb.connection.ServerType import com.mongodb.connection.ServerVersion +import com.mongodb.internal.async.SingleResultCallback import com.mongodb.internal.binding.AsyncConnectionSource import com.mongodb.internal.connection.AsyncConnection import com.mongodb.internal.connection.QueryResult @@ -365,8 +366,11 @@ class AsyncQueryBatchCursorSpecification extends Specification { given: def connectionA = referenceCountedAsyncConnection(serverVersion) def connectionB = referenceCountedAsyncConnection(serverVersion) - def connectionSource = getAsyncConnectionSource(connectionA, connectionB) + def connectionSource = getAsyncConnectionSource(serverType, connectionA, connectionB) def initialResult = queryResult() + Object getMoreResponse = useCommand + ? documentResponse([], getMoreResponseHasCursor ? 42 : 0) + : queryResult([], getMoreResponseHasCursor ? 42 : 0) when: def cursor = new AsyncQueryBatchCursor(initialResult, 0, 0, 0, CODEC, connectionSource, connectionA) @@ -379,21 +383,24 @@ class AsyncQueryBatchCursorSpecification extends Specification { nextBatch(cursor) then: - if (commandAsync) { - _ * connectionA.commandAsync(_, _, _, _, _, _, _, _) >> { - // Simulate the user calling close while the getMore is in flight + // simulate the user calling `close` while `getMore` is in flight + if (useCommand) { + // in LB mode the same connection is used to execute both `getMore` and `killCursors` + int numberOfInvocations = serverType == ServerType.LOAD_BALANCER + ? getMoreResponseHasCursor ? 2 : 1 + : 1 + numberOfInvocations * connectionA.commandAsync(_, _, _, _, _, _, _, _) >> { + // `getMore` command cursor.close() - it[7].onResult(response, null) + ((SingleResultCallback) it[7]).onResult(getMoreResponse, null) } >> { - it[7].onResult(response2, null) + // `killCursors` command + ((SingleResultCallback) it[7]).onResult(null, null) } } else { - _ * connectionA.getMoreAsync(_, _, _, _, _) >> { - // Simulate the user calling close while the getMore is in flight + 1 * connectionA.getMoreAsync(_, _, _, _, _) >> { cursor.close() - it[4].onResult(response, null) - } >> { - it[4].onResult(response2, null) + ((SingleResultCallback) it[4]).onResult(getMoreResponse, null) } } @@ -402,25 +409,23 @@ class AsyncQueryBatchCursorSpecification extends Specification { then: connectionA.getCount() == 0 - if (response2 == null) { //otherwise connectionSource is released asynchronously, which is not easy to verify - connectionSource.getCount() == 0 - } cursor.isClosed() where: - serverVersion | commandAsync | response | response2 - new ServerVersion([3, 2, 0]) | true | documentResponse([]) | documentResponse([], 0) - new ServerVersion([3, 2, 0]) | true | documentResponse([], 0) | null - new ServerVersion([3, 0, 0]) | false | new QueryResult(NAMESPACE, [], 42, SERVER_ADDRESS) | - new QueryResult(NAMESPACE, [], 0, SERVER_ADDRESS) - new ServerVersion([3, 0, 0]) | false | new QueryResult(NAMESPACE, [], 0, SERVER_ADDRESS) | null + serverVersion | useCommand | getMoreResponseHasCursor | serverType + new ServerVersion([5, 0, 0]) | true | true | ServerType.LOAD_BALANCER + new ServerVersion([5, 0, 0]) | true | false | ServerType.LOAD_BALANCER + new ServerVersion([3, 2, 0]) | true | true | ServerType.STANDALONE + new ServerVersion([3, 2, 0]) | true | false | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | true | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | false | ServerType.STANDALONE } def 'should close cursor after getMore finishes if cursor was closed while getMore was in progress and getMore throws exception'() { given: def connectionA = referenceCountedAsyncConnection(serverVersion) def connectionB = referenceCountedAsyncConnection(serverVersion) - def connectionSource = getAsyncConnectionSource(connectionA, connectionB) + def connectionSource = getAsyncConnectionSource(serverType, connectionA, connectionB) def initialResult = queryResult() when: @@ -434,17 +439,22 @@ class AsyncQueryBatchCursorSpecification extends Specification { nextBatch(cursor) then: + // simulate the user calling `close` while `getMore` is throwing a `MongoException` if (commandAsync) { - 1 * connectionA.commandAsync(_, _, _, _, _, _, _, _) >> { - // Simulate the user calling close while the getMore is throwing a MongoException + // in LB mode the same connection is used to execute both `getMore` and `killCursors` + int numberOfInvocations = serverType == ServerType.LOAD_BALANCER ? 2 : 1 + numberOfInvocations * connectionA.commandAsync(_, _, _, _, _, _, _, _) >> { + // `getMore` command cursor.close() - it[7].onResult(null, MONGO_EXCEPTION) + ((SingleResultCallback) it[7]).onResult(null, MONGO_EXCEPTION) + } >> { + // `killCursors` command + ((SingleResultCallback) it[7]).onResult(null, null) } } else { 1 * connectionA.getMoreAsync(_, _, _, _, _) >> { - // Simulate the user calling close while the getMore is throwing a MongoException cursor.close() - it[4].onResult(null, MONGO_EXCEPTION) + ((SingleResultCallback) it[4]).onResult(null, MONGO_EXCEPTION) } } @@ -456,15 +466,16 @@ class AsyncQueryBatchCursorSpecification extends Specification { cursor.isClosed() where: - serverVersion | commandAsync - new ServerVersion([3, 2, 0]) | true - new ServerVersion([3, 0, 0]) | false + serverVersion | commandAsync | serverType + new ServerVersion([5, 0, 0]) | true | ServerType.LOAD_BALANCER + new ServerVersion([3, 2, 0]) | true | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | ServerType.STANDALONE } def 'should handle errors when calling close'() { given: def connection = referenceCountedAsyncConnection() - def connectionSource = getAsyncConnectionSourceWithResult { [null, MONGO_EXCEPTION] } + def connectionSource = getAsyncConnectionSourceWithResult(ServerType.STANDALONE) { [null, MONGO_EXCEPTION] } def cursor = new AsyncQueryBatchCursor(queryResult(), 0, 0, 0, CODEC, connectionSource, connection) when: @@ -484,7 +495,7 @@ class AsyncQueryBatchCursorSpecification extends Specification { def 'should handle errors when getting a connection for getMore'() { given: def connection = referenceCountedAsyncConnection() - def connectionSource = getAsyncConnectionSourceWithResult { [null, MONGO_EXCEPTION] } + def connectionSource = getAsyncConnectionSourceWithResult(ServerType.STANDALONE) { [null, MONGO_EXCEPTION] } when: def cursor = new AsyncQueryBatchCursor(queryResult(), 0, 0, 0, CODEC, connectionSource, connection) @@ -576,14 +587,14 @@ class AsyncQueryBatchCursorSpecification extends Specification { private static final COMMAND_EXCEPTION = new MongoCommandException(BsonDocument.parse('{"ok": false, "errmsg": "error"}'), SERVER_ADDRESS) - def documentResponse(results, cursorId = 42) { + private static BsonDocument documentResponse(results, cursorId = 42) { new BsonDocument('ok', new BsonInt32(1)).append('cursor', new BsonDocument('id', new BsonInt64(cursorId)).append('ns', new BsonString(NAMESPACE.getFullName())) .append('nextBatch', new BsonArrayWrapper(results))) } - def queryResult(results = FIRST_BATCH, cursorId = 42) { + private static QueryResult queryResult(results = FIRST_BATCH, cursorId = 42) { new QueryResult(NAMESPACE, results, cursorId, SERVER_ADDRESS) } @@ -619,19 +630,23 @@ class AsyncQueryBatchCursorSpecification extends Specification { mock } - def getAsyncConnectionSource(AsyncConnection... connections) { + AsyncConnectionSource getAsyncConnectionSource(AsyncConnection... connections) { + getAsyncConnectionSource(ServerType.STANDALONE, connections) + } + + AsyncConnectionSource getAsyncConnectionSource(ServerType serverType, AsyncConnection... connections) { def index = -1 - getAsyncConnectionSourceWithResult { index += 1; [connections.toList().get(index).retain(), null] } + getAsyncConnectionSourceWithResult(serverType) { index += 1; [connections.toList().get(index).retain(), null] } } - def getAsyncConnectionSourceWithResult(connectionCallbackResults) { + def getAsyncConnectionSourceWithResult(ServerType serverType, Closure connectionCallbackResults) { def released = false int counter = 0 def mock = Mock(AsyncConnectionSource) mock.getServerDescription() >> { ServerDescription.builder() .address(new ServerAddress()) - .type(ServerType.STANDALONE) + .type(serverType) .state(ServerConnectionState.CONNECTED) .build() } diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/FindOperationUnitSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/FindOperationUnitSpecification.groovy index f206b84e1e..19f26cb47b 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/FindOperationUnitSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/FindOperationUnitSpecification.groovy @@ -58,6 +58,7 @@ class FindOperationUnitSpecification extends OperationUnitSpecification { } def connectionSource = Stub(ConnectionSource) { getConnection() >> connection + getServerApi() >> null } def readBinding = Stub(ReadBinding) { getReadPreference() >> readPreference @@ -204,6 +205,7 @@ class FindOperationUnitSpecification extends OperationUnitSpecification { def readBinding = Stub(ReadBinding) { getReadConnectionSource() >> Stub(ConnectionSource) { getConnection() >> connection + getServerApi() >> null } getReadPreference() >> Stub(ReadPreference) { isSlaveOk() >> slaveOk diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/QueryBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/QueryBatchCursorSpecification.groovy index 291670691b..bb4a01695b 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/QueryBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/QueryBatchCursorSpecification.groovy @@ -16,11 +16,16 @@ package com.mongodb.internal.operation +import com.mongodb.MongoException import com.mongodb.MongoNamespace import com.mongodb.MongoSocketException import com.mongodb.MongoSocketOpenException import com.mongodb.ServerAddress import com.mongodb.connection.ConnectionDescription +import com.mongodb.connection.ServerConnectionState +import com.mongodb.connection.ServerDescription +import com.mongodb.connection.ServerType +import com.mongodb.connection.ServerVersion import com.mongodb.internal.binding.ConnectionSource import com.mongodb.internal.connection.Connection import com.mongodb.internal.connection.QueryResult @@ -30,9 +35,15 @@ import org.bson.BsonInt64 import org.bson.BsonString import org.bson.Document import org.bson.codecs.BsonDocumentCodec +import org.bson.codecs.DocumentCodec import spock.lang.Specification +import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion + class QueryBatchCursorSpecification extends Specification { + private static final MongoNamespace NAMESPACE = new MongoNamespace('db', 'coll') + private static final ServerAddress SERVER_ADDRESS = new ServerAddress() + def 'should generate expected command with batchSize and maxTimeMS'() { given: def connection = Mock(Connection) { @@ -46,16 +57,13 @@ class QueryBatchCursorSpecification extends Specification { } connectionSource.retain() >> connectionSource - def database = 'test' - def collection = 'QueryBatchCursorSpecification' def cursorId = 42 - def namespace = new MongoNamespace(database, collection) - def firstBatch = new QueryResult(namespace, [], cursorId, new ServerAddress()) + def firstBatch = new QueryResult(NAMESPACE, [], cursorId, SERVER_ADDRESS) def cursor = new QueryBatchCursor(firstBatch, 0, batchSize, maxTimeMS, new BsonDocumentCodec(), connectionSource, connection) def expectedCommand = new BsonDocument('getMore': new BsonInt64(cursorId)) - .append('collection', new BsonString(collection)) + .append('collection', new BsonString(NAMESPACE.getCollectionName())) if (batchSize != 0) { expectedCommand.append('batchSize', new BsonInt32(batchSize)) } @@ -66,14 +74,14 @@ class QueryBatchCursorSpecification extends Specification { def reply = new BsonDocument('ok', new BsonInt32(1)) .append('cursor', new BsonDocument('id', new BsonInt64(0)) - .append('ns', new BsonString(namespace.getFullName())) + .append('ns', new BsonString(NAMESPACE.getFullName())) .append('nextBatch', new BsonArrayWrapper([]))) when: cursor.hasNext() then: - 1 * connection.command(database, expectedCommand, _, _, _, _, null) >> { + 1 * connection.command(NAMESPACE.getDatabaseName(), expectedCommand, _, _, _, _, null) >> { reply } 1 * connection.release() @@ -87,13 +95,12 @@ class QueryBatchCursorSpecification extends Specification { def 'should handle exceptions when closing'() { given: - def serverAddress = new ServerAddress() def connection = Mock(Connection) { _ * getDescription() >> Stub(ConnectionDescription) { getMaxWireVersion() >> 4 } - _ * killCursor(_, _) >> { throw new MongoSocketException('No MongoD', serverAddress) } - _ * command(_, _, _, _, _) >> { throw new MongoSocketException('No MongoD', serverAddress) } + _ * killCursor(_, _) >> { throw new MongoSocketException('No MongoD', SERVER_ADDRESS) } + _ * command(_, _, _, _, _) >> { throw new MongoSocketException('No MongoD', SERVER_ADDRESS) } } def connectionSource = Stub(ConnectionSource) { getServerApi() >> null @@ -101,8 +108,7 @@ class QueryBatchCursorSpecification extends Specification { } connectionSource.retain() >> connectionSource - def namespace = new MongoNamespace('test', 'QueryBatchCursorSpecification') - def firstBatch = new QueryResult(namespace, [], 42, serverAddress) + def firstBatch = new QueryResult(NAMESPACE, [], 42, SERVER_ADDRESS) def cursor = new QueryBatchCursor(firstBatch, 0, 2, 100, new BsonDocumentCodec(), connectionSource, connection) when: @@ -120,19 +126,18 @@ class QueryBatchCursorSpecification extends Specification { def 'should handle exceptions when killing cursor and a connection can not be obtained'() { given: - def serverAddress = new ServerAddress() def connection = Mock(Connection) { _ * getDescription() >> Stub(ConnectionDescription) { getMaxWireVersion() >> 4 } } def connectionSource = Stub(ConnectionSource) { - getConnection() >> { throw new MongoSocketOpenException("can't open socket", serverAddress, new IOException()) } + getConnection() >> { throw new MongoSocketOpenException("can't open socket", SERVER_ADDRESS, new IOException()) } + getServerApi() >> null } connectionSource.retain() >> connectionSource - def namespace = new MongoNamespace('test', 'QueryBatchCursorSpecification') - def firstBatch = new QueryResult(namespace, [], 42, serverAddress) + def firstBatch = new QueryResult(NAMESPACE, [], 42, SERVER_ADDRESS) def cursor = new QueryBatchCursor(firstBatch, 0, 2, 100, new BsonDocumentCodec(), connectionSource, connection) when: @@ -147,4 +152,208 @@ class QueryBatchCursorSpecification extends Specification { then: notThrown(Exception) } + + def 'should close cursor after getMore finishes if cursor was closed while getMore was in progress and getMore returns a response'() { + given: + Connection conn = mockConnection(serverVersion) + ConnectionSource connSource + if (serverType == ServerType.LOAD_BALANCER) { + connSource = mockConnectionSource(SERVER_ADDRESS, serverType) + } else { + connSource = mockConnectionSource(SERVER_ADDRESS, serverType, conn, mockConnection(serverVersion)) + } + List firstBatch = [new Document()] + QueryResult initialResult = new QueryResult<>(NAMESPACE, firstBatch, 1, SERVER_ADDRESS) + Object getMoreResponse = useCommand + ? emptyGetMoreCommandResponse(NAMESPACE, getMoreResponseHasCursor ? 42 : 0) + : emptyGetMoreQueryResponse(NAMESPACE, SERVER_ADDRESS, getMoreResponseHasCursor ? 42 : 0) + + when: + QueryBatchCursor cursor = new QueryBatchCursor<>(initialResult, 0, 0, 0, new DocumentCodec(), connSource, conn) + List batch = cursor.next() + + then: + batch == firstBatch + + when: + cursor.next() + + then: + // simulate the user calling `close` while `getMore` is in flight + if (useCommand) { + // in LB mode the same connection is used to execute both `getMore` and `killCursors` + int numberOfInvocations = serverType == ServerType.LOAD_BALANCER + ? getMoreResponseHasCursor ? 2 : 1 + : 1 + numberOfInvocations * conn.command(*_) >> { + // `getMore` command + cursor.close() + getMoreResponse + } >> { + // `killCursors` command + null + } + } else { + 1 * conn.getMore(*_) >> { + cursor.close() + getMoreResponse + } + } + + then: + IllegalStateException e = thrown() + e.getMessage() == 'Cursor has been closed' + + then: + conn.getCount() == 1 + connSource.getCount() == 1 + + where: + serverVersion | useCommand | getMoreResponseHasCursor | serverType + new ServerVersion([5, 0, 0]) | true | true | ServerType.LOAD_BALANCER + new ServerVersion([5, 0, 0]) | true | false | ServerType.LOAD_BALANCER + new ServerVersion([3, 2, 0]) | true | true | ServerType.STANDALONE + new ServerVersion([3, 2, 0]) | true | false | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | true | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | false | ServerType.STANDALONE + } + + def 'should close cursor after getMore finishes if cursor was closed while getMore was in progress and getMore throws exception'() { + given: + Connection conn = mockConnection(serverVersion) + ConnectionSource connSource + if (serverType == ServerType.LOAD_BALANCER) { + connSource = mockConnectionSource(SERVER_ADDRESS, serverType) + } else { + connSource = mockConnectionSource(SERVER_ADDRESS, serverType, conn, mockConnection(serverVersion)) + } + List firstBatch = [new Document()] + QueryResult initialResult = new QueryResult<>(NAMESPACE, firstBatch, 1, SERVER_ADDRESS) + String exceptionMessage = 'test' + + when: + QueryBatchCursor cursor = new QueryBatchCursor<>(initialResult, 0, 0, 0, new DocumentCodec(), connSource, conn) + List batch = cursor.next() + + then: + batch == firstBatch + + when: + cursor.next() + + then: + // simulate the user calling `close` while `getMore` is in flight + if (useCommand) { + // in LB mode the same connection is used to execute both `getMore` and `killCursors` + int numberOfInvocations = serverType == ServerType.LOAD_BALANCER ? 2 : 1 + numberOfInvocations * conn.command(*_) >> { + // `getMore` command + cursor.close() + throw new MongoException(exceptionMessage) + } >> { + // `killCursors` command + null + } + } else { + 1 * conn.getMore(*_) >> { + cursor.close() + throw new MongoException(exceptionMessage) + } + } + + then: + MongoException e = thrown() + e.getMessage() == exceptionMessage + + then: + conn.getCount() == 1 + connSource.getCount() == 1 + + where: + serverVersion | useCommand | serverType + new ServerVersion([5, 0, 0]) | true | ServerType.LOAD_BALANCER + new ServerVersion([3, 2, 0]) | true | ServerType.STANDALONE + new ServerVersion([3, 0, 0]) | false | ServerType.STANDALONE + } + + /** + * Creates a {@link Connection} with {@link Connection#getCount()} returning 1. + */ + private Connection mockConnection(ServerVersion serverVersion) { + int refCounter = 1 + Connection mockConn = Mock(Connection) { + getDescription() >> Stub(ConnectionDescription) { + getMaxWireVersion() >> getMaxWireVersionForServerVersion(serverVersion.getVersionList()) + } + } + mockConn.retain() >> { + if (refCounter == 0) { + throw new IllegalStateException('Tried to retain Connection when already released') + } else { + refCounter += 1 + } + mockConn + } + mockConn.release() >> { + refCounter -= 1 + if (refCounter < 0) { + throw new IllegalStateException('Tried to release Connection below 0') + } + } + mockConn.getCount() >> { refCounter } + mockConn + } + + private ConnectionSource mockConnectionSource(ServerAddress serverAddress, ServerType serverType, Connection... connections) { + int connIdx = 0 + int refCounter = 1 + ConnectionSource mockConnectionSource = Mock(ConnectionSource) + mockConnectionSource.getServerDescription() >> { + ServerDescription.builder() + .address(serverAddress) + .type(serverType) + .state(ServerConnectionState.CONNECTED) + .build() + } + mockConnectionSource.retain() >> { + if (refCounter == 0) { + throw new IllegalStateException('Tried to retain ConnectionSource when already released') + } else { + refCounter += 1 + } + mockConnectionSource + } + mockConnectionSource.release() >> { + refCounter -= 1 + if (refCounter < 0) { + throw new IllegalStateException('Tried to release ConnectionSource below 0') + } + } + mockConnectionSource.getCount() >> { refCounter } + mockConnectionSource.getConnection() >> { + if (refCounter == 0) { + throw new IllegalStateException('Tried to use released ConnectionSource') + } + Connection conn + if (connIdx < connections.length) { + conn = connections[connIdx] + } else { + throw new IllegalStateException('Requested more than maxConnections=' + maxConnections) + } + connIdx++ + conn.retain() + } + mockConnectionSource + } + + private static BsonDocument emptyGetMoreCommandResponse(MongoNamespace namespace, long cursorId) { + new BsonDocument('ok', new BsonInt32(1)) + .append('cursor', new BsonDocument('id', new BsonInt64(cursorId)) + .append('ns', new BsonString(namespace.getFullName())) + .append('nextBatch', new BsonArrayWrapper([]))) + } + + private static QueryResult emptyGetMoreQueryResponse(MongoNamespace namespace, ServerAddress serverAddress, long cursorId) { + new QueryResult(namespace, [], cursorId, serverAddress) + } } diff --git a/driver-sync/src/main/com/mongodb/client/MongoChangeStreamCursor.java b/driver-sync/src/main/com/mongodb/client/MongoChangeStreamCursor.java index 16a3bb5f11..38e33c8ae8 100644 --- a/driver-sync/src/main/com/mongodb/client/MongoChangeStreamCursor.java +++ b/driver-sync/src/main/com/mongodb/client/MongoChangeStreamCursor.java @@ -16,6 +16,7 @@ package com.mongodb.client; +import com.mongodb.annotations.NotThreadSafe; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -24,17 +25,18 @@ *

* An application should ensure that a cursor is closed in all circumstances, e.g. using a try-with-resources statement: *

- *
- * try (MongoChangeStreamCursor<Document> cursor = collection.find().cursor()) {
+ * 
{@code
+ * try (MongoChangeStreamCursor> cursor = collection.watch().cursor()) {
  *     while (cursor.hasNext()) {
  *         System.out.println(cursor.next());
  *     }
  * }
- * 
+ * } * * @since 3.11 * @param The type of documents the cursor contains */ +@NotThreadSafe public interface MongoChangeStreamCursor extends MongoCursor { /** * Returns the resume token. If a batch has been iterated to the last change stream document in the batch diff --git a/driver-sync/src/main/com/mongodb/client/MongoCursor.java b/driver-sync/src/main/com/mongodb/client/MongoCursor.java index 2fa28ac75a..241532faf1 100644 --- a/driver-sync/src/main/com/mongodb/client/MongoCursor.java +++ b/driver-sync/src/main/com/mongodb/client/MongoCursor.java @@ -29,20 +29,25 @@ *

* An application should ensure that a cursor is closed in all circumstances, e.g. using a try-with-resources statement: * - *

- * try (MongoCursor<Document> cursor = collection.find().iterator()) {
+ * 
{@code
+ * try (MongoCursor cursor = collection.find().cursor()) {
  *     while (cursor.hasNext()) {
  *         System.out.println(cursor.next());
  *     }
  * }
- * 
+ * } * * @since 3.0 * @param The type of documents the cursor contains */ @NotThreadSafe public interface MongoCursor extends Iterator, Closeable { - + /** + * Despite this interface being {@linkplain NotThreadSafe non-thread-safe}, + * {@link #close()} is allowed to be called concurrently with any method of the cursor, including itself. + * This is useful to cancel blocked {@link #hasNext()}, {@link #next()}. + * This method is idempotent. + */ @Override void close();