diff --git a/build.gradle b/build.gradle index f1441cefb17b..0eae57bd10f3 100644 --- a/build.gradle +++ b/build.gradle @@ -119,7 +119,7 @@ apply from: file('gradle/ide/eclipse.gradle') // (java, tests) apply from: file('gradle/java/folder-layout.gradle') apply from: file('gradle/java/javac.gradle') -apply from: file('gradle/java/memorysegment-mrjar.gradle') +apply from: file('gradle/java/core-mrjar.gradle') apply from: file('gradle/testing/defaults-tests.gradle') apply from: file('gradle/testing/randomization.gradle') apply from: file('gradle/testing/fail-on-no-tests.gradle') @@ -158,7 +158,7 @@ apply from: file('gradle/generation/javacc.gradle') apply from: file('gradle/generation/forUtil.gradle') apply from: file('gradle/generation/antlr.gradle') apply from: file('gradle/generation/unicode-test-classes.gradle') -apply from: file('gradle/generation/panama-foreign.gradle') +apply from: file('gradle/generation/extract-jdk-apis.gradle') apply from: file('gradle/datasets/external-datasets.gradle') diff --git a/gradle/generation/panama-foreign.gradle b/gradle/generation/extract-jdk-apis.gradle similarity index 76% rename from gradle/generation/panama-foreign.gradle rename to gradle/generation/extract-jdk-apis.gradle index 119dc8fad008..834a490cc0d9 100644 --- a/gradle/generation/panama-foreign.gradle +++ b/gradle/generation/extract-jdk-apis.gradle @@ -17,10 +17,17 @@ def resources = scriptResources(buildscript) +configure(rootProject) { + ext { + // also change this in extractor tool: ExtractForeignAPI + vectorIncubatorJavaVersions = [ JavaVersion.VERSION_20 ] as Set + } +} + configure(project(":lucene:core")) { ext { apijars = file('src/generated/jdk'); - panamaJavaVersions = [ 19, 20 ] + mrjarJavaVersions = [ 19, 20 ] } configurations { @@ -31,9 +38,9 @@ configure(project(":lucene:core")) { apiextractor "org.ow2.asm:asm:${scriptDepVersions['asm']}" } - for (jdkVersion : panamaJavaVersions) { - def task = tasks.create(name: "generatePanamaForeignApiJar${jdkVersion}", type: JavaExec) { - description "Regenerate the API-only JAR file with public Panama Foreign API from JDK ${jdkVersion}" + for (jdkVersion : mrjarJavaVersions) { + def task = tasks.create(name: "generateJdkApiJar${jdkVersion}", type: JavaExec) { + description "Regenerate the API-only JAR file with public Panama Foreign & Vector API from JDK ${jdkVersion}" group "generation" javaLauncher = javaToolchains.launcherFor { @@ -45,21 +52,21 @@ configure(project(":lucene:core")) { javaLauncher.get() return true } catch (Exception e) { - logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign API JAR.', jdkVersion) + logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign & Vector API JAR.', jdkVersion) logger.warn('Error: {}', e.cause?.message) logger.warn("Please make sure to point env 'JAVA{}_HOME' to exactly JDK version {} or enable Gradle toolchain auto-download.", jdkVersion, jdkVersion) return false } } - + classpath = configurations.apiextractor - mainClass = file("${resources}/ExtractForeignAPI.java") as String + mainClass = file("${resources}/ExtractJdkApis.java") as String systemProperties = [ 'user.timezone': 'UTC' ] args = [ jdkVersion, - new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar"), + new File(apijars, "jdk${jdkVersion}.apijar"), ] } diff --git a/gradle/generation/extract-jdk-apis/ExtractJdkApis.java b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java new file mode 100644 index 000000000000..b7e99cb83cc8 --- /dev/null +++ b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +import org.objectweb.asm.AnnotationVisitor; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.FieldVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +public final class ExtractJdkApis { + + private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z")); + + private static final String PATTERN_PANAMA_FOREIGN = "java.base/java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}"; + private static final String PATTERN_VECTOR_INCUBATOR = "jdk.incubator.vector/jdk/incubator/vector/*"; + private static final String PATTERN_VECTOR_VM_INTERNALS = "java.base/jdk/internal/vm/vector/VectorSupport{,$Vector,$VectorMask,$VectorPayload,$VectorShuffle}"; + + static final Map> CLASSFILE_PATTERNS = Map.of( + 19, List.of(PATTERN_PANAMA_FOREIGN), + 20, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR) + ); + + public static void main(String... args) throws IOException { + if (args.length != 2) { + throw new IllegalArgumentException("Need two parameters: java version, output file"); + } + Integer jdk = Integer.valueOf(args[0]); + if (jdk.intValue() != Runtime.version().feature()) { + throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature()); + } + if (!CLASSFILE_PATTERNS.containsKey(jdk)) { + throw new IllegalArgumentException("No support to extract stubs from java version: " + jdk); + } + var outputPath = Paths.get(args[1]); + + // create JRT filesystem and build a combined FileMatcher: + var jrtPath = Paths.get(URI.create("jrt:/")).toRealPath(); + var patterns = CLASSFILE_PATTERNS.get(jdk).stream() + .map(pattern -> jrtPath.getFileSystem().getPathMatcher("glob:" + pattern + ".class")) + .toArray(PathMatcher[]::new); + PathMatcher pattern = p -> Arrays.stream(patterns).anyMatch(matcher -> matcher.matches(p)); + + // Collect all files to process: + final List filesToExtract; + try (var stream = Files.walk(jrtPath)) { + filesToExtract = stream.filter(p -> pattern.matches(jrtPath.relativize(p))).collect(Collectors.toList()); + } + + // Process all class files: + try (var out = new ZipOutputStream(Files.newOutputStream(outputPath))) { + process(filesToExtract, out); + } + } + + private static void process(List filesToExtract, ZipOutputStream out) throws IOException { + var classesToInclude = new HashSet(); + var references = new HashMap(); + var processed = new TreeMap(); + System.out.println("Transforming " + filesToExtract.size() + " class files..."); + for (Path p : filesToExtract) { + try (var in = Files.newInputStream(p)) { + var reader = new ClassReader(in); + var cw = new ClassWriter(0); + var cleaner = new Cleaner(cw, classesToInclude, references); + reader.accept(cleaner, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); + processed.put(reader.getClassName(), cw.toByteArray()); + } + } + // recursively add all superclasses / interfaces of visible classes to classesToInclude: + for (Set a = classesToInclude; !a.isEmpty();) { + a = a.stream().map(references::get).filter(Objects::nonNull).flatMap(Arrays::stream).collect(Collectors.toSet()); + classesToInclude.addAll(a); + } + // remove all non-visible or not referenced classes: + processed.keySet().removeIf(Predicate.not(classesToInclude::contains)); + System.out.println("Writing " + processed.size() + " visible classes..."); + for (var cls : processed.entrySet()) { + String cn = cls.getKey(); + System.out.println("Writing stub for class: " + cn); + out.putNextEntry(new ZipEntry(cn.concat(".class")).setLastModifiedTime(FIXED_FILEDATE)); + out.write(cls.getValue()); + out.closeEntry(); + } + classesToInclude.removeIf(processed.keySet()::contains); + System.out.println("Referenced classes not included: " + classesToInclude); + } + + static boolean isVisible(int access) { + return (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) != 0; + } + + static class Cleaner extends ClassVisitor { + private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature"; + private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor(); + + private final Set classesToInclude; + private final Map references; + + Cleaner(ClassWriter out, Set classesToInclude, Map references) { + super(Opcodes.ASM9, out); + this.classesToInclude = classesToInclude; + this.references = references; + } + + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + super.visit(Opcodes.V11, access, name, signature, superName, interfaces); + if (isVisible(access)) { + classesToInclude.add(name); + } + references.put(name, Stream.concat(Stream.of(superName), Arrays.stream(interfaces)).toArray(String[]::new)); + } + + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + + @Override + public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { + if (!isVisible(access)) { + return null; + } + return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) { + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + }; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { + if (!isVisible(access)) { + return null; + } + return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) { + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + }; + } + + @Override + public void visitInnerClass(String name, String outerName, String innerName, int access) { + if (!Objects.equals(outerName, PREVIEW_ANN)) { + super.visitInnerClass(name, outerName, innerName, access); + } + } + + @Override + public void visitPermittedSubclass​(String c) { + } + + } + +} diff --git a/gradle/generation/panama-foreign/ExtractForeignAPI.java b/gradle/generation/panama-foreign/ExtractForeignAPI.java deleted file mode 100644 index 44253ea0122b..000000000000 --- a/gradle/generation/panama-foreign/ExtractForeignAPI.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import java.io.IOException; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.attribute.FileTime; -import java.time.Instant; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.zip.ZipEntry; -import java.util.zip.ZipOutputStream; - -import org.objectweb.asm.AnnotationVisitor; -import org.objectweb.asm.ClassReader; -import org.objectweb.asm.ClassVisitor; -import org.objectweb.asm.ClassWriter; -import org.objectweb.asm.FieldVisitor; -import org.objectweb.asm.MethodVisitor; -import org.objectweb.asm.Opcodes; -import org.objectweb.asm.Type; - -public final class ExtractForeignAPI { - - private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z")); - - public static void main(String... args) throws IOException { - if (args.length != 2) { - throw new IllegalArgumentException("Need two parameters: java version, output file"); - } - if (Integer.parseInt(args[0]) != Runtime.version().feature()) { - throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature()); - } - var outputPath = Paths.get(args[1]); - var javaBaseModule = Paths.get(URI.create("jrt:/")).resolve("java.base").toRealPath(); - var fileMatcher = javaBaseModule.getFileSystem().getPathMatcher("glob:java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}.class"); - try (var out = new ZipOutputStream(Files.newOutputStream(outputPath)); var stream = Files.walk(javaBaseModule)) { - var filesToExtract = stream.map(javaBaseModule::relativize).filter(fileMatcher::matches).sorted().collect(Collectors.toList()); - for (Path relative : filesToExtract) { - System.out.println("Processing class file: " + relative); - try (var in = Files.newInputStream(javaBaseModule.resolve(relative))) { - final var reader = new ClassReader(in); - final var cw = new ClassWriter(0); - reader.accept(new Cleaner(cw), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); - out.putNextEntry(new ZipEntry(relative.toString()).setLastModifiedTime(FIXED_FILEDATE)); - out.write(cw.toByteArray()); - out.closeEntry(); - } - } - } - } - - static class Cleaner extends ClassVisitor { - private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature"; - private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor(); - - private boolean completelyHidden = false; - - Cleaner(ClassWriter out) { - super(Opcodes.ASM9, out); - } - - private boolean isHidden(int access) { - return completelyHidden || (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) == 0; - } - - @Override - public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { - super.visit(Opcodes.V11, access, name, signature, superName, interfaces); - completelyHidden = isHidden(access); - } - - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - - @Override - public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { - if (isHidden(access)) { - return null; - } - return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) { - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - }; - } - - @Override - public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { - if (isHidden(access)) { - return null; - } - return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) { - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - }; - } - - @Override - public void visitInnerClass(String name, String outerName, String innerName, int access) { - if (!Objects.equals(outerName, PREVIEW_ANN)) { - super.visitInnerClass(name, outerName, innerName, access); - } - } - - @Override - public void visitPermittedSubclass​(String c) { - } - - } - -} diff --git a/gradle/java/memorysegment-mrjar.gradle b/gradle/java/core-mrjar.gradle similarity index 82% rename from gradle/java/memorysegment-mrjar.gradle rename to gradle/java/core-mrjar.gradle index 137f8a3c567d..5715e782f000 100644 --- a/gradle/java/memorysegment-mrjar.gradle +++ b/gradle/java/core-mrjar.gradle @@ -15,11 +15,11 @@ * limitations under the License. */ -// Produce an MR-JAR with Java 19+ MemorySegment implementation for MMapDirectory +// Produce an MR-JAR with Java 19+ foreign and vector implementations configure(project(":lucene:core")) { plugins.withType(JavaPlugin) { - for (jdkVersion : panamaJavaVersions) { + for (jdkVersion : mrjarJavaVersions) { sourceSets.create("main${jdkVersion}") { java { srcDirs = ["src/java${jdkVersion}"] @@ -29,7 +29,7 @@ configure(project(":lucene:core")) { dependencies.add("main${jdkVersion}Implementation", sourceSets.main.output) tasks.named("compileMain${jdkVersion}Java").configure { - def apijar = new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar") + def apijar = new File(apijars, "jdk${jdkVersion}.apijar") inputs.file(apijar) @@ -40,12 +40,14 @@ configure(project(":lucene:core")) { "-Xlint:-options", "--patch-module", "java.base=${apijar}", "--add-exports", "java.base/java.lang.foreign=ALL-UNNAMED", + // for compilation we patch the incubator packages into java.base, this has no effect on resulting class files: + "--add-exports", "java.base/jdk.incubator.vector=ALL-UNNAMED", ] } } tasks.named('jar').configure { - for (jdkVersion : panamaJavaVersions) { + for (jdkVersion : mrjarJavaVersions) { into("META-INF/versions/${jdkVersion}") { from sourceSets["main${jdkVersion}"].output } diff --git a/gradle/testing/defaults-tests.gradle b/gradle/testing/defaults-tests.gradle index 9f50cda8ca79..f7a348f0b66c 100644 --- a/gradle/testing/defaults-tests.gradle +++ b/gradle/testing/defaults-tests.gradle @@ -47,7 +47,7 @@ allprojects { description: "Number of forked test JVMs"], [propName: 'tests.haltonfailure', value: true, description: "Halt processing on test failure."], [propName: 'tests.jvmargs', - value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") }, + value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", isCIBuild ? "" : "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") }, description: "Arguments passed to each forked JVM."], // Other settings. [propName: 'tests.neverUpToDate', value: true, @@ -119,11 +119,16 @@ allprojects { if (rootProject.runtimeJavaVersion < JavaVersion.VERSION_16) { jvmArgs '--illegal-access=deny' } - + // Lucene needs to optional modules at runtime, which we want to enforce for testing // (if the runner JVM does not support them, it will fail tests): jvmArgs '--add-modules', 'jdk.unsupported,jdk.management' + // Enable the vector incubator module on supported Java versions: + if (rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion)) { + jvmArgs '--add-modules', 'jdk.incubator.vector' + } + def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties") def tempDir = layout.projectDirectory.dir(testsTmpDir.toString()) jvmArgumentProviders.add( diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ebd8486df156..89d645fa9bf9 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -20,6 +20,15 @@ New Features * GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai) +* GITHUB#12302, GITHUB#12311: Add vectorized implementations of VectorUtil.dotProduct(), + squareDistance(), cosine() with Java 20 jdk.incubator.vector APIs. Applications started + with command line parameter "java --add-modules jdk.incubator.vector" on exactly Java 20 + will automatically use the new vectorized implementations if running on a supported platform + (x86 AVX2 or later, ARM SVE or later). This is an opt-in feature and requires explicit Java + command line flag! When enabled, Lucene logs a notice using java.util.logging. Please test + thoroughly and report bugs/slowness to Lucene's mailing list. + (Chris Hegarty, Robert Muir, Uwe Schindler) + Improvements --------------------- diff --git a/lucene/core/src/generated/jdk/README.md b/lucene/core/src/generated/jdk/README.md index 371bbebf8518..48a014b992b8 100644 --- a/lucene/core/src/generated/jdk/README.md +++ b/lucene/core/src/generated/jdk/README.md @@ -40,4 +40,4 @@ to point the Lucene build system to missing JDK versions. The regeneration task a warning if a specific JDK is missing, leaving the already existing `.apijar` file untouched. -The extraction is done with the ASM library, see `ExtractForeignAPI.java` source code. +The extraction is done with the ASM library, see `ExtractJdkApis.java` source code. diff --git a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar b/lucene/core/src/generated/jdk/jdk19.apijar similarity index 54% rename from lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar rename to lucene/core/src/generated/jdk/jdk19.apijar index f1672b81fcb0..4a04f1440e4a 100644 Binary files a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar and b/lucene/core/src/generated/jdk/jdk19.apijar differ diff --git a/lucene/core/src/generated/jdk/jdk20.apijar b/lucene/core/src/generated/jdk/jdk20.apijar new file mode 100644 index 000000000000..942ddef057b7 Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk20.apijar differ diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index 3646cf65584b..8a515cb79fc9 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -90,8 +90,8 @@ public float compare(byte[] v1, byte[] v2) { /** * Calculates a similarity score between the two vectors with a specified function. Higher - * similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs - * determine the vector data that is compared. Each (signed) byte represents a vector dimension. + * similarity scores correspond to closer vectors. Each (signed) byte represents a vector + * dimension. * * @param v1 a vector * @param v2 another vector, of the same dimension diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 2a08436ec0b0..068a6edc035b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -20,6 +20,9 @@ /** Utilities for computations with numeric arrays */ public final class VectorUtil { + // visible for testing + static final VectorUtilProvider PROVIDER = VectorUtilProvider.lookup(); + private VectorUtil() {} /** @@ -31,68 +34,7 @@ public static float dotProduct(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - float res = 0f; - /* - * If length of vector is larger than 8, we use unrolled dot product to accelerate the - * calculation. - */ - int i; - for (i = 0; i < a.length % 8; i++) { - res += b[i] * a[i]; - } - if (a.length < 8) { - return res; - } - for (; i + 31 < a.length; i += 32) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; - res += - b[i + 8] * a[i + 8] - + b[i + 9] * a[i + 9] - + b[i + 10] * a[i + 10] - + b[i + 11] * a[i + 11] - + b[i + 12] * a[i + 12] - + b[i + 13] * a[i + 13] - + b[i + 14] * a[i + 14] - + b[i + 15] * a[i + 15]; - res += - b[i + 16] * a[i + 16] - + b[i + 17] * a[i + 17] - + b[i + 18] * a[i + 18] - + b[i + 19] * a[i + 19] - + b[i + 20] * a[i + 20] - + b[i + 21] * a[i + 21] - + b[i + 22] * a[i + 22] - + b[i + 23] * a[i + 23]; - res += - b[i + 24] * a[i + 24] - + b[i + 25] * a[i + 25] - + b[i + 26] * a[i + 26] - + b[i + 27] * a[i + 27] - + b[i + 28] * a[i + 28] - + b[i + 29] * a[i + 29] - + b[i + 30] * a[i + 30] - + b[i + 31] * a[i + 31]; - } - for (; i + 7 < a.length; i += 8) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; - } - return res; + return PROVIDER.dotProduct(a, b); } /** @@ -100,42 +42,19 @@ public static float dotProduct(float[] a, float[] b) { * * @throws IllegalArgumentException if the vectors' dimensions differ. */ - public static float cosine(float[] v1, float[] v2) { - if (v1.length != v2.length) { - throw new IllegalArgumentException( - "vector dimensions differ: " + v1.length + "!=" + v2.length); - } - - float sum = 0.0f; - float norm1 = 0.0f; - float norm2 = 0.0f; - int dim = v1.length; - - for (int i = 0; i < dim; i++) { - float elem1 = v1[i]; - float elem2 = v2[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + public static float cosine(float[] a, float[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return (float) (sum / Math.sqrt(norm1 * norm2)); + return PROVIDER.cosine(a, b); } /** Returns the cosine similarity between the two vectors. */ public static float cosine(byte[] a, byte[] b) { - // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. - int sum = 0; - int norm1 = 0; - int norm2 = 0; - - for (int i = 0; i < a.length; i++) { - byte elem1 = a[i]; - byte elem2 = b[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + return PROVIDER.cosine(a, b); } /** @@ -143,52 +62,19 @@ public static float cosine(byte[] a, byte[] b) { * * @throws IllegalArgumentException if the vectors' dimensions differ. */ - public static float squareDistance(float[] v1, float[] v2) { - if (v1.length != v2.length) { - throw new IllegalArgumentException( - "vector dimensions differ: " + v1.length + "!=" + v2.length); - } - float squareSum = 0.0f; - int dim = v1.length; - int i; - for (i = 0; i + 8 <= dim; i += 8) { - squareSum += squareDistanceUnrolled(v1, v2, i); - } - for (; i < dim; i++) { - float diff = v1[i] - v2[i]; - squareSum += diff * diff; + public static float squareDistance(float[] a, float[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return squareSum; - } - - private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { - float diff0 = v1[index + 0] - v2[index + 0]; - float diff1 = v1[index + 1] - v2[index + 1]; - float diff2 = v1[index + 2] - v2[index + 2]; - float diff3 = v1[index + 3] - v2[index + 3]; - float diff4 = v1[index + 4] - v2[index + 4]; - float diff5 = v1[index + 5] - v2[index + 5]; - float diff6 = v1[index + 6] - v2[index + 6]; - float diff7 = v1[index + 7] - v2[index + 7]; - return diff0 * diff0 - + diff1 * diff1 - + diff2 * diff2 - + diff3 * diff3 - + diff4 * diff4 - + diff5 * diff5 - + diff6 * diff6 - + diff7 * diff7; + return PROVIDER.squareDistance(a, b); } /** Returns the sum of squared differences of the two vectors. */ public static int squareDistance(byte[] a, byte[] b) { - // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. - int squareSum = 0; - for (int i = 0; i < a.length; i++) { - int diff = a[i] - b[i]; - squareSum += diff * diff; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return squareSum; + return PROVIDER.squareDistance(a, b); } /** @@ -250,12 +136,10 @@ public static void add(float[] u, float[] v) { * @return the value of the dot product of the two vectors */ public static int dotProduct(byte[] a, byte[] b) { - assert a.length == b.length; - int total = 0; - for (int i = 0; i < a.length; i++) { - total += a[i] * b[i]; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return total; + return PROVIDER.dotProduct(a, b); } /** diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java new file mode 100644 index 000000000000..da8483ed04de --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +/** The default VectorUtil provider implementation. */ +final class VectorUtilDefaultProvider implements VectorUtilProvider { + + VectorUtilDefaultProvider() {} + + @Override + public float dotProduct(float[] a, float[] b) { + float res = 0f; + /* + * If length of vector is larger than 8, we use unrolled dot product to accelerate the + * calculation. + */ + int i; + for (i = 0; i < a.length % 8; i++) { + res += b[i] * a[i]; + } + if (a.length < 8) { + return res; + } + for (; i + 31 < a.length; i += 32) { + res += + b[i + 0] * a[i + 0] + + b[i + 1] * a[i + 1] + + b[i + 2] * a[i + 2] + + b[i + 3] * a[i + 3] + + b[i + 4] * a[i + 4] + + b[i + 5] * a[i + 5] + + b[i + 6] * a[i + 6] + + b[i + 7] * a[i + 7]; + res += + b[i + 8] * a[i + 8] + + b[i + 9] * a[i + 9] + + b[i + 10] * a[i + 10] + + b[i + 11] * a[i + 11] + + b[i + 12] * a[i + 12] + + b[i + 13] * a[i + 13] + + b[i + 14] * a[i + 14] + + b[i + 15] * a[i + 15]; + res += + b[i + 16] * a[i + 16] + + b[i + 17] * a[i + 17] + + b[i + 18] * a[i + 18] + + b[i + 19] * a[i + 19] + + b[i + 20] * a[i + 20] + + b[i + 21] * a[i + 21] + + b[i + 22] * a[i + 22] + + b[i + 23] * a[i + 23]; + res += + b[i + 24] * a[i + 24] + + b[i + 25] * a[i + 25] + + b[i + 26] * a[i + 26] + + b[i + 27] * a[i + 27] + + b[i + 28] * a[i + 28] + + b[i + 29] * a[i + 29] + + b[i + 30] * a[i + 30] + + b[i + 31] * a[i + 31]; + } + for (; i + 7 < a.length; i += 8) { + res += + b[i + 0] * a[i + 0] + + b[i + 1] * a[i + 1] + + b[i + 2] * a[i + 2] + + b[i + 3] * a[i + 3] + + b[i + 4] * a[i + 4] + + b[i + 5] * a[i + 5] + + b[i + 6] * a[i + 6] + + b[i + 7] * a[i + 7]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + float sum = 0.0f; + float norm1 = 0.0f; + float norm2 = 0.0f; + int dim = a.length; + + for (int i = 0; i < dim; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt(norm1 * norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + float squareSum = 0.0f; + int dim = a.length; + int i; + for (i = 0; i + 8 <= dim; i += 8) { + squareSum += squareDistanceUnrolled(a, b, i); + } + for (; i < dim; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } + + private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { + float diff0 = v1[index + 0] - v2[index + 0]; + float diff1 = v1[index + 1] - v2[index + 1]; + float diff2 = v1[index + 2] - v2[index + 2]; + float diff3 = v1[index + 3] - v2[index + 3]; + float diff4 = v1[index + 4] - v2[index + 4]; + float diff5 = v1[index + 5] - v2[index + 5]; + float diff6 = v1[index + 6] - v2[index + 6]; + float diff7 = v1[index + 7] - v2[index + 7]; + return diff0 * diff0 + + diff1 * diff1 + + diff2 * diff2 + + diff3 * diff3 + + diff4 * diff4 + + diff5 * diff5 + + diff6 * diff6 + + diff7 * diff7; + } + + @Override + public int dotProduct(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + total += a[i] * b[i]; + } + return total; + } + + @Override + public float cosine(byte[] a, byte[] b) { + // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. + int sum = 0; + int norm1 = 0; + int norm2 = 0; + + for (int i = 0; i < a.length; i++) { + byte elem1 = a[i]; + byte elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public int squareDistance(byte[] a, byte[] b) { + // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. + int squareSum = 0; + for (int i = 0; i < a.length; i++) { + int diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java new file mode 100644 index 000000000000..813f0e470ff0 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.lang.Runtime.Version; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Locale; +import java.util.Objects; +import java.util.logging.Logger; + +/** A provider of VectorUtil implementations. */ +interface VectorUtilProvider { + + /** Calculates the dot product of the given float arrays. */ + float dotProduct(float[] a, float[] b); + + /** Returns the cosine similarity between the two vectors. */ + float cosine(float[] v1, float[] v2); + + /** Returns the sum of squared differences of the two vectors. */ + float squareDistance(float[] a, float[] b); + + /** Returns the dot product computed over signed bytes. */ + int dotProduct(byte[] a, byte[] b); + + /** Returns the cosine similarity between the two byte vectors. */ + float cosine(byte[] a, byte[] b); + + /** Returns the sum of squared differences of the two byte vectors. */ + int squareDistance(byte[] a, byte[] b); + + // -- provider lookup mechanism + + static final Logger LOG = Logger.getLogger(VectorUtilProvider.class.getName()); + + /** The minimal version of Java that has the bugfix for JDK-8301190. */ + static final Version VERSION_JDK8301190_FIXED = Version.parse("20.0.2"); + + static VectorUtilProvider lookup() { + final int runtimeVersion = Runtime.version().feature(); + if (runtimeVersion == 20) { + // is locale sane (only buggy in Java 20) + if (isAffectedByJDK8301190()) { + LOG.warning( + "Java runtime is using a buggy default locale; Java vector incubator API can't be enabled: " + + Locale.getDefault()); + return new VectorUtilDefaultProvider(); + } + // is the incubator module present and readable (JVM providers may to exclude them or it is + // build with jlink) + if (!vectorModulePresentAndReadable()) { + LOG.warning( + "Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API."); + return new VectorUtilDefaultProvider(); + } + if (isClientVM()) { + LOG.warning("C2 compiler is disabled; Java vector incubator API can't be enabled"); + return new VectorUtilDefaultProvider(); + } + try { + // we use method handles with lookup, so we do not need to deal with setAccessible as we + // have private access through the lookup: + final var lookup = MethodHandles.lookup(); + final var cls = lookup.findClass("org.apache.lucene.util.VectorUtilPanamaProvider"); + final var constr = lookup.findConstructor(cls, MethodType.methodType(void.class)); + try { + return (VectorUtilProvider) constr.invoke(); + } catch (UnsupportedOperationException uoe) { + // not supported because preferred vector size too small or similar + LOG.warning("Java vector incubator API was not enabled. " + uoe.getMessage()); + return new VectorUtilDefaultProvider(); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable th) { + throw new AssertionError(th); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new LinkageError( + "VectorUtilPanamaProvider is missing correctly typed constructor", e); + } catch (ClassNotFoundException cnfe) { + throw new LinkageError("VectorUtilPanamaProvider is missing in Lucene JAR file", cnfe); + } + } else if (runtimeVersion >= 21) { + LOG.warning( + "You are running with Java 21 or later. To make full use of the Vector API, please update Apache Lucene."); + } + return new VectorUtilDefaultProvider(); + } + + private static boolean vectorModulePresentAndReadable() { + var opt = + ModuleLayer.boot().modules().stream() + .filter(m -> m.getName().equals("jdk.incubator.vector")) + .findFirst(); + if (opt.isPresent()) { + VectorUtilProvider.class.getModule().addReads(opt.get()); + return true; + } + return false; + } + + /** + * Check if runtime is affected by JDK-8301190 (avoids assertion when default language is say + * "tr"). + */ + private static boolean isAffectedByJDK8301190() { + return VERSION_JDK8301190_FIXED.compareToIgnoreOptional(Runtime.version()) > 0 + && !Objects.equals("I", "i".toUpperCase(Locale.getDefault())); + } + + @SuppressWarnings("removal") + @SuppressForbidden(reason = "security manager") + private static boolean isClientVM() { + try { + final PrivilegedAction action = + () -> System.getProperty("java.vm.info", "").contains("emulated-client"); + return AccessController.doPrivileged(action); + } catch ( + @SuppressWarnings("unused") + SecurityException e) { + LOG.warning( + "SecurityManager denies permission to 'java.vm.info' system property, so state of C2 compiler can't be detected. " + + "In case of performance issues allow access to this property."); + return false; + } + } +} diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java new file mode 100644 index 000000000000..2d8a22380088 --- /dev/null +++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.util.logging.Logger; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +/** A VectorUtil provider implementation that leverages the Panama Vector API. */ +final class VectorUtilPanamaProvider implements VectorUtilProvider { + + /** + * The bit size of the preferred species (this field is package private to allow the lookup to + * load it). + */ + static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize(); + + private static final VectorSpecies PREF_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; + private static final VectorSpecies PREF_BYTE_SPECIES; + private static final VectorSpecies PREF_SHORT_SPECIES; + + /** + * x86 and less than 256-bit vectors. + * + *

it could be that it has only AVX1 and integer vectors are fast. it could also be that it has + * no AVX and integer vectors are extremely slow. don't use integer vectors to avoid landmines. + */ + private static final boolean IS_AMD64_WITHOUT_AVX2 = + Constants.OS_ARCH.equals("amd64") && INT_SPECIES_PREF_BIT_SIZE < 256; + + static { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + PREF_BYTE_SPECIES = + ByteVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2)); + PREF_SHORT_SPECIES = + ShortVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1)); + } else { + PREF_BYTE_SPECIES = null; + PREF_SHORT_SPECIES = null; + } + } + + VectorUtilPanamaProvider() { + if (INT_SPECIES_PREF_BIT_SIZE < 128) { + throw new UnsupportedOperationException( + "Vector bit size is less than 128: " + INT_SPECIES_PREF_BIT_SIZE); + } + var log = Logger.getLogger(getClass().getName()); + log.info( + "Java vector incubator API enabled; uses preferredBitSize=" + INT_SPECIES_PREF_BIT_SIZE); + } + + @Override + public float dotProduct(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + acc2 = acc2.add(vc.mul(vd)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + acc3 = acc3.add(ve.mul(vf)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + acc4 = acc4.add(vg.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + int i = 0; + float sum = 0; + float norm1 = 0; + float norm2 = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector sum1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum4 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_4 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + sum2 = sum2.add(vc.mul(vd)); + norm1_2 = norm1_2.add(vc.mul(vc)); + norm2_2 = norm2_2.add(vd.mul(vd)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + sum3 = sum3.add(ve.mul(vf)); + norm1_3 = norm1_3.add(ve.mul(ve)); + norm2_3 = norm2_3.add(vf.mul(vf)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + sum4 = sum4.add(vg.mul(vh)); + norm1_4 = norm1_4.add(vg.mul(vg)); + norm2_4 = norm2_4.add(vh.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + } + // reduce + FloatVector sumres1 = sum1.add(sum2); + FloatVector sumres2 = sum3.add(sum4); + FloatVector norm1res1 = norm1_1.add(norm1_2); + FloatVector norm1res2 = norm1_3.add(norm1_4); + FloatVector norm2res1 = norm2_1.add(norm2_2); + FloatVector norm2res2 = norm2_3.add(norm2_4); + sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD); + norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD); + norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt(norm1 * norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + FloatVector diff1 = va.sub(vb); + acc1 = acc1.add(diff1.mul(diff1)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + FloatVector diff2 = vc.sub(vd); + acc2 = acc2.add(diff2.mul(diff2)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector diff3 = ve.sub(vf); + acc3 = acc3.add(diff3.mul(diff3)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector diff4 = vg.sub(vh); + acc4 = acc4.add(diff4.mul(diff4)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + FloatVector diff = va.sub(vb); + acc1 = acc1.add(diff.mul(diff)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } + + // Binary functions, these all follow a general pattern like this: + // + // short intermediate = a * b; + // int accumulator = accumulator + intermediate; + // + // 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively + // intermediate results use 128 or 256 bit vectors, respectively + // final accumulator uses 256 or 512 bit vectors, respectively + // + // We also support 128 bit vectors, using two 128 bit accumulators. + // This is slower but still faster than not vectorizing at all. + + @Override + public int dotProduct(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector prod16 = va16.mul(vb16); + Vector prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(prod32); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and multiply + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector prod16 = va16.mul(vb16); + // split each short vector into two int vectors and add + Vector prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(prod32_1); + acc2 = acc2.add(prod32_2); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(byte[] a, byte[] b) { + int i = 0; + int sum = 0; + int norm1 = 0; + int norm2 = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector prod16 = va16.mul(vb16); + Vector norm1_16 = va16.mul(va16); + Vector norm2_16 = vb16.mul(vb16); + Vector prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + Vector norm1_32 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + Vector norm2_32 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + accSum = accSum.add(prod32); + accNorm1 = accNorm1.add(norm1_32); + accNorm2 = accNorm2.add(norm2_32); + } + // reduce + sum += accSum.reduceLanes(VectorOperators.ADD); + norm1 += accNorm1.reduceLanes(VectorOperators.ADD); + norm2 += accNorm2.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and perform multiplications + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector prod16 = va16.mul(vb16); + Vector norm1_16 = va16.mul(va16); + Vector norm2_16 = vb16.mul(vb16); + // split each short vector into two int vectors and add + Vector prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + Vector norm1_32_1 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector norm1_32_2 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + Vector norm2_32_1 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector norm2_32_2 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + accSum1 = accSum1.add(prod32_1); + accSum2 = accSum2.add(prod32_2); + accNorm1_1 = accNorm1_1.add(norm1_32_1); + accNorm1_2 = accNorm1_2.add(norm1_32_2); + accNorm2_1 = accNorm2_1.add(norm2_32_1); + accNorm2_2 = accNorm2_2.add(norm2_32_2); + } + // reduce + sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD); + norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD); + norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + byte elem1 = a[i]; + byte elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public int squareDistance(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector diff16 = va16.sub(vb16); + Vector diff32 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(diff32.mul(diff32)); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and subtract + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector diff16 = va16.sub(vb16); + // split each short vector into two int vectors, square, and add + Vector diff32_1 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector diff32_2 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(diff32_1.mul(diff32_1)); + acc2 = acc2.add(diff32_2.mul(diff32_2)); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + int diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java new file mode 100644 index 000000000000..c8443b964e39 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import java.util.function.ToDoubleFunction; +import java.util.function.ToIntFunction; +import java.util.stream.IntStream; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.BeforeClass; + +public class TestVectorUtilProviders extends LuceneTestCase { + + private static final double DELTA = 1e-4; + private static final VectorUtilProvider LUCENE_PROVIDER = new VectorUtilDefaultProvider(); + private static final VectorUtilProvider JDK_PROVIDER = VectorUtil.PROVIDER; + + private static final int[] VECTOR_SIZES = { + 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024 + }; + + private final int size; + + public TestVectorUtilProviders(int size) { + this.size = size; + } + + @ParametersFactory + public static Iterable parametersFactory() { + return () -> IntStream.of(VECTOR_SIZES).boxed().map(i -> new Object[] {i}).iterator(); + } + + @BeforeClass + public static void beforeClass() throws Exception { + assumeFalse( + "Test only works when JDK's vector incubator module is enabled.", + JDK_PROVIDER instanceof VectorUtilDefaultProvider); + } + + public void testFloatVectors() { + var a = new float[size]; + var b = new float[size]; + for (int i = 0; i < size; ++i) { + a[i] = random().nextFloat(); + b[i] = random().nextFloat(); + } + assertFloatReturningProviders(p -> p.dotProduct(a, b)); + assertFloatReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + } + + public void testBinaryVectors() { + var a = new byte[size]; + var b = new byte[size]; + random().nextBytes(a); + random().nextBytes(b); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + } + + private void assertFloatReturningProviders(ToDoubleFunction func) { + assertEquals(func.applyAsDouble(LUCENE_PROVIDER), func.applyAsDouble(JDK_PROVIDER), DELTA); + } + + private void assertIntReturningProviders(ToIntFunction func) { + assertEquals(func.applyAsInt(LUCENE_PROVIDER), func.applyAsInt(JDK_PROVIDER)); + } +}