Skip to content

Commit

Permalink
[7.4.1] Remove DownloadManager instance from RegistryFactoryImpl. (#2…
Browse files Browse the repository at this point in the history
…4228)

This PR removes the DownloadManager from the registry factory
implementation. As the registry created by the factory is cached by the
SkyFunction mechanism, the DownloadManager instance was living too long
- it is supposed to be re-created for every command instantiation to
respect changes in command line options, but for the registry, it
ignored those changes.

Instead, the DownloadManager is set directly into the affected
SkyFunctions that require access to it. This way, the per-command
DownloadManager instance is correctly used.

This fixes #24166.

Note for reviewers: This is my first time touching code with
SkyFunctions, so I don't really know what I'm doing.

Closes #24212.

PiperOrigin-RevId: 693644409
Change-Id: I7b16684e52673043615290d114f078ab7ab99fcf

Co-authored-by: Cornelius Riemenschneider <[email protected]>
  • Loading branch information
meteorcloudy and criemen authored Nov 6, 2024
1 parent bcc4e9f commit 24caa1d
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ public class BazelRepositoryModule extends BlazeModule {
private List<String> allowedYankedVersions = ImmutableList.of();
private boolean disableNativeRepoRules;
private SingleExtensionEvalFunction singleExtensionEvalFunction;
private ModuleFileFunction moduleFileFunction;
private RepoSpecFunction repoSpecFunction;
private YankedVersionsFunction yankedVersionsFunction;

private final VendorCommand vendorCommand = new VendorCommand(clientEnvironmentSupplier);
private final RegistryFactoryImpl registryFactory =
Expand Down Expand Up @@ -255,14 +258,17 @@ public void workspaceInit(
builtinModules = ModuleFileFunction.getBuiltinModules(directories.getEmbeddedBinariesRoot());
}

moduleFileFunction =
new ModuleFileFunction(
runtime.getRuleClassProvider().getBazelStarlarkEnvironment(),
directories.getWorkspace(),
builtinModules);
repoSpecFunction = new RepoSpecFunction();
yankedVersionsFunction = new YankedVersionsFunction();

builder
.addSkyFunction(SkyFunctions.REPOSITORY_DIRECTORY, repositoryDelegatorFunction)
.addSkyFunction(
SkyFunctions.MODULE_FILE,
new ModuleFileFunction(
runtime.getRuleClassProvider().getBazelStarlarkEnvironment(),
directories.getWorkspace(),
builtinModules))
.addSkyFunction(SkyFunctions.MODULE_FILE, moduleFileFunction)
.addSkyFunction(SkyFunctions.BAZEL_DEP_GRAPH, new BazelDepGraphFunction())
.addSkyFunction(
SkyFunctions.BAZEL_LOCK_FILE, new BazelLockFileFunction(directories.getWorkspace()))
Expand All @@ -276,8 +282,8 @@ SkyFunctions.BAZEL_LOCK_FILE, new BazelLockFileFunction(directories.getWorkspace
.addSkyFunction(
SkyFunctions.REGISTRY,
new RegistryFunction(registryFactory, directories.getWorkspace()))
.addSkyFunction(SkyFunctions.REPO_SPEC, new RepoSpecFunction())
.addSkyFunction(SkyFunctions.YANKED_VERSIONS, new YankedVersionsFunction())
.addSkyFunction(SkyFunctions.REPO_SPEC, repoSpecFunction)
.addSkyFunction(SkyFunctions.YANKED_VERSIONS, yankedVersionsFunction)
.addSkyFunction(
SkyFunctions.VENDOR_FILE,
new VendorFileFunction(runtime.getRuleClassProvider().getBazelStarlarkEnvironment()))
Expand Down Expand Up @@ -312,8 +318,10 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException {
DownloadManager downloadManager =
new DownloadManager(repositoryCache, env.getDownloaderDelegate(), env.getHttpDownloader());
this.starlarkRepositoryFunction.setDownloadManager(downloadManager);
this.moduleFileFunction.setDownloadManager(downloadManager);
this.repoSpecFunction.setDownloadManager(downloadManager);
this.yankedVersionsFunction.setDownloadManager(downloadManager);
this.vendorCommand.setDownloadManager(downloadManager);
this.registryFactory.setDownloadManager(downloadManager);

clientEnvironmentSupplier.set(env.getRepoEnv());
PackageOptions pkgOptions = env.getOptions().getOptions(PackageOptions.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ public enum KnownFileHashesMode {
}

private final URI uri;
private final DownloadManager downloadManager;
private final Map<String, String> clientEnv;
private final Gson gson;
private final ImmutableMap<String, Optional<Checksum>> knownFileHashes;
Expand All @@ -99,14 +98,12 @@ public enum KnownFileHashesMode {

public IndexRegistry(
URI uri,
DownloadManager downloadManager,
Map<String, String> clientEnv,
ImmutableMap<String, Optional<Checksum>> knownFileHashes,
KnownFileHashesMode knownFileHashesMode,
ImmutableMap<ModuleKey, String> previouslySelectedYankedVersions,
Optional<Path> vendorDir) {
this.uri = uri;
this.downloadManager = downloadManager;
this.clientEnv = clientEnv;
this.gson =
new GsonBuilder()
Expand Down Expand Up @@ -136,9 +133,12 @@ private String constructUrl(String base, String... segments) {

/** Grabs a file from the given URL. Returns {@link Optional#empty} if the file doesn't exist. */
private Optional<byte[]> grabFile(
String url, ExtendedEventHandler eventHandler, boolean useChecksum)
String url,
ExtendedEventHandler eventHandler,
DownloadManager downloadManager,
boolean useChecksum)
throws IOException, InterruptedException {
var maybeContent = doGrabFile(url, eventHandler, useChecksum);
var maybeContent = doGrabFile(downloadManager, url, eventHandler, useChecksum);
if ((knownFileHashesMode == KnownFileHashesMode.USE_AND_UPDATE
|| knownFileHashesMode == KnownFileHashesMode.USE_IMMUTABLE_AND_UPDATE)
&& useChecksum) {
Expand All @@ -148,7 +148,10 @@ private Optional<byte[]> grabFile(
}

private Optional<byte[]> doGrabFile(
String rawUrl, ExtendedEventHandler eventHandler, boolean useChecksum)
DownloadManager downloadManager,
String rawUrl,
ExtendedEventHandler eventHandler,
boolean useChecksum)
throws IOException, InterruptedException {
Optional<Checksum> checksum;
if (knownFileHashesMode != KnownFileHashesMode.IGNORE && useChecksum) {
Expand Down Expand Up @@ -225,12 +228,14 @@ private Optional<byte[]> doGrabFile(
}

@Override
public Optional<ModuleFile> getModuleFile(ModuleKey key, ExtendedEventHandler eventHandler)
public Optional<ModuleFile> getModuleFile(
ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException {
String url =
constructUrl(
getUrl(), "modules", key.getName(), key.getVersion().toString(), "MODULE.bazel");
Optional<byte[]> maybeContent = grabFile(url, eventHandler, /* useChecksum= */ true);
Optional<byte[]> maybeContent =
grabFile(url, eventHandler, downloadManager, /* useChecksum= */ true);
return maybeContent.map(content -> ModuleFile.create(content, url));
}

Expand Down Expand Up @@ -277,19 +282,27 @@ private static class GitRepoSourceJson {
* if the file doesn't exist.
*/
private Optional<String> grabJsonFile(
String url, ExtendedEventHandler eventHandler, boolean useChecksum)
String url,
ExtendedEventHandler eventHandler,
DownloadManager downloadManager,
boolean useChecksum)
throws IOException, InterruptedException {
return grabFile(url, eventHandler, useChecksum).map(value -> new String(value, UTF_8));
return grabFile(url, eventHandler, downloadManager, useChecksum)
.map(value -> new String(value, UTF_8));
}

/**
* Grabs a JSON file from the given URL, and returns it as a parsed object with fields in {@code
* T}. Returns {@link Optional#empty} if the file doesn't exist.
*/
private <T> Optional<T> grabJson(
String url, Class<T> klass, ExtendedEventHandler eventHandler, boolean useChecksum)
String url,
Class<T> klass,
ExtendedEventHandler eventHandler,
DownloadManager downloadManager,
boolean useChecksum)
throws IOException, InterruptedException {
Optional<String> jsonString = grabJsonFile(url, eventHandler, useChecksum);
Optional<String> jsonString = grabJsonFile(url, eventHandler, downloadManager, useChecksum);
if (jsonString.isEmpty() || jsonString.get().isBlank()) {
return Optional.empty();
}
Expand All @@ -307,38 +320,37 @@ private <T> T parseJson(String jsonString, String url, Class<T> klass) throws IO
}

@Override
public RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler)
public RepoSpec getRepoSpec(
ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException {
String jsonUrl = getSourceJsonUrl(key);
Optional<String> jsonString = grabJsonFile(jsonUrl, eventHandler, /* useChecksum= */ true);
Optional<String> jsonString =
grabJsonFile(jsonUrl, eventHandler, downloadManager, /* useChecksum= */ true);
if (jsonString.isEmpty()) {
throw new FileNotFoundException(
String.format(
"Module %s's %s not found in registry %s", key, SOURCE_JSON_FILENAME, getUrl()));
}
SourceJson sourceJson = parseJson(jsonString.get(), jsonUrl, SourceJson.class);
switch (sourceJson.type) {
case "archive":
{
ArchiveSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, ArchiveSourceJson.class);
return createArchiveRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler), key);
}
case "local_path":
{
LocalPathSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, LocalPathSourceJson.class);
return createLocalPathRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler), key);
}
case "git_repository":
{
GitRepoSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, GitRepoSourceJson.class);
return createGitRepoSpec(typedSourceJson);
}
default:
throw new IOException(
String.format("Invalid source type \"%s\" for module %s", sourceJson.type, key));
case "archive" -> {
ArchiveSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, ArchiveSourceJson.class);
return createArchiveRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler, downloadManager), key);
}
case "local_path" -> {
LocalPathSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, LocalPathSourceJson.class);
return createLocalPathRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler, downloadManager), key);
}
case "git_repository" -> {
GitRepoSourceJson typedSourceJson =
parseJson(jsonString.get(), jsonUrl, GitRepoSourceJson.class);
return createGitRepoSpec(typedSourceJson);
}
default ->
throw new IOException(
String.format("Invalid source type \"%s\" for module %s", sourceJson.type, key));
}
}

Expand All @@ -347,7 +359,8 @@ private String getSourceJsonUrl(ModuleKey key) {
getUrl(), "modules", key.getName(), key.getVersion().toString(), SOURCE_JSON_FILENAME);
}

private Optional<BazelRegistryJson> getBazelRegistryJson(ExtendedEventHandler eventHandler)
private Optional<BazelRegistryJson> getBazelRegistryJson(
ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException {
if (bazelRegistryJson == null || bazelRegistryJsonEvents == null) {
synchronized (this) {
Expand All @@ -359,6 +372,7 @@ private Optional<BazelRegistryJson> getBazelRegistryJson(ExtendedEventHandler ev
constructUrl(getUrl(), "bazel_registry.json"),
BazelRegistryJson.class,
storedEventHandler,
downloadManager,
/* useChecksum= */ true);
bazelRegistryJsonEvents = storedEventHandler;
}
Expand Down Expand Up @@ -488,13 +502,14 @@ private RepoSpec createGitRepoSpec(GitRepoSourceJson sourceJson) {

@Override
public Optional<ImmutableMap<Version, String>> getYankedVersions(
String moduleName, ExtendedEventHandler eventHandler)
String moduleName, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException {
Optional<MetadataJson> metadataJson =
grabJson(
constructUrl(getUrl(), "modules", moduleName, "metadata.json"),
MetadataJson.class,
eventHandler,
downloadManager,
// metadata.json is not immutable
/* useChecksum= */ false);
if (metadataJson.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.devtools.build.lib.bazel.bzlmod.ModuleFileValue.RootModuleFileValue;
import com.google.devtools.build.lib.bazel.repository.PatchUtil;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum.MissingChecksumException;
import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.LabelConstants;
import com.google.devtools.build.lib.cmdline.LabelSyntaxException;
Expand Down Expand Up @@ -98,6 +99,7 @@ public class ModuleFileFunction implements SkyFunction {
private final BazelStarlarkEnvironment starlarkEnv;
private final Path workspaceRoot;
private final ImmutableMap<String, NonRegistryOverride> builtinModules;
@Nullable private DownloadManager downloadManager;

private static final String BZLMOD_REMINDER =
"""
Expand Down Expand Up @@ -230,6 +232,10 @@ public SkyValue compute(SkyKey skyKey, Environment env)
getModuleFileResult.downloadEventHandler.getPosts()));
}

public void setDownloadManager(DownloadManager downloadManager) {
this.downloadManager = downloadManager;
}

@Nullable
private SkyValue computeForRootModule(StarlarkSemantics starlarkSemantics, Environment env)
throws ModuleFileFunctionException, InterruptedException {
Expand Down Expand Up @@ -618,7 +624,8 @@ private GetModuleFileResult getModuleFile(
StoredEventHandler downloadEventHandler = new StoredEventHandler();
for (Registry registry : registryObjects) {
try {
Optional<ModuleFile> maybeModuleFile = registry.getModuleFile(key, downloadEventHandler);
Optional<ModuleFile> maybeModuleFile =
registry.getModuleFile(key, downloadEventHandler, this.downloadManager);
if (maybeModuleFile.isEmpty()) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.google.devtools.build.lib.bazel.bzlmod;

import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.skyframe.NotComparableSkyValue;
import java.io.IOException;
Expand All @@ -31,22 +32,24 @@ public interface Registry extends NotComparableSkyValue {
* Retrieves the contents of the module file of the module identified by {@code key} from the
* registry. Returns {@code Optional.empty()} when the module is not found in this registry.
*/
Optional<ModuleFile> getModuleFile(ModuleKey key, ExtendedEventHandler eventHandler)
Optional<ModuleFile> getModuleFile(
ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException;

/**
* Retrieves the {@link RepoSpec} object that indicates how the contents of the module identified
* by {@code key} should be materialized as a repo.
*/
RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler)
RepoSpec getRepoSpec(
ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException;

/**
* Retrieves yanked versions of the module identified by {@code key.getName()} from the registry.
* Returns {@code Optional.empty()} when the information is not found in the registry.
*/
Optional<ImmutableMap<Version, String>> getYankedVersions(
String moduleName, ExtendedEventHandler eventHandler)
String moduleName, ExtendedEventHandler eventHandler, DownloadManager downloadManager)
throws IOException, InterruptedException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,21 @@
import com.google.devtools.build.lib.bazel.bzlmod.IndexRegistry.KnownFileHashesMode;
import com.google.devtools.build.lib.bazel.repository.RepositoryOptions.LockfileMode;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager;
import com.google.devtools.build.lib.vfs.Path;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import javax.annotation.Nullable;

/** Prod implementation of {@link RegistryFactory}. */
public class RegistryFactoryImpl implements RegistryFactory {
@Nullable private DownloadManager downloadManager;
private final Supplier<Map<String, String>> clientEnvironmentSupplier;

public RegistryFactoryImpl(Supplier<Map<String, String>> clientEnvironmentSupplier) {
this.clientEnvironmentSupplier = clientEnvironmentSupplier;
}

public void setDownloadManager(DownloadManager downloadManager) {
this.downloadManager = downloadManager;
}

@Override
public Registry createRegistry(
String url,
Expand Down Expand Up @@ -76,7 +69,6 @@ public Registry createRegistry(
};
return new IndexRegistry(
uri,
downloadManager,
clientEnvironmentSupplier.get(),
knownFileHashes,
knownFileHashesMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.google.devtools.build.lib.bazel.bzlmod;

import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager;
import com.google.devtools.build.lib.events.StoredEventHandler;
import com.google.devtools.build.lib.profiler.Profiler;
import com.google.devtools.build.lib.profiler.ProfilerTask;
Expand All @@ -32,6 +33,7 @@
* fetching required information from its {@link Registry}.
*/
public class RepoSpecFunction implements SkyFunction {
@Nullable private DownloadManager downloadManager;

@Override
@Nullable
Expand All @@ -49,7 +51,7 @@ public SkyValue compute(SkyKey skyKey, Environment env)
try (SilentCloseable c =
Profiler.instance()
.profile(ProfilerTask.BZLMOD, () -> "compute repo spec: " + key.getModuleKey())) {
repoSpec = registry.getRepoSpec(key.getModuleKey(), downloadEvents);
repoSpec = registry.getRepoSpec(key.getModuleKey(), downloadEvents, this.downloadManager);
} catch (IOException e) {
throw new RepoSpecException(
ExternalDepsException.withCauseAndMessage(
Expand All @@ -63,6 +65,10 @@ public SkyValue compute(SkyKey skyKey, Environment env)
repoSpec, RegistryFileDownloadEvent.collectToMap(downloadEvents.getPosts()));
}

public void setDownloadManager(DownloadManager downloadManager) {
this.downloadManager = downloadManager;
}

static final class RepoSpecException extends SkyFunctionException {

RepoSpecException(ExternalDepsException cause) {
Expand Down
Loading

0 comments on commit 24caa1d

Please sign in to comment.