diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index 24c93850c2c5f2..d38c39206b3776 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -94,8 +94,21 @@ public void ensureInputsPresent( Flowable uploads = createUploadTasks(context, merkleTree, additionalInputs, allDigests, force) - .flatMap(uploadTasks -> findMissingBlobs(context, uploadTasks)) - .flatMapPublisher(this::waitForUploadTasks); + .flatMapPublisher( + result -> + Flowable.using( + () -> result, + uploadTasks -> + findMissingBlobs(context, uploadTasks) + .flatMapPublisher(this::waitForUploadTasks), + uploadTasks -> { + for (UploadTask uploadTask : uploadTasks) { + Disposable d = uploadTask.disposable.getAndSet(null); + if (d != null) { + d.dispose(); + } + } + })); try { mergeBulkTransfer(uploads).blockingAwait(); @@ -175,15 +188,7 @@ private Maybe maybeCreateUploadTask( UploadTask uploadTask = new UploadTask(); uploadTask.digest = digest; uploadTask.disposable = new AtomicReference<>(); - uploadTask.completion = - Completable.fromObservable( - completion.doOnDispose( - () -> { - Disposable d = uploadTask.disposable.getAndSet(null); - if (d != null) { - d.dispose(); - } - })); + uploadTask.completion = Completable.fromObservable(completion); Completable upload = casUploadCache.execute( digest, @@ -238,44 +243,34 @@ private Single> findMissingBlobs( () -> Profiler.instance().profile("findMissingDigests"), ignored -> Single.fromObservable( - Observable.fromSingle( - toSingle( - () -> { - ImmutableList digestsToQuery = - uploadTasks.stream() - .filter(uploadTask -> uploadTask.continuation != null) - .map(uploadTask -> uploadTask.digest) - .collect(toImmutableList()); - if (digestsToQuery.isEmpty()) { - return immediateFuture(ImmutableSet.of()); - } - return findMissingDigests(context, digestsToQuery); - }, - directExecutor()) - .map( - missingDigests -> { - for (UploadTask uploadTask : uploadTasks) { - if (uploadTask.continuation != null) { - uploadTask.continuation.onSuccess( - missingDigests.contains(uploadTask.digest)); - } - } - return uploadTasks; - })) - // Use AsyncSubject so that if downstream is disposed, the - // findMissingDigests call is not cancelled (because it may be needed by - // other - // threads). - .subscribeWith(AsyncSubject.create())) - .doOnDispose( - () -> { - for (UploadTask uploadTask : uploadTasks) { - Disposable d = uploadTask.disposable.getAndSet(null); - if (d != null) { - d.dispose(); - } - } - }), + Observable.fromSingle( + toSingle( + () -> { + ImmutableList digestsToQuery = + uploadTasks.stream() + .filter(uploadTask -> uploadTask.continuation != null) + .map(uploadTask -> uploadTask.digest) + .collect(toImmutableList()); + if (digestsToQuery.isEmpty()) { + return immediateFuture(ImmutableSet.of()); + } + return findMissingDigests(context, digestsToQuery); + }, + directExecutor()) + .map( + missingDigests -> { + for (UploadTask uploadTask : uploadTasks) { + if (uploadTask.continuation != null) { + uploadTask.continuation.onSuccess( + missingDigests.contains(uploadTask.digest)); + } + } + return uploadTasks; + })) + // Use AsyncSubject so that if downstream is disposed, the + // findMissingDigests call is not cancelled (because it may be needed by + // other threads). + .subscribeWith(AsyncSubject.create())), SilentCloseable::close); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java index 5bd130756d5460..ff70a90b36ca4f 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteCacheTest.java @@ -27,6 +27,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningScheduledExecutorService; @@ -64,9 +65,8 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Deque; -import java.util.List; +import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; @@ -308,7 +308,7 @@ public void ensureInputsPresent_interruptedDuringUploadBlobs_cancelInProgressUpl RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); - List> futures = new ArrayList<>(); + Deque> futures = new ConcurrentLinkedDeque<>(); CountDownLatch uploadBlobCalls = new CountDownLatch(2); doAnswer( invocationOnMock -> { @@ -460,12 +460,14 @@ public void ensureInputsPresent_interruptedDuringUploadBlobs_cancelInProgressUpl RemoteCacheClient cacheProtocol = spy(new InMemoryCacheClient()); RemoteExecutionCache remoteCache = spy(newRemoteExecutionCache(cacheProtocol)); - List> futures = new ArrayList<>(); + ConcurrentLinkedDeque> uploadBlobFutures = new ConcurrentLinkedDeque<>(); + Map> uploadFileFutures = Maps.newConcurrentMap(); CountDownLatch uploadBlobCalls = new CountDownLatch(2); + CountDownLatch uploadFileCalls = new CountDownLatch(3); doAnswer( invocationOnMock -> { SettableFuture future = SettableFuture.create(); - futures.add(future); + uploadBlobFutures.add(future); uploadBlobCalls.countDown(); return future; }) @@ -473,60 +475,93 @@ public void ensureInputsPresent_interruptedDuringUploadBlobs_cancelInProgressUpl .uploadBlob(any(), any(), any()); doAnswer( invocationOnMock -> { + Path file = invocationOnMock.getArgument(2, Path.class); SettableFuture future = SettableFuture.create(); - futures.add(future); - uploadBlobCalls.countDown(); + uploadFileFutures.put(file, future); + uploadFileCalls.countDown(); return future; }) .when(cacheProtocol) .uploadFile(any(), any(), any()); - Path path = fs.getPath("/execroot/foo"); - FileSystemUtils.writeContentAsLatin1(path, "bar"); - SortedMap inputs = new TreeMap<>(); - inputs.put(PathFragment.create("foo"), path); - MerkleTree merkleTree = MerkleTree.build(inputs, digestUtil); + Path foo = fs.getPath("/execroot/foo"); + FileSystemUtils.writeContentAsLatin1(foo, "foo"); + Path bar = fs.getPath("/execroot/bar"); + FileSystemUtils.writeContentAsLatin1(bar, "bar"); + Path qux = fs.getPath("/execroot/qux"); + FileSystemUtils.writeContentAsLatin1(qux, "qux"); + + SortedMap input1 = new TreeMap<>(); + input1.put(PathFragment.create("foo"), foo); + input1.put(PathFragment.create("bar"), bar); + MerkleTree merkleTree1 = MerkleTree.build(input1, digestUtil); + + SortedMap input2 = new TreeMap<>(); + input2.put(PathFragment.create("bar"), bar); + input2.put(PathFragment.create("qux"), qux); + MerkleTree merkleTree2 = MerkleTree.build(input2, digestUtil); CountDownLatch ensureInputsPresentReturned = new CountDownLatch(2); CountDownLatch ensureInterrupted = new CountDownLatch(1); - Runnable work = - () -> { - try { - remoteCache.ensureInputsPresent(context, merkleTree, ImmutableMap.of(), false); - } catch (IOException ignored) { - // ignored - } catch (InterruptedException e) { - ensureInterrupted.countDown(); - } finally { - ensureInputsPresentReturned.countDown(); - } - }; - Thread thread1 = new Thread(work); - Thread thread2 = new Thread(work); + Thread thread1 = + new Thread( + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree1, ImmutableMap.of(), false); + } catch (IOException ignored) { + // ignored + } catch (InterruptedException e) { + ensureInterrupted.countDown(); + } finally { + ensureInputsPresentReturned.countDown(); + } + }); + Thread thread2 = + new Thread( + () -> { + try { + remoteCache.ensureInputsPresent(context, merkleTree2, ImmutableMap.of(), false); + } catch (InterruptedException | IOException ignored) { + // ignored + } finally { + ensureInputsPresentReturned.countDown(); + } + }); // act thread1.start(); thread2.start(); uploadBlobCalls.await(); - assertThat(futures).hasSize(2); - assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(2); + uploadFileCalls.await(); + assertThat(uploadBlobFutures).hasSize(2); + assertThat(uploadFileFutures).hasSize(3); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(5); thread1.interrupt(); ensureInterrupted.await(); // assert - assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(2); + assertThat(remoteCache.casUploadCache.getInProgressTasks()).hasSize(3); assertThat(remoteCache.casUploadCache.getFinishedTasks()).isEmpty(); - for (SettableFuture future : futures) { - assertThat(future.isCancelled()).isFalse(); + for (Map.Entry> entry : uploadFileFutures.entrySet()) { + Path file = entry.getKey(); + SettableFuture future = entry.getValue(); + if (file.equals(foo)) { + assertThat(future.isCancelled()).isTrue(); + } else { + assertThat(future.isCancelled()).isFalse(); + } } - for (SettableFuture future : futures) { + for (SettableFuture future : uploadBlobFutures) { + future.set(null); + } + for (SettableFuture future : uploadFileFutures.values()) { future.set(null); } ensureInputsPresentReturned.await(); assertThat(remoteCache.casUploadCache.getInProgressTasks()).isEmpty(); - assertThat(remoteCache.casUploadCache.getFinishedTasks()).hasSize(2); + assertThat(remoteCache.casUploadCache.getFinishedTasks()).hasSize(3); } @Test