Skip to content

Commit

Permalink
Allow repo rules to download multiple things in parallel.
Browse files Browse the repository at this point in the history
This is accomplished by adding a "block" argument to the "download" call. If set,
the call returns a "pending download" object with one single method that waits for the download to finish.

Appropriate care is taken that downloads don't hang around after the repository function finishes running.

Fixes bazelbuild#19674 .

RELNOTES: None.
PiperOrigin-RevId: 588320352
Change-Id: Ib0f48b6c7c2a07e93a4af602b0045120bd418829
  • Loading branch information
lberki authored and bazel-io committed Jan 11, 2024
1 parent eb51790 commit 546db45
Show file tree
Hide file tree
Showing 6 changed files with 487 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/clock",
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//src/main/java/com/google/devtools/build/lib/events",
"//src/main/java/com/google/devtools/build/lib/profiler",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//src/main/java/com/google/devtools/build/lib/util",
"//src/main/java/com/google/devtools/build/lib/util:os",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@
import com.google.auth.Credentials;
import com.google.common.base.MoreObjects;
import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.devtools.build.lib.authandtls.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCacheHitEvent;
import com.google.devtools.build.lib.bazel.repository.downloader.UrlRewriter.RewrittenURL;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.profiler.Profiler;
import com.google.devtools.build.lib.profiler.SilentCloseable;
import com.google.devtools.build.lib.vfs.FileSystemUtils;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.PathFragment;
Expand All @@ -42,6 +46,10 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import javax.annotation.Nullable;

/**
Expand All @@ -51,6 +59,16 @@
* to disk.
*/
public class DownloadManager {
private static final ExecutorService DOWNLOAD_EXECUTOR =
Executors.newFixedThreadPool(
// There is also GrpcRemoteDownloader so if we set the thread pool to the same size as
// the allowed number of HTTP downloads, it might unnecessarily block. No, this is not a
// very
// principled approach; ideally, we'd grow the thread pool as needed with some generous
// upper
// limit.
2 * HttpDownloader.MAX_PARALLEL_DOWNLOADS,
new ThreadFactoryBuilder().setNameFormat("download-manager-%d").build());

private final RepositoryCache repositoryCache;
private List<Path> distdir = ImmutableList.of();
Expand Down Expand Up @@ -96,6 +114,69 @@ public void setCredentialFactory(CredentialFactory credentialFactory) {
this.credentialFactory = credentialFactory;
}

public Future<Path> startDownload(
List<URL> originalUrls,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Optional<String> type,
Path output,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv,
String context) {
return DOWNLOAD_EXECUTOR.submit(
() -> {
try (SilentCloseable c = Profiler.instance().profile("fetching: " + context)) {
return downloadInExecutor(
originalUrls,
authHeaders,
checksum,
canonicalId,
type,
output,
eventHandler,
clientEnv,
context);
}
});
}

public Path finalizeDownload(Future<Path> download) throws IOException, InterruptedException {
try {
return download.get();
} catch (ExecutionException e) {
Throwables.throwIfInstanceOf(e.getCause(), IOException.class);
Throwables.throwIfInstanceOf(e.getCause(), InterruptedException.class);
Throwables.throwIfUnchecked(e.getCause());
throw new IllegalStateException(e);
}
}

public Path download(
List<URL> originalUrls,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Optional<String> type,
Path output,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv,
String context)
throws IOException, InterruptedException {
Future<Path> future =
startDownload(
originalUrls,
authHeaders,
checksum,
canonicalId,
type,
output,
eventHandler,
clientEnv,
context);
return finalizeDownload(future);
}

/**
* Downloads file to disk and returns path.
*
Expand All @@ -114,7 +195,7 @@ public void setCredentialFactory(CredentialFactory credentialFactory) {
* @throws IOException if download was attempted and ended up failing
* @throws InterruptedException if this thread is being cast into oblivion
*/
public Path download(
private Path downloadInExecutor(
List<URL> originalUrls,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
* file to disk.
*/
public class HttpDownloader implements Downloader {
private static final int MAX_PARALLEL_DOWNLOADS = 8;
static final int MAX_PARALLEL_DOWNLOADS = 8;

private static final Semaphore SEMAPHORE = new Semaphore(MAX_PARALLEL_DOWNLOADS, true);
private static final Clock CLOCK = new JavaClock();
private static final Sleeper SLEEPER = new JavaSleeper();
Expand Down
Loading

0 comments on commit 546db45

Please sign in to comment.