From 65d3f63df47b14dbf1e6839fa3ebd7e269a315e9 Mon Sep 17 00:00:00 2001
From: Phillip Webb <pwebb@vmware.com>
Date: Tue, 27 Sep 2022 20:37:04 -0700
Subject: [PATCH] Refactor common code in NativeImagePlugin

Refactor the `configureJvmReachabilityConfigurationDirectories`
and `configureJvmReachabilityExcludeConfigArgs` to use a common
method since they share the same logic.
---
 .../buildtools/gradle/NativeImagePlugin.java  | 108 ++++++++----------
 1 file changed, 48 insertions(+), 60 deletions(-)

diff --git a/native-gradle-plugin/src/main/java/org/graalvm/buildtools/gradle/NativeImagePlugin.java b/native-gradle-plugin/src/main/java/org/graalvm/buildtools/gradle/NativeImagePlugin.java
index 12c099e9c..4e5c07a55 100644
--- a/native-gradle-plugin/src/main/java/org/graalvm/buildtools/gradle/NativeImagePlugin.java
+++ b/native-gradle-plugin/src/main/java/org/graalvm/buildtools/gradle/NativeImagePlugin.java
@@ -77,6 +77,7 @@
 import org.gradle.api.artifacts.ConfigurationContainer;
 import org.gradle.api.artifacts.ModuleVersionIdentifier;
 import org.gradle.api.artifacts.component.ModuleComponentIdentifier;
+import org.gradle.api.artifacts.result.ResolvedComponentResult;
 import org.gradle.api.attributes.Attribute;
 import org.gradle.api.attributes.AttributeContainer;
 import org.gradle.api.file.ArchiveOperations;
@@ -128,10 +129,13 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.function.BiFunction;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.regex.Pattern;
+import java.util.stream.Collector;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.graalvm.buildtools.gradle.internal.ConfigurationCacheSupport.serializablePredicateOf;
 import static org.graalvm.buildtools.gradle.internal.ConfigurationCacheSupport.serializableSupplierOf;
@@ -340,69 +344,52 @@ private void configureAutomaticTaskCreation(Project project,
         });
     }
 
-    private void configureJvmReachabilityConfigurationDirectories(Project project, GraalVMExtension graalExtension, NativeImageOptions options, SourceSet sourceSet) {
-        GraalVMReachabilityMetadataRepositoryExtension repositoryExtension = reachabilityExtensionOn(graalExtension);
-        Provider<GraalVMReachabilityMetadataService> serviceProvider = graalVMReachabilityMetadataService(project, repositoryExtension);
-        options.getConfigurationFileDirectories().from(repositoryExtension.getEnabled().flatMap(enabled -> {
-            if (enabled) {
-                if (repositoryExtension.getUri().isPresent()) {
-                    Configuration classpath = project.getConfigurations().getByName(sourceSet.getRuntimeClasspathConfigurationName());
-                    Set<String> excludedModules = repositoryExtension.getExcludedModules().getOrElse(Collections.emptySet());
-                    Map<String, String> forcedVersions = repositoryExtension.getModuleToConfigVersion().getOrElse(Collections.emptyMap());
-                    return serviceProvider.map(repo -> repo.findConfigurationsFor(query -> classpath.getIncoming().getResolutionResult().allComponents(component -> {
-                        ModuleVersionIdentifier moduleVersion = component.getModuleVersion();
-                        String module = Objects.requireNonNull(moduleVersion).getGroup() + ":" + moduleVersion.getName();
-                        if (!excludedModules.contains(module)) {
-                            query.forArtifact(artifact -> {
-                                artifact.gav(module + ":" + moduleVersion.getVersion());
-                                if (forcedVersions.containsKey(module)) {
-                                    artifact.forceConfigVersion(forcedVersions.get(module));
-                                }
-                            });
-                        }
-                        query.useLatestConfigWhenVersionIsUntested();
-                    })).stream()
-                            .map(configuration -> configuration.getDirectory().toAbsolutePath())
-                            .map(Path::toFile)
-                            .collect(Collectors.toList()));
-                }
-            }
-            return project.getProviders().provider(Collections::emptySet);
-        }));
+    private void configureJvmReachabilityConfigurationDirectories(Project project, GraalVMExtension graalExtension,
+            NativeImageOptions options, SourceSet sourceSet) {
+        options.getConfigurationFileDirectories().from(graalVMReachabilityQuery(project, graalExtension, sourceSet,
+                        configuration -> true, this::getConfigurationDirectory,
+                        Collectors.toList()));
+    }
+
+    private File getConfigurationDirectory(ModuleVersionIdentifier moduleVersion,
+            DirectoryConfiguration configuration) {
+        return configuration.getDirectory().toAbsolutePath().toFile();
+    }
+
+    private void configureJvmReachabilityExcludeConfigArgs(Project project, GraalVMExtension graalExtension,
+            NativeImageOptions options, SourceSet sourceSet) {
+        options.getExcludeConfig().putAll(graalVMReachabilityQuery(project, graalExtension, sourceSet,
+                        DirectoryConfiguration::isOverride, this::getExclusionConfig,
+                        Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
     }
 
-    private void configureJvmReachabilityExcludeConfigArgs(Project project, GraalVMExtension graalExtension, NativeImageOptions options, SourceSet sourceSet) {
-        GraalVMReachabilityMetadataRepositoryExtension repositoryExtension = reachabilityExtensionOn(graalExtension);
-        Provider<GraalVMReachabilityMetadataService> serviceProvider = graalVMReachabilityMetadataService(project, repositoryExtension);
-        options.getExcludeConfig().putAll(repositoryExtension.getEnabled().flatMap(enabled -> {
-            if (enabled) {
-                if (repositoryExtension.getUri().isPresent()) {
-                    Configuration classpath = project.getConfigurations().getByName(sourceSet.getRuntimeClasspathConfigurationName());
-                    Set<String> excludedModules = repositoryExtension.getExcludedModules().getOrElse(Collections.emptySet());
-                    Map<String, String> forcedVersions = repositoryExtension.getModuleToConfigVersion().getOrElse(Collections.emptyMap());
-                    return serviceProvider.map(repo -> classpath.getIncoming().getResolutionResult().getAllComponents().stream().flatMap(component -> {
+    private Map.Entry<String, List<String>> getExclusionConfig(ModuleVersionIdentifier moduleVersion,
+            DirectoryConfiguration configuration) {
+        String gav = moduleVersion.getGroup() + ":" + moduleVersion.getName() + ":" + moduleVersion.getVersion();
+        return new AbstractMap.SimpleEntry<>(gav, Arrays.asList("^/META-INF/native-image/.*"));
+    }
+
+    private <T, A, R> Provider<R> graalVMReachabilityQuery(Project project, GraalVMExtension graalExtension,
+            SourceSet sourceSet, Predicate<DirectoryConfiguration> filter,
+            BiFunction<ModuleVersionIdentifier, DirectoryConfiguration, T> mapper, Collector<T, A, R> collector) {
+        GraalVMReachabilityMetadataRepositoryExtension extension = reachabilityExtensionOn(graalExtension);
+        return extension.getEnabled().flatMap(enabled -> {
+            if (enabled && extension.getUri().isPresent()) {
+                Configuration classpath = project.getConfigurations().getByName(sourceSet.getRuntimeClasspathConfigurationName());
+                Set<String> excludedModules = extension.getExcludedModules().getOrElse(Collections.emptySet());
+                Map<String, String> forcedVersions = extension.getModuleToConfigVersion().getOrElse(Collections.emptyMap());
+                return graalVMReachabilityMetadataService(project, extension).map(service -> {
+                    Set<ResolvedComponentResult> components = classpath.getIncoming().getResolutionResult().getAllComponents();
+                    Stream<T> mapped = components.stream().flatMap(component -> {
                         ModuleVersionIdentifier moduleVersion = component.getModuleVersion();
-                        return repo.findConfigurationsFor(query -> {
-                                    String module = Objects.requireNonNull(moduleVersion).getGroup() + ":" + moduleVersion.getName();
-                                    if (!excludedModules.contains(module)) {
-                                        query.forArtifact(artifact -> {
-                                            artifact.gav(module + ":" + moduleVersion.getVersion());
-                                            if (forcedVersions.containsKey(module)) {
-                                                artifact.forceConfigVersion(forcedVersions.get(module));
-                                            }
-                                        });
-                                    }
-                                    query.useLatestConfigWhenVersionIsUntested();
-                                }).stream()
-                                .filter(DirectoryConfiguration::isOverride)
-                                .map(configuration -> new AbstractMap.SimpleEntry<>(
-                                        moduleVersion.getGroup() + ":" + moduleVersion.getName() + ":" + moduleVersion.getVersion(),
-                                        Arrays.asList("^/META-INF/native-image/.*")));
-                    }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
-                }
+                        Set<DirectoryConfiguration> configurations = service.findConfigurationsFor(excludedModules, forcedVersions, moduleVersion);
+                        return configurations.stream().filter(filter).map(configuration -> mapper.apply(moduleVersion, configuration));
+                    });
+                    return mapped.collect(collector);
+                });
             }
-            return project.getProviders().provider(Collections::emptyMap);
-        }));
+            return project.getProviders().provider(() -> Stream.<T>empty().collect(collector));
+        });
     }
 
     private Provider<GraalVMReachabilityMetadataService> graalVMReachabilityMetadataService(Project project,
@@ -413,7 +400,8 @@ private Provider<GraalVMReachabilityMetadataService> graalVMReachabilityMetadata
                     LogLevel logLevel = determineLogLevel();
                     spec.getParameters().getLogLevel().set(logLevel);
                     spec.getParameters().getUri().set(repositoryExtension.getUri());
-                    spec.getParameters().getCacheDir().set(new File(project.getGradle().getGradleUserHomeDir(), "native-build-tools/repositories"));
+                    spec.getParameters().getCacheDir().set(
+                            new File(project.getGradle().getGradleUserHomeDir(), "native-build-tools/repositories"));
                 });
     }