diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index 3b3e1051dce7b0..bee979da2f557b 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -394,16 +394,14 @@ private ListenableFuture requestRead( CountingOutputStream rawOut, @Nullable Supplier digestSupplier, Channel channel) { + boolean compressed = shouldCompress(digest); String resourceName = getResourceName( - options.remoteInstanceName, - digest, - options.cacheCompression, - digestUtil.getDigestFunction()); + options.remoteInstanceName, digest, compressed, digestUtil.getDigestFunction()); SettableFuture future = SettableFuture.create(); OutputStream out; try { - out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut; + out = compressed ? new ZstdDecompressingOutputStream(rawOut) : rawOut; } catch (IOException e) { return Futures.immediateFailedFuture(e); } @@ -499,7 +497,7 @@ public ListenableFuture uploadFile( digest, Chunker.builder() .setInput(digest.getSizeBytes(), path) - .setCompressed(options.cacheCompression) + .setCompressed(shouldCompress(digest)) .build()); } @@ -511,7 +509,7 @@ public ListenableFuture uploadBlob( digest, Chunker.builder() .setInput(data.toByteArray()) - .setCompressed(options.cacheCompression) + .setCompressed(shouldCompress(digest)) .build()); } @@ -535,6 +533,10 @@ Retrier getRetrier() { return this.retrier; } + private boolean shouldCompress(Digest digest) { + return options.cacheCompression && digest.getSizeBytes() >= options.cacheCompressionThreshold; + } + public ReferenceCountedChannel getChannel() { return channel; } diff --git a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java index e87e78509d8353..37e0b643c2dc7c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java +++ b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java @@ -430,9 +430,21 @@ public RemoteBuildEventUploadModeConverter() { defaultValue = "false", documentationCategory = OptionDocumentationCategory.REMOTE, effectTags = {OptionEffectTag.UNKNOWN}, - help = "If enabled, compress/decompress cache blobs with zstd.") + help = + "If enabled, compress/decompress cache blobs with zstd when their size is at least" + + " --experimental_remote_cache_compression_threshold.") public boolean cacheCompression; + @Option( + name = "experimental_remote_cache_compression_threshold", + defaultValue = "0", + documentationCategory = OptionDocumentationCategory.REMOTE, + effectTags = {OptionEffectTag.UNKNOWN}, + help = + "The minimum blob size required to compress/decompress with zstd. Ineffectual unless" + + " --remote_cache_compression is set.") + public int cacheCompressionThreshold; + @Option( name = "build_event_upload_max_threads", defaultValue = "100", diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java index bbaad1bd15ac1a..f46a0be51741df 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java @@ -91,6 +91,8 @@ import com.google.devtools.common.options.Options; import com.google.gson.JsonObject; import com.google.protobuf.ByteString; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import io.grpc.BindableService; import io.grpc.CallCredentials; import io.grpc.CallOptions; @@ -129,14 +131,13 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; /** Tests for {@link GrpcCacheClient}. */ -@RunWith(JUnit4.class) +@RunWith(TestParameterInjector.class) public class GrpcCacheClientTest { private static final DigestUtil DIGEST_UTIL = new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256); @@ -1271,36 +1272,42 @@ public void read(ReadRequest request, StreamObserver responseObser } @Test - public void testCompressedDownload() throws IOException, InterruptedException { + public void testCompressedDownload(@TestParameter boolean overThreshold) + throws IOException, InterruptedException { RemoteOptions options = Options.getDefaults(RemoteOptions.class); options.cacheCompression = true; + options.cacheCompressionThreshold = 100; final GrpcCacheClient client = newClient(options); - final byte[] data = "abcdefg".getBytes(UTF_8); + final byte[] data = + overThreshold ? "0123456789".repeat(10).getBytes(UTF_8) : "0123456789".getBytes(UTF_8); final Digest digest = DIGEST_UTIL.compute(data); - final byte[] compressed = Zstd.compress(data); + final byte[] bytes = overThreshold ? Zstd.compress(data) : data; serviceRegistry.addService( new ByteStreamImplBase() { @Override public void read(ReadRequest request, StreamObserver responseObserver) { assertThat(request.getResourceName()).contains(digest.getHash()); + if (overThreshold) { + assertThat(request.getResourceName()).contains("compressed-blobs/zstd"); + } else { + assertThat(request.getResourceName()).doesNotContain("compressed-blobs/zstd"); + } responseObserver.onNext( ReadResponse.newBuilder() - .setData(ByteString.copyFrom(Arrays.copyOf(compressed, compressed.length / 3))) + .setData(ByteString.copyFrom(Arrays.copyOf(bytes, bytes.length / 3))) .build()); responseObserver.onNext( ReadResponse.newBuilder() .setData( ByteString.copyFrom( - Arrays.copyOfRange( - compressed, compressed.length / 3, compressed.length / 3 * 2))) + Arrays.copyOfRange(bytes, bytes.length / 3, bytes.length / 3 * 2))) .build()); responseObserver.onNext( ReadResponse.newBuilder() .setData( ByteString.copyFrom( - Arrays.copyOfRange( - compressed, compressed.length / 3 * 2, compressed.length))) + Arrays.copyOfRange(bytes, bytes.length / 3 * 2, bytes.length))) .build()); responseObserver.onCompleted(); }