diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD
index 34f7a7863ecf3b..40a3c64b00e995 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD
@@ -18,6 +18,7 @@ java_library(
deps = [
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//third_party:guava",
+ "//third_party:jsr305",
"//third_party:rxjava3",
"//third_party/grpc:grpc-jar",
],
diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java
index ee513847873e24..2326e3189b379c 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java
@@ -13,6 +13,7 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;
+import io.reactivex.rxjava3.core.Single;
import java.io.Closeable;
import java.io.IOException;
@@ -24,6 +25,13 @@
*
Connections must be closed with {@link Connection#close()} in order to be reused later.
*/
public interface ConnectionPool extends ConnectionFactory, Closeable {
+ /**
+ * Reuses a {@link Connection} in the pool and will potentially create a new connection depends on
+ * implementation.
+ */
+ @Override
+ Single extends Connection> create();
+
/** Closes the connection pool and closes all the underlying connections */
@Override
void close() throws IOException;
diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java
new file mode 100644
index 00000000000000..d731574bbf44a8
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java
@@ -0,0 +1,170 @@
+// Copyright 2021 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote.grpc;
+
+import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
+import io.grpc.CallOptions;
+import io.grpc.ClientCall;
+import io.grpc.MethodDescriptor;
+import io.reactivex.rxjava3.core.Single;
+import io.reactivex.rxjava3.disposables.Disposable;
+import io.reactivex.rxjava3.functions.Action;
+import io.reactivex.rxjava3.subjects.AsyncSubject;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.ReentrantLock;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+
+/**
+ * A {@link ConnectionPool} that creates one connection using provided {@link ConnectionFactory} and
+ * shares the connection upto {@code maxConcurrency}.
+ *
+ *
This is useful if underlying connection maintains a connection pool internally. (such as
+ * {@code Channel} in gRPC)
+ *
+ *
Connections must be closed with {@link Connection#close()} in order to be reused later.
+ */
+@ThreadSafe
+public class SharedConnectionFactory implements ConnectionPool {
+ private final TokenBucket tokenBucket;
+ private final ConnectionFactory factory;
+
+ @Nullable
+ @GuardedBy("connectionLock")
+ private AsyncSubject connectionAsyncSubject = null;
+
+ private final ReentrantLock connectionLock = new ReentrantLock();
+ private final AtomicReference connectionCreationDisposable =
+ new AtomicReference<>(null);
+
+ public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) {
+ this.factory = factory;
+
+ List initialTokens = new ArrayList<>(maxConcurrency);
+ for (int i = 0; i < maxConcurrency; ++i) {
+ initialTokens.add(i);
+ }
+ this.tokenBucket = new TokenBucket<>(initialTokens);
+ }
+
+ @Override
+ public void close() throws IOException {
+ tokenBucket.close();
+
+ Disposable d = connectionCreationDisposable.getAndSet(null);
+ if (d != null && !d.isDisposed()) {
+ d.dispose();
+ }
+
+ try {
+ connectionLock.lockInterruptibly();
+
+ if (connectionAsyncSubject != null) {
+ Connection connection = connectionAsyncSubject.getValue();
+ if (connection != null) {
+ connection.close();
+ }
+
+ if (!connectionAsyncSubject.hasComplete()) {
+ connectionAsyncSubject.onError(new IllegalStateException("closed"));
+ }
+ }
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ } finally {
+ connectionLock.unlock();
+ }
+ }
+
+ private AsyncSubject createUnderlyingConnectionIfNot() throws InterruptedException {
+ connectionLock.lockInterruptibly();
+ try {
+ if (connectionAsyncSubject == null || connectionAsyncSubject.hasThrowable()) {
+ connectionAsyncSubject =
+ factory
+ .create()
+ .doOnSubscribe(connectionCreationDisposable::set)
+ .toObservable()
+ .subscribeWith(AsyncSubject.create());
+ }
+
+ return connectionAsyncSubject;
+ } finally {
+ connectionLock.unlock();
+ }
+ }
+
+ private Single extends Connection> acquireConnection() {
+ return Single.fromCallable(this::createUnderlyingConnectionIfNot)
+ .flatMap(Single::fromObservable);
+ }
+
+ /**
+ * Reuses the underlying {@link Connection} and wait for it to be released if is exceeding {@code
+ * maxConcurrency}.
+ */
+ @Override
+ public Single create() {
+ return tokenBucket
+ .acquireToken()
+ .flatMap(
+ token ->
+ acquireConnection()
+ .doOnError(ignored -> tokenBucket.addToken(token))
+ .doOnDispose(() -> tokenBucket.addToken(token))
+ .map(
+ conn ->
+ new SharedConnection(
+ conn, /* onClose= */ () -> tokenBucket.addToken(token))));
+ }
+
+ /** Returns current number of available connections. */
+ public int numAvailableConnections() {
+ return tokenBucket.size();
+ }
+
+ /** A {@link Connection} which wraps an underlying connection and is shared between consumers. */
+ public static class SharedConnection implements Connection {
+ private final Connection connection;
+ private final Action onClose;
+
+ public SharedConnection(Connection connection, Action onClose) {
+ this.connection = connection;
+ this.onClose = onClose;
+ }
+
+ @Override
+ public ClientCall call(
+ MethodDescriptor method, CallOptions options) {
+ return connection.call(method, options);
+ }
+
+ @Override
+ public void close() throws IOException {
+ try {
+ onClose.run();
+ } catch (Throwable t) {
+ throw new IOException(t);
+ }
+ }
+
+ /** Returns the underlying connection this shared connection built on */
+ public Connection getUnderlyingConnection() {
+ return connection;
+ }
+ }
+}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java
new file mode 100644
index 00000000000000..124b79e4a2488d
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java
@@ -0,0 +1,354 @@
+// Copyright 2021 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
+import io.reactivex.rxjava3.core.Single;
+import io.reactivex.rxjava3.observers.TestObserver;
+import io.reactivex.rxjava3.plugins.RxJavaPlugins;
+import java.io.IOException;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+/** Tests for {@link SharedConnectionFactory}. */
+@RunWith(JUnit4.class)
+public class SharedConnectionFactoryTest {
+ @Rule public final MockitoRule mockito = MockitoJUnit.rule();
+
+ private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null);
+
+ @Mock private Connection connection;
+ @Mock private ConnectionFactory connectionFactory;
+
+ @Before
+ public void setUp() {
+ RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);
+
+ when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection));
+ }
+
+ @After
+ public void tearDown() throws Throwable {
+ // Make sure rxjava didn't receive global errors
+ Throwable t = rxGlobalThrowable.getAndSet(null);
+ if (t != null) {
+ throw t;
+ }
+ }
+
+ @Test
+ public void create_smoke() {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ assertThat(factory.numAvailableConnections()).isEqualTo(1);
+
+ TestObserver observer = factory.create().test();
+
+ observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+ verify(connectionFactory, times(1)).create();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ }
+
+ @Test
+ public void create_noConnectionCreationBeforeSubscription() {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+
+ factory.create();
+
+ verify(connectionFactory, times(0)).create();
+ }
+
+ @Test
+ public void create_exceedingMaxConcurrency_waiting() {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer1 = factory.create().test();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+
+ TestObserver observer2 = factory.create().test();
+ observer2.assertEmpty();
+ }
+
+ @Test
+ public void create_afterConnectionClosed_shareConnections() throws IOException {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer1 = factory.create().test();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+ TestObserver observer2 = factory.create().test();
+
+ observer1.values().get(0).close();
+
+ observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ }
+
+ @Test
+ public void create_belowMaxConcurrency_shareConnections() {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
+
+ TestObserver observer1 = factory.create().test();
+ assertThat(factory.numAvailableConnections()).isEqualTo(1);
+ observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+
+ TestObserver observer2 = factory.create().test();
+ observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ }
+
+ @Test
+ public void create_concurrentCreate_shareConnections() throws InterruptedException {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
+ Semaphore semaphore = new Semaphore(0);
+ AtomicBoolean finished = new AtomicBoolean(false);
+ Thread t =
+ new Thread(
+ () -> {
+ factory
+ .create()
+ .doOnSuccess(
+ conn -> {
+ assertThat(conn.getUnderlyingConnection()).isEqualTo(connection);
+ semaphore.release();
+ Thread.sleep(Integer.MAX_VALUE);
+ finished.set(true);
+ })
+ .blockingSubscribe();
+
+ finished.set(true);
+ });
+ t.start();
+ semaphore.acquire();
+
+ TestObserver observer = factory.create().test();
+
+ observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
+ assertThat(finished.get()).isFalse();
+ }
+
+ @Test
+ public void create_afterLastFailed_success() {
+ AtomicInteger times = new AtomicInteger(0);
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ when(connectionFactory.create())
+ .thenAnswer(
+ invocation -> {
+ if (times.getAndIncrement() == 0) {
+ return Single.error(new IllegalStateException("error"));
+ }
+
+ return Single.just(connection);
+ });
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ Single connectionSingle = factory.create();
+
+ connectionSingle
+ .test()
+ .assertError(IllegalStateException.class)
+ .assertError(e -> e.getMessage().contains("error"));
+ assertThat(factory.numAvailableConnections()).isEqualTo(1);
+ connectionSingle
+ .test()
+ .assertValue(conn -> conn.getUnderlyingConnection() == connection)
+ .assertComplete();
+
+ assertThat(times.get()).isEqualTo(2);
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+ }
+
+ @Test
+ public void create_disposeWhenWaitingForConnectionCreation_doNotCancelCreation()
+ throws InterruptedException {
+ AtomicBoolean canceled = new AtomicBoolean(false);
+ AtomicBoolean finished = new AtomicBoolean(false);
+ Semaphore disposed = new Semaphore(0);
+ Semaphore terminated = new Semaphore(0);
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ when(connectionFactory.create())
+ .thenAnswer(
+ invocation ->
+ Single.create(
+ emitter ->
+ new Thread(
+ () -> {
+ try {
+ disposed.acquire();
+ finished.set(true);
+ emitter.onSuccess(connection);
+ } catch (InterruptedException e) {
+ emitter.onError(e);
+ }
+ terminated.release();
+ })
+ .start())
+ .doOnDispose(() -> canceled.set(true)));
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer = factory.create().test();
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+
+ observer.assertEmpty().dispose();
+ disposed.release();
+
+ terminated.acquire();
+ assertThat(canceled.get()).isFalse();
+ assertThat(finished.get()).isTrue();
+ assertThat(factory.numAvailableConnections()).isEqualTo(1);
+ }
+
+ @Test
+ public void create_interrupt_terminate() throws InterruptedException {
+ AtomicBoolean finished = new AtomicBoolean(false);
+ AtomicBoolean interrupted = new AtomicBoolean(true);
+ Semaphore threadTerminatedSemaphore = new Semaphore(0);
+ Semaphore connectionCreationSemaphore = new Semaphore(0);
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ when(connectionFactory.create())
+ .thenAnswer(
+ invocation ->
+ Single.create(
+ emitter ->
+ new Thread(
+ () -> {
+ try {
+ Thread.sleep(Integer.MAX_VALUE);
+ finished.set(true);
+ emitter.onSuccess(connectionFactory);
+ } catch (InterruptedException e) {
+ emitter.onError(e);
+ }
+ })
+ .start()));
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
+ factory.create().test().assertEmpty();
+ Thread t =
+ new Thread(
+ () -> {
+ try {
+ TestObserver observer = factory.create().test();
+ connectionCreationSemaphore.release();
+ observer.await();
+ } catch (InterruptedException e) {
+ interrupted.set(true);
+ }
+
+ threadTerminatedSemaphore.release();
+ });
+ t.start();
+
+ connectionCreationSemaphore.acquire();
+ t.interrupt();
+ threadTerminatedSemaphore.acquire();
+
+ assertThat(finished.get()).isFalse();
+ assertThat(interrupted.get()).isTrue();
+ }
+
+ @Test
+ public void closeConnection_connectionBecomeAvailable() throws IOException {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer = factory.create().test();
+ observer.assertComplete();
+ SharedConnection conn = observer.values().get(0);
+ assertThat(factory.numAvailableConnections()).isEqualTo(0);
+
+ conn.close();
+
+ assertThat(factory.numAvailableConnections()).isEqualTo(1);
+ verify(connection, times(0)).close();
+ }
+
+ @Test
+ public void closeFactory_closeUnderlyingConnection() throws IOException {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer = factory.create().test();
+ observer.assertComplete();
+
+ factory.close();
+
+ verify(connection, times(1)).close();
+ }
+
+ @Test
+ public void closeFactory_noNewConnectionAllowed() throws IOException {
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ factory.close();
+
+ TestObserver observer = factory.create().test();
+
+ observer
+ .assertError(IllegalStateException.class)
+ .assertError(e -> e.getMessage().contains("closed"));
+ }
+
+ @Test
+ public void closeFactory_pendingConnectionCreation_closedError()
+ throws IOException, InterruptedException {
+ AtomicBoolean canceled = new AtomicBoolean(false);
+ AtomicBoolean finished = new AtomicBoolean(false);
+ Semaphore terminated = new Semaphore(0);
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ when(connectionFactory.create())
+ .thenAnswer(
+ invocation ->
+ Single.create(
+ emitter -> {
+ Thread t =
+ new Thread(
+ () -> {
+ try {
+ Thread.sleep(Integer.MAX_VALUE);
+ finished.set(true);
+ emitter.onSuccess(connection);
+ } catch (InterruptedException ignored) {
+ /* no-op */
+ }
+
+ terminated.release();
+ });
+ t.start();
+
+ emitter.setCancellable(t::interrupt);
+ })
+ .doOnDispose(() -> canceled.set(true)));
+ SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
+ TestObserver observer = factory.create().test();
+ observer.assertEmpty();
+
+ assertThat(canceled.get()).isFalse();
+ factory.close();
+
+ terminated.acquire();
+ observer
+ .assertError(IllegalStateException.class)
+ .assertError(e -> e.getMessage().contains("closed"));
+ assertThat(canceled.get()).isTrue();
+ assertThat(finished.get()).isFalse();
+ }
+}