From 24caa1d30a7cce01a61fcd70f6a65378bf9c261e Mon Sep 17 00:00:00 2001 From: Yun Peng Date: Wed, 6 Nov 2024 19:21:59 +0100 Subject: [PATCH] [7.4.1] Remove DownloadManager instance from RegistryFactoryImpl. (#24228) This PR removes the DownloadManager from the registry factory implementation. As the registry created by the factory is cached by the SkyFunction mechanism, the DownloadManager instance was living too long - it is supposed to be re-created for every command instantiation to respect changes in command line options, but for the registry, it ignored those changes. Instead, the DownloadManager is set directly into the affected SkyFunctions that require access to it. This way, the per-command DownloadManager instance is correctly used. This fixes https://github.com/bazelbuild/bazel/issues/24166. Note for reviewers: This is my first time touching code with SkyFunctions, so I don't really know what I'm doing. Closes #24212. PiperOrigin-RevId: 693644409 Change-Id: I7b16684e52673043615290d114f078ab7ab99fcf Co-authored-by: Cornelius Riemenschneider --- .../lib/bazel/BazelRepositoryModule.java | 26 ++++-- .../build/lib/bazel/bzlmod/IndexRegistry.java | 89 +++++++++++-------- .../lib/bazel/bzlmod/ModuleFileFunction.java | 9 +- .../build/lib/bazel/bzlmod/Registry.java | 9 +- .../lib/bazel/bzlmod/RegistryFactoryImpl.java | 8 -- .../lib/bazel/bzlmod/RepoSpecFunction.java | 8 +- .../bazel/bzlmod/YankedVersionsFunction.java | 8 +- .../skyframe/packages/BazelPackageLoader.java | 35 ++++---- .../build/lib/bazel/bzlmod/FakeRegistry.java | 9 +- .../lib/bazel/bzlmod/IndexRegistryTest.java | 62 +++++++------ 10 files changed, 158 insertions(+), 105 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/BazelRepositoryModule.java b/src/main/java/com/google/devtools/build/lib/bazel/BazelRepositoryModule.java index 1dc8e1ef1df341..b86c9a884a3a4e 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/BazelRepositoryModule.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/BazelRepositoryModule.java @@ -166,6 +166,9 @@ public class BazelRepositoryModule extends BlazeModule { private List allowedYankedVersions = ImmutableList.of(); private boolean disableNativeRepoRules; private SingleExtensionEvalFunction singleExtensionEvalFunction; + private ModuleFileFunction moduleFileFunction; + private RepoSpecFunction repoSpecFunction; + private YankedVersionsFunction yankedVersionsFunction; private final VendorCommand vendorCommand = new VendorCommand(clientEnvironmentSupplier); private final RegistryFactoryImpl registryFactory = @@ -255,14 +258,17 @@ public void workspaceInit( builtinModules = ModuleFileFunction.getBuiltinModules(directories.getEmbeddedBinariesRoot()); } + moduleFileFunction = + new ModuleFileFunction( + runtime.getRuleClassProvider().getBazelStarlarkEnvironment(), + directories.getWorkspace(), + builtinModules); + repoSpecFunction = new RepoSpecFunction(); + yankedVersionsFunction = new YankedVersionsFunction(); + builder .addSkyFunction(SkyFunctions.REPOSITORY_DIRECTORY, repositoryDelegatorFunction) - .addSkyFunction( - SkyFunctions.MODULE_FILE, - new ModuleFileFunction( - runtime.getRuleClassProvider().getBazelStarlarkEnvironment(), - directories.getWorkspace(), - builtinModules)) + .addSkyFunction(SkyFunctions.MODULE_FILE, moduleFileFunction) .addSkyFunction(SkyFunctions.BAZEL_DEP_GRAPH, new BazelDepGraphFunction()) .addSkyFunction( SkyFunctions.BAZEL_LOCK_FILE, new BazelLockFileFunction(directories.getWorkspace())) @@ -276,8 +282,8 @@ SkyFunctions.BAZEL_LOCK_FILE, new BazelLockFileFunction(directories.getWorkspace .addSkyFunction( SkyFunctions.REGISTRY, new RegistryFunction(registryFactory, directories.getWorkspace())) - .addSkyFunction(SkyFunctions.REPO_SPEC, new RepoSpecFunction()) - .addSkyFunction(SkyFunctions.YANKED_VERSIONS, new YankedVersionsFunction()) + .addSkyFunction(SkyFunctions.REPO_SPEC, repoSpecFunction) + .addSkyFunction(SkyFunctions.YANKED_VERSIONS, yankedVersionsFunction) .addSkyFunction( SkyFunctions.VENDOR_FILE, new VendorFileFunction(runtime.getRuleClassProvider().getBazelStarlarkEnvironment())) @@ -312,8 +318,10 @@ public void beforeCommand(CommandEnvironment env) throws AbruptExitException { DownloadManager downloadManager = new DownloadManager(repositoryCache, env.getDownloaderDelegate(), env.getHttpDownloader()); this.starlarkRepositoryFunction.setDownloadManager(downloadManager); + this.moduleFileFunction.setDownloadManager(downloadManager); + this.repoSpecFunction.setDownloadManager(downloadManager); + this.yankedVersionsFunction.setDownloadManager(downloadManager); this.vendorCommand.setDownloadManager(downloadManager); - this.registryFactory.setDownloadManager(downloadManager); clientEnvironmentSupplier.set(env.getRepoEnv()); PackageOptions pkgOptions = env.getOptions().getOptions(PackageOptions.class); diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistry.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistry.java index b9cb637ad6d906..98bc2cd784609d 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistry.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistry.java @@ -85,7 +85,6 @@ public enum KnownFileHashesMode { } private final URI uri; - private final DownloadManager downloadManager; private final Map clientEnv; private final Gson gson; private final ImmutableMap> knownFileHashes; @@ -99,14 +98,12 @@ public enum KnownFileHashesMode { public IndexRegistry( URI uri, - DownloadManager downloadManager, Map clientEnv, ImmutableMap> knownFileHashes, KnownFileHashesMode knownFileHashesMode, ImmutableMap previouslySelectedYankedVersions, Optional vendorDir) { this.uri = uri; - this.downloadManager = downloadManager; this.clientEnv = clientEnv; this.gson = new GsonBuilder() @@ -136,9 +133,12 @@ private String constructUrl(String base, String... segments) { /** Grabs a file from the given URL. Returns {@link Optional#empty} if the file doesn't exist. */ private Optional grabFile( - String url, ExtendedEventHandler eventHandler, boolean useChecksum) + String url, + ExtendedEventHandler eventHandler, + DownloadManager downloadManager, + boolean useChecksum) throws IOException, InterruptedException { - var maybeContent = doGrabFile(url, eventHandler, useChecksum); + var maybeContent = doGrabFile(downloadManager, url, eventHandler, useChecksum); if ((knownFileHashesMode == KnownFileHashesMode.USE_AND_UPDATE || knownFileHashesMode == KnownFileHashesMode.USE_IMMUTABLE_AND_UPDATE) && useChecksum) { @@ -148,7 +148,10 @@ private Optional grabFile( } private Optional doGrabFile( - String rawUrl, ExtendedEventHandler eventHandler, boolean useChecksum) + DownloadManager downloadManager, + String rawUrl, + ExtendedEventHandler eventHandler, + boolean useChecksum) throws IOException, InterruptedException { Optional checksum; if (knownFileHashesMode != KnownFileHashesMode.IGNORE && useChecksum) { @@ -225,12 +228,14 @@ private Optional doGrabFile( } @Override - public Optional getModuleFile(ModuleKey key, ExtendedEventHandler eventHandler) + public Optional getModuleFile( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException { String url = constructUrl( getUrl(), "modules", key.getName(), key.getVersion().toString(), "MODULE.bazel"); - Optional maybeContent = grabFile(url, eventHandler, /* useChecksum= */ true); + Optional maybeContent = + grabFile(url, eventHandler, downloadManager, /* useChecksum= */ true); return maybeContent.map(content -> ModuleFile.create(content, url)); } @@ -277,9 +282,13 @@ private static class GitRepoSourceJson { * if the file doesn't exist. */ private Optional grabJsonFile( - String url, ExtendedEventHandler eventHandler, boolean useChecksum) + String url, + ExtendedEventHandler eventHandler, + DownloadManager downloadManager, + boolean useChecksum) throws IOException, InterruptedException { - return grabFile(url, eventHandler, useChecksum).map(value -> new String(value, UTF_8)); + return grabFile(url, eventHandler, downloadManager, useChecksum) + .map(value -> new String(value, UTF_8)); } /** @@ -287,9 +296,13 @@ private Optional grabJsonFile( * T}. Returns {@link Optional#empty} if the file doesn't exist. */ private Optional grabJson( - String url, Class klass, ExtendedEventHandler eventHandler, boolean useChecksum) + String url, + Class klass, + ExtendedEventHandler eventHandler, + DownloadManager downloadManager, + boolean useChecksum) throws IOException, InterruptedException { - Optional jsonString = grabJsonFile(url, eventHandler, useChecksum); + Optional jsonString = grabJsonFile(url, eventHandler, downloadManager, useChecksum); if (jsonString.isEmpty() || jsonString.get().isBlank()) { return Optional.empty(); } @@ -307,10 +320,12 @@ private T parseJson(String jsonString, String url, Class klass) throws IO } @Override - public RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) + public RepoSpec getRepoSpec( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException { String jsonUrl = getSourceJsonUrl(key); - Optional jsonString = grabJsonFile(jsonUrl, eventHandler, /* useChecksum= */ true); + Optional jsonString = + grabJsonFile(jsonUrl, eventHandler, downloadManager, /* useChecksum= */ true); if (jsonString.isEmpty()) { throw new FileNotFoundException( String.format( @@ -318,27 +333,24 @@ public RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) } SourceJson sourceJson = parseJson(jsonString.get(), jsonUrl, SourceJson.class); switch (sourceJson.type) { - case "archive": - { - ArchiveSourceJson typedSourceJson = - parseJson(jsonString.get(), jsonUrl, ArchiveSourceJson.class); - return createArchiveRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler), key); - } - case "local_path": - { - LocalPathSourceJson typedSourceJson = - parseJson(jsonString.get(), jsonUrl, LocalPathSourceJson.class); - return createLocalPathRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler), key); - } - case "git_repository": - { - GitRepoSourceJson typedSourceJson = - parseJson(jsonString.get(), jsonUrl, GitRepoSourceJson.class); - return createGitRepoSpec(typedSourceJson); - } - default: - throw new IOException( - String.format("Invalid source type \"%s\" for module %s", sourceJson.type, key)); + case "archive" -> { + ArchiveSourceJson typedSourceJson = + parseJson(jsonString.get(), jsonUrl, ArchiveSourceJson.class); + return createArchiveRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler, downloadManager), key); + } + case "local_path" -> { + LocalPathSourceJson typedSourceJson = + parseJson(jsonString.get(), jsonUrl, LocalPathSourceJson.class); + return createLocalPathRepoSpec(typedSourceJson, getBazelRegistryJson(eventHandler, downloadManager), key); + } + case "git_repository" -> { + GitRepoSourceJson typedSourceJson = + parseJson(jsonString.get(), jsonUrl, GitRepoSourceJson.class); + return createGitRepoSpec(typedSourceJson); + } + default -> + throw new IOException( + String.format("Invalid source type \"%s\" for module %s", sourceJson.type, key)); } } @@ -347,7 +359,8 @@ private String getSourceJsonUrl(ModuleKey key) { getUrl(), "modules", key.getName(), key.getVersion().toString(), SOURCE_JSON_FILENAME); } - private Optional getBazelRegistryJson(ExtendedEventHandler eventHandler) + private Optional getBazelRegistryJson( + ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException { if (bazelRegistryJson == null || bazelRegistryJsonEvents == null) { synchronized (this) { @@ -359,6 +372,7 @@ private Optional getBazelRegistryJson(ExtendedEventHandler ev constructUrl(getUrl(), "bazel_registry.json"), BazelRegistryJson.class, storedEventHandler, + downloadManager, /* useChecksum= */ true); bazelRegistryJsonEvents = storedEventHandler; } @@ -488,13 +502,14 @@ private RepoSpec createGitRepoSpec(GitRepoSourceJson sourceJson) { @Override public Optional> getYankedVersions( - String moduleName, ExtendedEventHandler eventHandler) + String moduleName, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException { Optional metadataJson = grabJson( constructUrl(getUrl(), "modules", moduleName, "metadata.json"), MetadataJson.class, eventHandler, + downloadManager, // metadata.json is not immutable /* useChecksum= */ false); if (metadataJson.isEmpty()) { diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java index 7d1c977074a486..f1ec5b976dba60 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java @@ -32,6 +32,7 @@ import com.google.devtools.build.lib.bazel.bzlmod.ModuleFileValue.RootModuleFileValue; import com.google.devtools.build.lib.bazel.repository.PatchUtil; import com.google.devtools.build.lib.bazel.repository.downloader.Checksum.MissingChecksumException; +import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.cmdline.LabelConstants; import com.google.devtools.build.lib.cmdline.LabelSyntaxException; @@ -98,6 +99,7 @@ public class ModuleFileFunction implements SkyFunction { private final BazelStarlarkEnvironment starlarkEnv; private final Path workspaceRoot; private final ImmutableMap builtinModules; + @Nullable private DownloadManager downloadManager; private static final String BZLMOD_REMINDER = """ @@ -230,6 +232,10 @@ public SkyValue compute(SkyKey skyKey, Environment env) getModuleFileResult.downloadEventHandler.getPosts())); } + public void setDownloadManager(DownloadManager downloadManager) { + this.downloadManager = downloadManager; + } + @Nullable private SkyValue computeForRootModule(StarlarkSemantics starlarkSemantics, Environment env) throws ModuleFileFunctionException, InterruptedException { @@ -618,7 +624,8 @@ private GetModuleFileResult getModuleFile( StoredEventHandler downloadEventHandler = new StoredEventHandler(); for (Registry registry : registryObjects) { try { - Optional maybeModuleFile = registry.getModuleFile(key, downloadEventHandler); + Optional maybeModuleFile = + registry.getModuleFile(key, downloadEventHandler, this.downloadManager); if (maybeModuleFile.isEmpty()) { continue; } diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Registry.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Registry.java index 0fc178244a09ff..32a55321efc240 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Registry.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Registry.java @@ -16,6 +16,7 @@ package com.google.devtools.build.lib.bazel.bzlmod; import com.google.common.collect.ImmutableMap; +import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.events.ExtendedEventHandler; import com.google.devtools.build.skyframe.NotComparableSkyValue; import java.io.IOException; @@ -31,14 +32,16 @@ public interface Registry extends NotComparableSkyValue { * Retrieves the contents of the module file of the module identified by {@code key} from the * registry. Returns {@code Optional.empty()} when the module is not found in this registry. */ - Optional getModuleFile(ModuleKey key, ExtendedEventHandler eventHandler) + Optional getModuleFile( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException; /** * Retrieves the {@link RepoSpec} object that indicates how the contents of the module identified * by {@code key} should be materialized as a repo. */ - RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) + RepoSpec getRepoSpec( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException; /** @@ -46,7 +49,7 @@ RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) * Returns {@code Optional.empty()} when the information is not found in the registry. */ Optional> getYankedVersions( - String moduleName, ExtendedEventHandler eventHandler) + String moduleName, ExtendedEventHandler eventHandler, DownloadManager downloadManager) throws IOException, InterruptedException; /** diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RegistryFactoryImpl.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RegistryFactoryImpl.java index c15c847771327d..8ca7b5871467ec 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RegistryFactoryImpl.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RegistryFactoryImpl.java @@ -19,28 +19,21 @@ import com.google.devtools.build.lib.bazel.bzlmod.IndexRegistry.KnownFileHashesMode; import com.google.devtools.build.lib.bazel.repository.RepositoryOptions.LockfileMode; import com.google.devtools.build.lib.bazel.repository.downloader.Checksum; -import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.vfs.Path; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; -import javax.annotation.Nullable; /** Prod implementation of {@link RegistryFactory}. */ public class RegistryFactoryImpl implements RegistryFactory { - @Nullable private DownloadManager downloadManager; private final Supplier> clientEnvironmentSupplier; public RegistryFactoryImpl(Supplier> clientEnvironmentSupplier) { this.clientEnvironmentSupplier = clientEnvironmentSupplier; } - public void setDownloadManager(DownloadManager downloadManager) { - this.downloadManager = downloadManager; - } - @Override public Registry createRegistry( String url, @@ -76,7 +69,6 @@ public Registry createRegistry( }; return new IndexRegistry( uri, - downloadManager, clientEnvironmentSupplier.get(), knownFileHashes, knownFileHashesMode, diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RepoSpecFunction.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RepoSpecFunction.java index 88d7664ca13b36..748481372f3d66 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RepoSpecFunction.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/RepoSpecFunction.java @@ -15,6 +15,7 @@ package com.google.devtools.build.lib.bazel.bzlmod; +import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.events.StoredEventHandler; import com.google.devtools.build.lib.profiler.Profiler; import com.google.devtools.build.lib.profiler.ProfilerTask; @@ -32,6 +33,7 @@ * fetching required information from its {@link Registry}. */ public class RepoSpecFunction implements SkyFunction { + @Nullable private DownloadManager downloadManager; @Override @Nullable @@ -49,7 +51,7 @@ public SkyValue compute(SkyKey skyKey, Environment env) try (SilentCloseable c = Profiler.instance() .profile(ProfilerTask.BZLMOD, () -> "compute repo spec: " + key.getModuleKey())) { - repoSpec = registry.getRepoSpec(key.getModuleKey(), downloadEvents); + repoSpec = registry.getRepoSpec(key.getModuleKey(), downloadEvents, this.downloadManager); } catch (IOException e) { throw new RepoSpecException( ExternalDepsException.withCauseAndMessage( @@ -63,6 +65,10 @@ public SkyValue compute(SkyKey skyKey, Environment env) repoSpec, RegistryFileDownloadEvent.collectToMap(downloadEvents.getPosts())); } + public void setDownloadManager(DownloadManager downloadManager) { + this.downloadManager = downloadManager; + } + static final class RepoSpecException extends SkyFunctionException { RepoSpecException(ExternalDepsException cause) { diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/YankedVersionsFunction.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/YankedVersionsFunction.java index 57a2344f6657fd..8117dd12bf754f 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/YankedVersionsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/YankedVersionsFunction.java @@ -15,6 +15,7 @@ package com.google.devtools.build.lib.bazel.bzlmod; +import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.events.Event; import com.google.devtools.build.lib.profiler.Profiler; import com.google.devtools.build.lib.profiler.ProfilerTask; @@ -31,6 +32,7 @@ * Registry}. */ public class YankedVersionsFunction implements SkyFunction { + @Nullable private DownloadManager downloadManager; @Override @Nullable @@ -47,7 +49,7 @@ public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedExcept .profile( ProfilerTask.BZLMOD, () -> "getting yanked versions: " + key.getModuleName())) { return YankedVersionsValue.create( - registry.getYankedVersions(key.getModuleName(), env.getListener())); + registry.getYankedVersions(key.getModuleName(), env.getListener(), downloadManager)); } catch (IOException e) { env.getListener() .handle( @@ -60,4 +62,8 @@ public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedExcept return YankedVersionsValue.create(Optional.empty()); } } + + public void setDownloadManager(DownloadManager downloadManager) { + this.downloadManager = downloadManager; + } } diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/packages/BazelPackageLoader.java b/src/main/java/com/google/devtools/build/lib/skyframe/packages/BazelPackageLoader.java index 067333494e99b2..b29736d1137658 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/packages/BazelPackageLoader.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/packages/BazelPackageLoader.java @@ -140,23 +140,22 @@ public BazelPackageLoader buildImpl() { new DownloadManager(repositoryCache, httpDownloader, httpDownloader); RegistryFactoryImpl registryFactory = new RegistryFactoryImpl(Suppliers.ofInstance(ImmutableMap.of())); - registryFactory.setDownloadManager(downloadManager); // Allow tests to override the following functions to use fake registry or custom built-in // modules if (!this.extraSkyFunctions.containsKey(SkyFunctions.MODULE_FILE)) { - addExtraSkyFunctions( - ImmutableMap.of( - SkyFunctions.MODULE_FILE, - new ModuleFileFunction( - ruleClassProvider.getBazelStarlarkEnvironment(), - directories.getWorkspace(), - ModuleFileFunction.getBuiltinModules(directories.getEmbeddedBinariesRoot()) - .entrySet() - .stream() - .filter(e -> e.getKey().equals("bazel_tools")) - .collect( - ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))))); + ModuleFileFunction moduleFileFunction = + new ModuleFileFunction( + ruleClassProvider.getBazelStarlarkEnvironment(), + directories.getWorkspace(), + ModuleFileFunction.getBuiltinModules(directories.getEmbeddedBinariesRoot()) + .entrySet() + .stream() + .filter(e -> e.getKey().equals("bazel_tools")) + .collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))); + + addExtraSkyFunctions(ImmutableMap.of(SkyFunctions.MODULE_FILE, moduleFileFunction)); + moduleFileFunction.setDownloadManager(downloadManager); } if (!this.extraSkyFunctions.containsKey(SkyFunctions.REGISTRY)) { addExtraSkyFunctions( @@ -167,6 +166,12 @@ public BazelPackageLoader buildImpl() { StarlarkRepositoryFunction starlarkRepositoryFunction = new StarlarkRepositoryFunction(); starlarkRepositoryFunction.setDownloadManager(downloadManager); + RepoSpecFunction repoSpecFunction = new RepoSpecFunction(); + repoSpecFunction.setDownloadManager(downloadManager); + + YankedVersionsFunction yankedVersionsFunction = new YankedVersionsFunction(); + yankedVersionsFunction.setDownloadManager(downloadManager); + addExtraSkyFunctions( ImmutableMap.builder() .put( @@ -194,8 +199,8 @@ public BazelPackageLoader buildImpl() { new BazelLockFileFunction(directories.getWorkspace())) .put(SkyFunctions.BAZEL_DEP_GRAPH, new BazelDepGraphFunction()) .put(SkyFunctions.BAZEL_MODULE_RESOLUTION, new BazelModuleResolutionFunction()) - .put(SkyFunctions.REPO_SPEC, new RepoSpecFunction()) - .put(SkyFunctions.YANKED_VERSIONS, new YankedVersionsFunction()) + .put(SkyFunctions.REPO_SPEC, repoSpecFunction) + .put(SkyFunctions.YANKED_VERSIONS, yankedVersionsFunction) .buildOrThrow()); return new BazelPackageLoader(this); diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/FakeRegistry.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/FakeRegistry.java index 1f8ba8e32d4a83..3aa8e7c0eaa2b9 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/FakeRegistry.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/FakeRegistry.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap; import com.google.devtools.build.lib.bazel.repository.RepositoryOptions.LockfileMode; import com.google.devtools.build.lib.bazel.repository.downloader.Checksum; +import com.google.devtools.build.lib.bazel.repository.downloader.DownloadManager; import com.google.devtools.build.lib.events.ExtendedEventHandler; import com.google.devtools.build.lib.vfs.Path; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -65,7 +66,8 @@ public String getUrl() { } @Override - public Optional getModuleFile(ModuleKey key, ExtendedEventHandler eventHandler) { + public Optional getModuleFile( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) { String uri = String.format("%s/modules/%s/%s/MODULE.bazel", url, key.getName(), key.getVersion()); var maybeContent = Optional.ofNullable(modules.get(key)).map(value -> value.getBytes(UTF_8)); @@ -74,7 +76,8 @@ public Optional getModuleFile(ModuleKey key, ExtendedEventHandler ev } @Override - public RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) { + public RepoSpec getRepoSpec( + ModuleKey key, ExtendedEventHandler eventHandler, DownloadManager downloadManager) { RepoSpec repoSpec = RepoSpec.builder() .setRuleClassName("local_repository") @@ -99,7 +102,7 @@ public RepoSpec getRepoSpec(ModuleKey key, ExtendedEventHandler eventHandler) { @Override public Optional> getYankedVersions( - String moduleName, ExtendedEventHandler eventHandler) { + String moduleName, ExtendedEventHandler eventHandler, DownloadManager downloadManager) { return Optional.ofNullable(yankedVersionMap.get(moduleName)); } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistryTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistryTest.java index 1aac36dee9e7c1..f2732b954cb836 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistryTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/IndexRegistryTest.java @@ -89,7 +89,6 @@ public void setUp() throws Exception { HttpDownloader httpDownloader = new HttpDownloader(); downloadManager = new DownloadManager(repositoryCache, httpDownloader, httpDownloader); registryFactory = new RegistryFactoryImpl(Suppliers.ofInstance(ImmutableMap.of())); - registryFactory.setDownloadManager(downloadManager); } @Test @@ -104,11 +103,12 @@ public void testHttpUrl() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "lol".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/1.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter)).isEmpty(); + assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter, downloadManager)) + .isEmpty(); } @Test @@ -130,7 +130,7 @@ public void testHttpUrlWithNetrcCreds() throws Exception { var e = assertThrows( IOException.class, - () -> registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)); + () -> registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)); assertThat(e) .hasMessageThat() .isEqualTo( @@ -138,11 +138,12 @@ public void testHttpUrlWithNetrcCreds() throws Exception { .formatted(server.getUrl() + "/myreg/modules/foo/1.0/MODULE.bazel")); downloadManager.setNetrcCreds(new NetrcCredentials(netrc)); - assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "lol".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/1.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter)).isEmpty(); + assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter, downloadManager)) + .isEmpty(); } @Test @@ -160,9 +161,10 @@ public void testFileUrl() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)) .hasValue(ModuleFile.create("lol".getBytes(UTF_8), file.toURI().toString())); - assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter)).isEmpty(); + assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter, downloadManager)) + .isEmpty(); } @Test @@ -213,7 +215,7 @@ public void testGetArchiveRepoSpec() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)) .isEqualTo( new ArchiveRepoSpecBuilder() .setUrls( @@ -227,7 +229,7 @@ public void testGetArchiveRepoSpec() throws Exception { .setOverlay(ImmutableMap.of()) .setRemotePatchStrip(0) .build()); - assertThat(registry.getRepoSpec(createModuleKey("bar", "2.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("bar", "2.0"), reporter, downloadManager)) .isEqualTo( new ArchiveRepoSpecBuilder() .setUrls( @@ -245,7 +247,7 @@ public void testGetArchiveRepoSpec() throws Exception { .setRemotePatchStrip(3) .setOverlay(ImmutableMap.of()) .build()); - assertThat(registry.getRepoSpec(createModuleKey("baz", "3.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("baz", "3.0"), reporter, downloadManager)) .isEqualTo( new ArchiveRepoSpecBuilder() .setUrls( @@ -286,7 +288,7 @@ public void testGetLocalPathRepoSpec() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)) .isEqualTo( RepoSpec.builder() .setRuleClassName("local_repository") @@ -314,7 +316,7 @@ public void testGetRepoInvalidRegistryJsonSpec() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)) .isEqualTo( new ArchiveRepoSpecBuilder() .setUrls(ImmutableList.of("http://mysite.com/thing.zip")) @@ -353,7 +355,8 @@ public void testGetRepoInvalidModuleJsonSpec() throws Exception { ImmutableMap.of(), Optional.empty()); assertThrows( - IOException.class, () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)); + IOException.class, + () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)); } @Test @@ -386,7 +389,7 @@ public void testGetYankedVersion() throws Exception { ImmutableMap.of(), Optional.empty()); Optional> yankedVersion = - registry.getYankedVersions("red-pill", reporter); + registry.getYankedVersions("red-pill", reporter, downloadManager); assertThat(yankedVersion) .hasValue( ImmutableMap.of( @@ -412,7 +415,8 @@ public void testArchiveWithExplicitType() throws Exception { ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); - assertThat(registry.getRepoSpec(createModuleKey("archive_type", "1.0"), reporter)) + assertThat( + registry.getRepoSpec(createModuleKey("archive_type", "1.0"), reporter, downloadManager)) .isEqualTo( new ArchiveRepoSpecBuilder() .setUrls(ImmutableList.of("https://mysite.com/thing?format=zip")) @@ -446,15 +450,16 @@ public void testGetModuleFileChecksums() throws Exception { knownFiles, ImmutableMap.of(), Optional.empty()); - assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "old".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/1.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("foo", "2.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "2.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "new".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/2.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter)).isEmpty(); + assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter, downloadManager)) + .isEmpty(); var recordedChecksums = eventRecorder.getRecordedHashes(); assertThat( @@ -481,15 +486,16 @@ public void testGetModuleFileChecksums() throws Exception { server.unserve("/myreg/modules/foo/1.0/MODULE.bazel"); server.unserve("/myreg/modules/foo/2.0/MODULE.bazel"); server.serve("/myreg/modules/bar/1.0/MODULE.bazel", "no longer 404"); - assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "old".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/1.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("foo", "2.0"), reporter)) + assertThat(registry.getModuleFile(createModuleKey("foo", "2.0"), reporter, downloadManager)) .hasValue( ModuleFile.create( "new".getBytes(UTF_8), server.getUrl() + "/myreg/modules/foo/2.0/MODULE.bazel")); - assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter)).isEmpty(); + assertThat(registry.getModuleFile(createModuleKey("bar", "1.0"), reporter, downloadManager)) + .isEmpty(); } @Test @@ -513,7 +519,7 @@ public void testGetModuleFileChecksumMismatch() throws Exception { var e = assertThrows( IOException.class, - () -> registry.getModuleFile(createModuleKey("foo", "1.0"), reporter)); + () -> registry.getModuleFile(createModuleKey("foo", "1.0"), reporter, downloadManager)); assertThat(e) .hasMessageThat() .isEqualTo( @@ -551,7 +557,7 @@ public void testGetRepoSpecChecksum() throws Exception { Registry registry = registryFactory.createRegistry( server.getUrl(), LockfileMode.UPDATE, knownFiles, ImmutableMap.of(), Optional.empty()); - assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)) .isEqualTo( RepoSpec.builder() .setRuleClassName("local_repository") @@ -579,7 +585,7 @@ public void testGetRepoSpecChecksum() throws Exception { // changes. server.unserve("/bazel_registry.json"); server.unserve("/modules/foo/1.0/source.json"); - assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)) + assertThat(registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)) .isEqualTo( RepoSpec.builder() .setRuleClassName("local_repository") @@ -621,7 +627,8 @@ public void testGetRepoSpecChecksumMismatch() throws Exception { server.getUrl(), LockfileMode.UPDATE, knownFiles, ImmutableMap.of(), Optional.empty()); var e = assertThrows( - IOException.class, () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)); + IOException.class, + () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)); assertThat(e) .hasMessageThat() .isEqualTo( @@ -665,7 +672,8 @@ public void testBazelRegistryChecksumMismatch() throws Exception { server.getUrl(), LockfileMode.UPDATE, knownFiles, ImmutableMap.of(), Optional.empty()); var e = assertThrows( - IOException.class, () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter)); + IOException.class, + () -> registry.getRepoSpec(createModuleKey("foo", "1.0"), reporter, downloadManager)); assertThat(e) .hasMessageThat() .isEqualTo(