diff --git a/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java b/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java new file mode 100644 index 00000000..75aef494 --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/connection/ConnectionOptionsHelper.java @@ -0,0 +1,41 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import com.google.api.core.InternalApi; +import com.google.cloud.spanner.connection.StatementExecutor.StatementExecutorType; + +@InternalApi +public class ConnectionOptionsHelper { + + @InternalApi + public static ConnectionOptions.Builder useDirectExecutorIfNotUseVirtualThreads( + String uri, ConnectionOptions.Builder builder) { + ConnectionState connectionState = new ConnectionState(ConnectionProperties.parseValues(uri)); + if (!connectionState.getValue(ConnectionProperties.USE_VIRTUAL_THREADS).getValue()) { + return builder.setStatementExecutorType(StatementExecutorType.DIRECT_EXECUTOR); + } + return builder; + } + + @InternalApi + public static boolean usesDirectExecutor(ConnectionOptions options) { + return options.getStatementExecutorType() == StatementExecutorType.DIRECT_EXECUTOR; + } + + private ConnectionOptionsHelper() {} +} diff --git a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcConnection.java b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcConnection.java index d790cf85..122d4896 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcConnection.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcConnection.java @@ -21,6 +21,7 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.ConnectionOptions; +import com.google.cloud.spanner.connection.ConnectionOptionsHelper; import com.google.common.annotations.VisibleForTesting; import com.google.rpc.Code; import java.sql.CallableStatement; @@ -53,6 +54,7 @@ abstract class AbstractJdbcConnection extends AbstractJdbcWrapper private final ConnectionOptions options; private final com.google.cloud.spanner.connection.Connection spanner; private final Properties clientInfo; + private final boolean usesDirectExecutor; private AbstractStatementParser parser; private SQLWarning firstWarning = null; @@ -63,6 +65,7 @@ abstract class AbstractJdbcConnection extends AbstractJdbcWrapper this.options = options; this.spanner = options.getConnection(); this.clientInfo = new Properties(JdbcDatabaseMetaData.getDefaultClientInfoProperties()); + this.usesDirectExecutor = ConnectionOptionsHelper.usesDirectExecutor(options); } /** Return the corresponding {@link com.google.cloud.spanner.connection.Connection} */ @@ -83,6 +86,10 @@ Spanner getSpanner() { return this.spanner.getSpanner(); } + boolean usesDirectExecutor() { + return this.usesDirectExecutor; + } + @Override public Dialect getDialect() { return spanner.getDialect(); diff --git a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcStatement.java b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcStatement.java index 3152445b..5c15f336 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcStatement.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcStatement.java @@ -34,6 +34,8 @@ import java.time.Duration; import java.util.Arrays; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.Nonnull; @@ -43,6 +45,8 @@ abstract class AbstractJdbcStatement extends AbstractJdbcWrapper implements Stat private static final String CURSORS_NOT_SUPPORTED = "Cursors are not supported"; private static final String ONLY_FETCH_FORWARD_SUPPORTED = "Only fetch_forward is supported"; final AbstractStatementParser parser; + private final Lock executingLock; + private volatile Thread executingThread; private boolean closed; private boolean closeOnCompletion; private boolean poolable; @@ -52,6 +56,11 @@ abstract class AbstractJdbcStatement extends AbstractJdbcWrapper implements Stat AbstractJdbcStatement(JdbcConnection connection) throws SQLException { this.connection = connection; this.parser = connection.getParser(); + if (connection.usesDirectExecutor()) { + this.executingLock = new ReentrantLock(); + } else { + this.executingLock = null; + } } @Override @@ -239,6 +248,10 @@ private T doWithStatementTimeout( Supplier runnable, Function shouldResetTimeout) throws SQLException { StatementTimeout originalTimeout = setTemporaryStatementTimeout(); T result = null; + if (this.executingLock != null) { + this.executingLock.lock(); + this.executingThread = Thread.currentThread(); + } try { Stopwatch stopwatch = Stopwatch.createStarted(); result = runnable.get(); @@ -248,6 +261,10 @@ private T doWithStatementTimeout( } catch (SpannerException spannerException) { throw JdbcSqlExceptionFactory.of(spannerException); } finally { + if (this.executingLock != null) { + this.executingThread = null; + this.executingLock.unlock(); + } if (shouldResetTimeout.apply(result)) { resetStatementTimeout(originalTimeout); } @@ -353,7 +370,16 @@ void setQueryTimeout(@Nonnull Duration duration) throws SQLException { @Override public void cancel() throws SQLException { checkClosed(); - connection.getSpannerConnection().cancel(); + if (this.executingThread != null) { + // This is a best-effort operation. It could be that the executing thread is set to null + // between the if-check and the actual execution. Just ignore if that happens. + try { + this.executingThread.interrupt(); + } catch (NullPointerException ignore) { + } + } else { + connection.getSpannerConnection().cancel(); + } } @Override diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDriver.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDriver.java index 1c499abf..6bb70283 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDriver.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDriver.java @@ -22,6 +22,7 @@ import com.google.cloud.spanner.SessionPoolOptionsHelper; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.connection.ConnectionOptions; +import com.google.cloud.spanner.connection.ConnectionOptionsHelper; import com.google.cloud.spanner.connection.ConnectionPropertiesHelper; import com.google.cloud.spanner.connection.ConnectionProperty; import com.google.rpc.Code; @@ -245,6 +246,9 @@ private ConnectionOptions buildConnectionOptions(String connectionUrl, Propertie // Enable multiplexed sessions by default for the JDBC driver. builder.setSessionPoolOptions( SessionPoolOptionsHelper.useMultiplexedSessions(SessionPoolOptions.newBuilder()).build()); + // Enable direct executor for JDBC, as we don't use the async API. + builder = + ConnectionOptionsHelper.useDirectExecutorIfNotUseVirtualThreads(connectionUrl, builder); return builder.build(); } diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTimeoutTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTimeoutTest.java index eb876c23..2c8c43ca 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTimeoutTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcStatementTimeoutTest.java @@ -19,24 +19,45 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.jdbc.JdbcSqlExceptionFactory.JdbcSqlExceptionImpl; import com.google.cloud.spanner.jdbc.JdbcSqlExceptionFactory.JdbcSqlTimeoutException; +import com.google.rpc.Code; +import com.google.spanner.v1.ExecuteSqlRequest; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.time.Duration; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Tests setting a statement timeout. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class JdbcStatementTimeoutTest extends AbstractMockServerTest { + @Parameter public boolean useVirtualThreads; + + @Parameters(name = "useVirtualThreads = {0}") + public static Object[] data() { + return new Boolean[] {false, true}; + } + + @Override + protected String getBaseUrl() { + return super.getBaseUrl() + ";useVirtualThreads=" + this.useVirtualThreads; + } + @After public void resetExecutionTimes() { mockSpanner.removeAllExecutionTimes(); @@ -122,4 +143,35 @@ public void testExecuteBatchTimeout() throws SQLException { } } } + + @Test + public void testCancel() throws Exception { + ExecutorService service = Executors.newSingleThreadExecutor(); + String sql = INSERT_STATEMENT.getSql(); + + try (java.sql.Connection connection = createJdbcConnection(); + Statement statement = connection.createStatement()) { + mockSpanner.freeze(); + Future future = + service.submit( + () -> { + // Wait until the request has landed on the server and then cancel the statement. + mockSpanner.waitForRequestsToContain( + message -> + message instanceof ExecuteSqlRequest + && ((ExecuteSqlRequest) message).getSql().equals(sql), + 5000L); + System.out.println("Cancelling statement"); + statement.cancel(); + return null; + }); + JdbcSqlExceptionImpl exception = + assertThrows(JdbcSqlExceptionImpl.class, () -> statement.execute(sql)); + assertEquals(Code.CANCELLED, exception.getCode()); + assertNull(future.get()); + } finally { + mockSpanner.unfreeze(); + service.shutdown(); + } + } }