diff --git a/.buildkite/pipelines/periodic-packaging.template.yml b/.buildkite/pipelines/periodic-packaging.template.yml index b00ef51ec8f07..8b2c3c04aa34e 100644 --- a/.buildkite/pipelines/periodic-packaging.template.yml +++ b/.buildkite/pipelines/periodic-packaging.template.yml @@ -10,9 +10,11 @@ steps: - debian-12 - opensuse-leap-15 - oraclelinux-8 + - oraclelinux-9 - sles-15 - ubuntu-2004 - ubuntu-2204 + - ubuntu-2404 - rocky-8 - rocky-9 - rhel-8 diff --git a/.buildkite/pipelines/periodic-packaging.yml b/.buildkite/pipelines/periodic-packaging.yml index 9619de3c2c98b..abd11068e7a65 100644 --- a/.buildkite/pipelines/periodic-packaging.yml +++ b/.buildkite/pipelines/periodic-packaging.yml @@ -11,9 +11,11 @@ steps: - debian-12 - opensuse-leap-15 - oraclelinux-8 + - oraclelinux-9 - sles-15 - ubuntu-2004 - ubuntu-2204 + - ubuntu-2404 - rocky-8 - rocky-9 - rhel-8 diff --git a/.buildkite/pipelines/periodic-platform-support.yml b/.buildkite/pipelines/periodic-platform-support.yml index 33e9a040b22cf..c5846a763f5e8 100644 --- a/.buildkite/pipelines/periodic-platform-support.yml +++ b/.buildkite/pipelines/periodic-platform-support.yml @@ -10,9 +10,11 @@ steps: - debian-12 - opensuse-leap-15 - oraclelinux-8 + - oraclelinux-9 - sles-15 - ubuntu-2004 - ubuntu-2204 + - ubuntu-2404 - rocky-8 - rocky-9 - rhel-8 diff --git a/.buildkite/pipelines/pull-request/packaging-tests-unix.yml b/.buildkite/pipelines/pull-request/packaging-tests-unix.yml index 3b7973d05a338..2b12cb641b955 100644 --- a/.buildkite/pipelines/pull-request/packaging-tests-unix.yml +++ b/.buildkite/pipelines/pull-request/packaging-tests-unix.yml @@ -13,9 +13,11 @@ steps: - debian-12 - opensuse-leap-15 - oraclelinux-8 + - oraclelinux-9 - sles-15 - ubuntu-2004 - ubuntu-2204 + - ubuntu-2404 - rocky-8 - rocky-9 - rhel-8 diff --git a/.buildkite/pipelines/pull-request/part-1.yml b/.buildkite/pipelines/pull-request/part-1.yml index 3d467c6c41e43..7a09e2a162ff8 100644 --- a/.buildkite/pipelines/pull-request/part-1.yml +++ b/.buildkite/pipelines/pull-request/part-1.yml @@ -1,6 +1,8 @@ steps: - label: part-1 - command: .ci/scripts/run-gradle.sh -Dignore.tests.seed checkPart1 + command: | + .buildkite/scripts/spotless.sh # This doesn't have to be part of part-1, it was just a convenient place to put it + .ci/scripts/run-gradle.sh -Dignore.tests.seed checkPart1 timeout_in_minutes: 300 agents: provider: gcp diff --git a/.buildkite/pipelines/pull-request/precommit.yml b/.buildkite/pipelines/pull-request/precommit.yml index f6548dfeed9b2..1763758932581 100644 --- a/.buildkite/pipelines/pull-request/precommit.yml +++ b/.buildkite/pipelines/pull-request/precommit.yml @@ -3,7 +3,9 @@ config: skip-labels: [] steps: - label: precommit - command: .ci/scripts/run-gradle.sh -Dignore.tests.seed precommit + command: | + .buildkite/scripts/spotless.sh + .ci/scripts/run-gradle.sh -Dignore.tests.seed precommit timeout_in_minutes: 300 agents: provider: gcp diff --git a/.buildkite/scripts/gradle-configuration-cache-validation.sh b/.buildkite/scripts/gradle-configuration-cache-validation.sh index 55a4b18a1e887..2257b4cfb17a5 100755 --- a/.buildkite/scripts/gradle-configuration-cache-validation.sh +++ b/.buildkite/scripts/gradle-configuration-cache-validation.sh @@ -3,16 +3,16 @@ set -euo pipefail # This is a workaround for https://github.com/gradle/gradle/issues/28159 -.ci/scripts/run-gradle.sh --no-daemon precommit +.ci/scripts/run-gradle.sh --no-daemon $@ -.ci/scripts/run-gradle.sh --configuration-cache precommit -Dorg.gradle.configuration-cache.inputs.unsafe.ignore.file-system-checks=build/*.tar.bz2 +.ci/scripts/run-gradle.sh --configuration-cache -Dorg.gradle.configuration-cache.inputs.unsafe.ignore.file-system-checks=build/*.tar.bz2 $@ # Create a temporary file tmpOutputFile=$(mktemp) trap "rm $tmpOutputFile" EXIT echo "2nd run" -.ci/scripts/run-gradle.sh --configuration-cache precommit -Dorg.gradle.configuration-cache.inputs.unsafe.ignore.file-system-checks=build/*.tar.bz2 | tee $tmpOutputFile +.ci/scripts/run-gradle.sh --configuration-cache -Dorg.gradle.configuration-cache.inputs.unsafe.ignore.file-system-checks=build/*.tar.bz2 $@ | tee $tmpOutputFile # Check if the command was successful if grep -q "Configuration cache entry reused." $tmpOutputFile; then diff --git a/.buildkite/scripts/spotless.sh b/.buildkite/scripts/spotless.sh new file mode 100755 index 0000000000000..b9e6094edb2c7 --- /dev/null +++ b/.buildkite/scripts/spotless.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +if [[ -z "${BUILDKITE_PULL_REQUEST:-}" ]]; then + echo "Not a pull request, skipping spotless" + exit 0 +fi + +if ! git diff --exit-code; then + echo "Changes are present before running spotless, not running" + git status + exit 0 +fi + +NEW_COMMIT_MESSAGE="[CI] Auto commit changes from spotless" +PREVIOUS_COMMIT_MESSAGE="$(git log -1 --pretty=%B)" + +echo "--- Running spotless" +.ci/scripts/run-gradle.sh -Dscan.tag.NESTED spotlessApply + +if git diff --exit-code; then + echo "No changes found after running spotless. Don't need to auto commit." + exit 0 +fi + +if [[ "$NEW_COMMIT_MESSAGE" == "$PREVIOUS_COMMIT_MESSAGE" ]]; then + echo "Changes found after running spotless" + echo "CI already attempted to commit these changes, but the file(s) seem to have changed again." + echo "Please review and fix manually." + exit 1 +fi + +git config --global user.name elasticsearchmachine +git config --global user.email 'infra-root+elasticsearchmachine@elastic.co' + +gh pr checkout "${BUILDKITE_PULL_REQUEST}" +git add -u . +git commit -m "$NEW_COMMIT_MESSAGE" +git push + +# After the git push, the new commit will trigger a new build within a few seconds and this build should get cancelled +# So, let's just sleep to give the build time to cancel itself without an error +# If it doesn't get cancelled for some reason, then exit with an error, because we don't want this build to be green (we just don't want it to generate an error either) +sleep 300 +exit 1 diff --git a/.ci/scripts/run-gradle.sh b/.ci/scripts/run-gradle.sh index af5db8a6b4063..b091c4a7c7d89 100755 --- a/.ci/scripts/run-gradle.sh +++ b/.ci/scripts/run-gradle.sh @@ -5,35 +5,39 @@ mkdir -p ~/.gradle/init.d && cp -v $WORKSPACE/.ci/init.gradle ~/.gradle/init.d MAX_WORKERS=4 -# Don't run this stuff on Windows -if ! uname -a | grep -q MING; then - # drop page cache and kernel slab objects on linux - [[ -x /usr/local/sbin/drop-caches ]] && sudo /usr/local/sbin/drop-caches +if [ -z "$MAX_WORKER_ENV" ]; then + # Don't run this stuff on Windows + if ! uname -a | grep -q MING; then + # drop page cache and kernel slab objects on linux + [[ -x /usr/local/sbin/drop-caches ]] && sudo /usr/local/sbin/drop-caches - if [ "$(uname -m)" = "arm64" ] || [ "$(uname -m)" = "aarch64" ]; then - MAX_WORKERS=16 - elif [ -f /proc/cpuinfo ]; then - MAX_WORKERS=`grep '^cpu\scores' /proc/cpuinfo | uniq | sed 's/\s\+//g' | cut -d':' -f 2` - else - if [[ "$OSTYPE" == "darwin"* ]]; then - MAX_WORKERS=`sysctl -n hw.physicalcpu | sed 's/\s\+//g'` - else - echo "Unsupported OS Type: $OSTYPE" - exit 1 - fi - fi - if pwd | grep -v -q ^/dev/shm ; then - echo "Not running on a ramdisk, reducing number of workers" - MAX_WORKERS=$(($MAX_WORKERS*2/3)) - fi + if [ "$(uname -m)" = "arm64" ] || [ "$(uname -m)" = "aarch64" ]; then + MAX_WORKERS=16 + elif [ -f /proc/cpuinfo ]; then + MAX_WORKERS=`grep '^cpu\scores' /proc/cpuinfo | uniq | sed 's/\s\+//g' | cut -d':' -f 2` + else + if [[ "$OSTYPE" == "darwin"* ]]; then + MAX_WORKERS=`sysctl -n hw.physicalcpu | sed 's/\s\+//g'` + else + echo "Unsupported OS Type: $OSTYPE" + exit 1 + fi + fi + if pwd | grep -v -q ^/dev/shm ; then + echo "Not running on a ramdisk, reducing number of workers" + MAX_WORKERS=$(($MAX_WORKERS*2/3)) + fi - # Export glibc version as environment variable since some BWC tests are incompatible with later versions - export GLIBC_VERSION=$(ldd --version | grep '^ldd' | sed 's/.* \([1-9]\.[0-9]*\).*/\1/') -fi + # Export glibc version as environment variable since some BWC tests are incompatible with later versions + export GLIBC_VERSION=$(ldd --version | grep '^ldd' | sed 's/.* \([1-9]\.[0-9]*\).*/\1/') + fi -# Running on 2-core machines without ramdisk can make this value be 0 -if [[ "$MAX_WORKERS" == "0" ]]; then - MAX_WORKERS=1 + # Running on 2-core machines without ramdisk can make this value be 0 + if [[ "$MAX_WORKERS" == "0" ]]; then + MAX_WORKERS=1 + fi +else + MAX_WORKERS=$MAX_WORKER_ENV fi set -e diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java index 5992a40275b46..6050e5c9a60d4 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java @@ -16,6 +16,7 @@ import org.gradle.api.Action; import org.gradle.api.GradleException; import org.gradle.api.Project; +import org.gradle.api.Task; import org.gradle.api.logging.LogLevel; import org.gradle.api.model.ObjectFactory; import org.gradle.api.provider.Property; @@ -103,10 +104,16 @@ private static TaskProvider createRunBwcGradleTask( loggedExec.dependsOn("checkoutBwcBranch"); loggedExec.getWorkingDir().set(checkoutDir.get()); - loggedExec.getNonTrackedEnvironment().put("JAVA_HOME", providerFactory.of(JavaHomeValueSource.class, spec -> { - spec.getParameters().getVersion().set(unreleasedVersionInfo.map(it -> it.version())); - spec.getParameters().getCheckoutDir().set(checkoutDir); - }).flatMap(s -> getJavaHome(objectFactory, toolChainService, Integer.parseInt(s)))); + loggedExec.doFirst(new Action() { + @Override + public void execute(Task task) { + Provider minimumCompilerVersionValueSource = providerFactory.of(JavaHomeValueSource.class, spec -> { + spec.getParameters().getVersion().set(unreleasedVersionInfo.map(it -> it.version())); + spec.getParameters().getCheckoutDir().set(checkoutDir); + }).flatMap(s -> getJavaHome(objectFactory, toolChainService, Integer.parseInt(s))); + loggedExec.getNonTrackedEnvironment().put("JAVA_HOME", minimumCompilerVersionValueSource.get()); + } + }); if (isCi && OS.current() != OS.WINDOWS) { // TODO: Disabled for now until we can figure out why files are getting corrupted @@ -169,14 +176,6 @@ private static Provider getJavaHome(ObjectFactory objectFactory, JavaToo .map(launcher -> launcher.getMetadata().getInstallationPath().getAsFile().getAbsolutePath()); } - private static String readFromFile(File file) { - try { - return FileUtils.readFileToString(file).trim(); - } catch (IOException ioException) { - throw new GradleException("Cannot read java properties file.", ioException); - } - } - public abstract static class JavaHomeValueSource implements ValueSource { private String minimumCompilerVersionPath(Version bwcVersion) { @@ -185,6 +184,14 @@ private String minimumCompilerVersionPath(Version bwcVersion) { : "buildSrc/" + MINIMUM_COMPILER_VERSION_PATH; } + private static String readFromFile(File file) { + try { + return FileUtils.readFileToString(file).trim(); + } catch (IOException ioException) { + throw new GradleException("Cannot read java properties file.", ioException); + } + } + @Override public String obtain() { return readFromFile( diff --git a/docs/changelog/118454.yaml b/docs/changelog/118454.yaml new file mode 100644 index 0000000000000..9a19ede64d705 --- /dev/null +++ b/docs/changelog/118454.yaml @@ -0,0 +1,5 @@ +pr: 118454 +summary: Fix RLIKE folding with (unsupported) case insensitive pattern +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/118704.yaml b/docs/changelog/118704.yaml new file mode 100644 index 0000000000000..c89735664f25e --- /dev/null +++ b/docs/changelog/118704.yaml @@ -0,0 +1,6 @@ +pr: 118704 +summary: Avoid updating settings version in `MetadataMigrateToDataStreamService` when + settings have not changed +area: Data streams +type: bug +issues: [] diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java index 9e23d2c0412c3..9f10739486238 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java @@ -29,15 +29,13 @@ public class InstrumentationServiceImpl implements InstrumentationService { @Override - public Instrumenter newInstrumenter(Map checkMethods) { - return InstrumenterImpl.create(checkMethods); + public Instrumenter newInstrumenter(Class clazz, Map methods) { + return InstrumenterImpl.create(clazz, methods); } @Override - public Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, - IOException { + public Map lookupMethods(Class checkerClass) throws IOException { var methodsToInstrument = new HashMap(); - var checkerClass = Class.forName(entitlementCheckerClassName); var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass); ClassReader reader = new ClassReader(classFileInfo.bytecodes()); ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java index 57e30c01c5c28..00efab829b2bb 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java @@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter { this.checkMethods = checkMethods; } - static String getCheckerClassName() { - int javaVersion = Runtime.version().feature(); - final String classNamePrefix; - if (javaVersion >= 23) { - classNamePrefix = "Java23"; - } else { - classNamePrefix = ""; - } - return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker"; - } - - public static InstrumenterImpl create(Map checkMethods) { - String checkerClass = getCheckerClassName(); - String handleClass = checkerClass + "Handle"; - String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass)); + public static InstrumenterImpl create(Class checkerClass, Map checkMethods) { + Type checkerClassType = Type.getType(checkerClass); + String handleClass = checkerClassType.getInternalName() + "Handle"; + String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType); return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods); } - public ClassFileInfo instrumentClassFile(Class clazz) throws IOException { - ClassFileInfo initial = getClassFileInfo(clazz); - return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes())); - } - - public static ClassFileInfo getClassFileInfo(Class clazz) throws IOException { + static ClassFileInfo getClassFileInfo(Class clazz) throws IOException { String internalName = Type.getInternalName(clazz); String fileName = "/" + internalName + ".class"; byte[] originalBytecodes; @@ -306,5 +290,5 @@ protected void pushEntitlementChecker(MethodVisitor mv) { mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false); } - public record ClassFileInfo(String fileName, byte[] bytecodes) {} + record ClassFileInfo(String fileName, byte[] bytecodes) {} } diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java index 9ccb72637d463..e3285cec8f883 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java @@ -51,8 +51,8 @@ interface TestCheckerCtors { void check$org_example_TestTargetClass$(Class clazz, int x, String y); } - public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName()); + public void testInstrumentationTargetLookup() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestChecker.class); assertThat(checkMethods, aMapWithSize(3)); assertThat( @@ -116,8 +116,8 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE ); } - public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName()); + public void testInstrumentationTargetLookupWithOverloads() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class); assertThat(checkMethods, aMapWithSize(2)); assertThat( @@ -148,8 +148,8 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException, C ); } - public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException { - Map checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName()); + public void testInstrumentationTargetLookupWithCtors() throws IOException { + Map checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class); assertThat(checkMethods, aMapWithSize(2)); assertThat( diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java index 75102b0bf260d..e9af1d152dd35 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java @@ -12,31 +12,64 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.entitlement.instrumentation.CheckMethod; import org.elasticsearch.entitlement.instrumentation.MethodKey; +import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; +import org.junit.Before; import org.objectweb.asm.Type; +import java.io.IOException; +import java.lang.reflect.AccessFlag; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; import java.lang.reflect.InvocationTargetException; -import java.util.List; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text; import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod; -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.Matchers.equalTo; /** - * This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check - * some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.) + * This tests {@link InstrumenterImpl} can instrument various method signatures + * (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.) */ @ESTestCase.WithoutSecurityManager public class InstrumenterTests extends ESTestCase { private static final Logger logger = LogManager.getLogger(InstrumenterTests.class); + static class TestLoader extends ClassLoader { + final byte[] testClassBytes; + final Class testClass; + + TestLoader(String testClassName, byte[] testClassBytes) { + super(InstrumenterTests.class.getClassLoader()); + this.testClassBytes = testClassBytes; + this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length); + } + + Method getSameMethod(Method method) { + try { + return testClass.getMethod(method.getName(), method.getParameterTypes()); + } catch (NoSuchMethodException e) { + throw new AssertionError(e); + } + } + + Constructor getSameConstructor(Constructor ctor) { + try { + return testClass.getConstructor(ctor.getParameterTypes()); + } catch (NoSuchMethodException e) { + throw new AssertionError(e); + } + } + } + /** * Contains all the virtual methods from {@link TestClassToInstrument}, * allowing this test to call them on the dynamically loaded instrumented class. @@ -80,13 +113,15 @@ public static void anotherStaticMethod(int arg) {} public interface MockEntitlementChecker { void checkSomeStaticMethod(Class clazz, int arg); - void checkSomeStaticMethod(Class clazz, int arg, String anotherArg); + void checkSomeStaticMethodOverload(Class clazz, int arg, String anotherArg); + + void checkAnotherStaticMethod(Class clazz, int arg); void checkSomeInstanceMethod(Class clazz, Testable that, int arg, String anotherArg); void checkCtor(Class clazz); - void checkCtor(Class clazz, int arg); + void checkCtorOverload(Class clazz, int arg); } public static class TestEntitlementCheckerHolder { @@ -105,6 +140,7 @@ public static class TestEntitlementChecker implements MockEntitlementChecker { volatile boolean isActive; int checkSomeStaticMethodIntCallCount = 0; + int checkAnotherStaticMethodIntCallCount = 0; int checkSomeStaticMethodIntStringCallCount = 0; int checkSomeInstanceMethodCallCount = 0; @@ -120,28 +156,33 @@ private void throwIfActive() { @Override public void checkSomeStaticMethod(Class callerClass, int arg) { checkSomeStaticMethodIntCallCount++; - assertSame(TestMethodUtils.class, callerClass); + assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); throwIfActive(); } @Override - public void checkSomeStaticMethod(Class callerClass, int arg, String anotherArg) { + public void checkSomeStaticMethodOverload(Class callerClass, int arg, String anotherArg) { checkSomeStaticMethodIntStringCallCount++; - assertSame(TestMethodUtils.class, callerClass); + assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); assertEquals("abc", anotherArg); throwIfActive(); } + @Override + public void checkAnotherStaticMethod(Class callerClass, int arg) { + checkAnotherStaticMethodIntCallCount++; + assertSame(InstrumenterTests.class, callerClass); + assertEquals(123, arg); + throwIfActive(); + } + @Override public void checkSomeInstanceMethod(Class callerClass, Testable that, int arg, String anotherArg) { checkSomeInstanceMethodCallCount++; assertSame(InstrumenterTests.class, callerClass); - assertThat( - that.getClass().getName(), - startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument") - ); + assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName())); assertEquals(123, arg); assertEquals("def", anotherArg); throwIfActive(); @@ -155,7 +196,7 @@ public void checkCtor(Class callerClass) { } @Override - public void checkCtor(Class callerClass, int arg) { + public void checkCtorOverload(Class callerClass, int arg) { checkCtorIntCallCount++; assertSame(InstrumenterTests.class, callerClass); assertEquals(123, arg); @@ -163,206 +204,83 @@ public void checkCtor(Class callerClass, int arg) { } } - public void testClassIsInstrumented() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); + @Before + public void resetInstance() { + TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker(); + } - TestEntitlementCheckerHolder.checkerInstance.isActive = false; + public void testStaticMethod() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod))); // Before checking is active, nothing should throw - callStaticMethod(newClass, "someStaticMethod", 123); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - + assertStaticMethod(loader, targetMethod, 123); // After checking is activated, everything should throw - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader, targetMethod, 123); } - public void testClassIsNotInstrumentedTwice() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); - var internalClassName = Type.getInternalName(classToInstrument); - - byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); - byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + public void testNotInstrumentedTwice() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)); - logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); + var loader1 = instrumentTestClass(instrumenter); + byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes); logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); + var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode); - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW_NEW", - instrumentedTwiceBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader2, targetMethod, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); } - public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class); - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - checkMethod, - methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)), - checkMethod - ); - - var instrumenter = createInstrumenter(checkMethods); - - InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument); - var internalClassName = Type.getInternalName(classToInstrument); + public void testMultipleMethods() throws Exception { + Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class); - byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes()); - byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode); + var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2)); + var loader = instrumentTestClass(instrumenter); - logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode))); - logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode))); - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW_NEW", - instrumentedTwiceBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); + assertStaticMethodThrows(loader, targetMethod1, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); - - assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123)); - assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); + assertStaticMethodThrows(loader, targetMethod2, 123); + assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount); } - public void testInstrumenterWorksWithOverloads() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class), - methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class) + public void testStaticMethodOverload() throws Exception { + Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class); + Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class); + var instrumenter = createInstrumenter( + Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2) ); + var loader = instrumentTestClass(instrumenter); - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0; - TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0; - - // After checking is activated, everything should throw - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123)); - assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc")); - + assertStaticMethodThrows(loader, targetMethod1, 123); + assertStaticMethodThrows(loader, targetMethod2, 123, "abc"); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount); } - public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)), - getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class) - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); + public void testInstanceMethodOverload() throws Exception { + Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class); + var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod)); + var loader = instrumentTestClass(instrumenter); TestEntitlementCheckerHolder.checkerInstance.isActive = true; - TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0; - - Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance()); + Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance()); // This overload is not instrumented, so it will not throw testTargetClass.someMethod(123); - assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def")); + expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def")); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount); } - public void testInstrumenterWorksWithConstructors() throws Exception { - var classToInstrument = TestClassToInstrument.class; - - Map checkMethods = Map.of( - new MethodKey(classToInstrument.getName().replace('.', '/'), "", List.of()), - getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class), - new MethodKey(classToInstrument.getName().replace('.', '/'), "", List.of("I")), - getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class) - ); - - var instrumenter = createInstrumenter(checkMethods); - - byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes(); - - if (logger.isTraceEnabled()) { - logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); - } - - Class newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes( - classToInstrument.getName() + "_NEW", - newBytecode - ); - - TestEntitlementCheckerHolder.checkerInstance.isActive = true; - - var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance()); - assertThat(ex.getCause(), instanceOf(TestException.class)); - var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123)); - assertThat(ex2.getCause(), instanceOf(TestException.class)); + public void testConstructors() throws Exception { + Constructor ctor1 = TestClassToInstrument.class.getConstructor(); + Constructor ctor2 = TestClassToInstrument.class.getConstructor(int.class); + var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2))); + assertCtorThrows(loader, ctor1); + assertCtorThrows(loader, ctor2, 123); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount); assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount); } @@ -373,11 +291,107 @@ public void testInstrumenterWorksWithConstructors() throws Exception { * MethodKey and instrumentationMethod with slightly different signatures (using the common interface * Testable) which is not what would happen when it's run by the agent. */ - private InstrumenterImpl createInstrumenter(Map checkMethods) { + private static InstrumenterImpl createInstrumenter(Map methods) throws NoSuchMethodException { + Map checkMethods = new HashMap<>(); + for (var entry : methods.entrySet()) { + checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue())); + } String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class); String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class); String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass)); - return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods); + return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods); + } + + private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException { + var clazz = TestClassToInstrument.class; + ClassFileInfo initial = getClassFileInfo(clazz); + byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes()); + if (logger.isTraceEnabled()) { + logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode)); + } + return new TestLoader(clazz.getName(), newBytecode); + } + + private static MethodKey getMethodKey(Executable method) { + logger.info("method key: {}", method.getName()); + String methodName = method instanceof Constructor ? "" : method.getName(); + return new MethodKey( + Type.getInternalName(method.getDeclaringClass()), + methodName, + Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList() + ); + } + + private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException { + Set flags = targetMethod.accessFlags(); + boolean isInstance = flags.contains(AccessFlag.STATIC) == false && targetMethod instanceof Method; + int extraArgs = 1; // caller class + if (isInstance) { + ++extraArgs; + } + Class[] targetParameterTypes = targetMethod.getParameterTypes(); + Class[] checkParameterTypes = new Class[targetParameterTypes.length + extraArgs]; + checkParameterTypes[0] = Class.class; + if (isInstance) { + checkParameterTypes[1] = Testable.class; + } + System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length); + var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes); + return new CheckMethod( + Type.getInternalName(MockEntitlementChecker.class), + checkMethod.getName(), + Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList() + ); + } + + private static void unwrapInvocationException(InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof TestException n) { + // Sometimes we're expecting this one! + throw n; + } else { + throw new AssertionError(cause); + } + } + + /** + * Calling a static method of a dynamically loaded class is significantly more cumbersome + * than calling a virtual method. + */ + static void callStaticMethod(Method method, Object... args) { + try { + method.invoke(null, args); + } catch (InvocationTargetException e) { + unwrapInvocationException(e); + } catch (IllegalAccessException e) { + throw new AssertionError(e); + } + } + + private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) { + Method testMethod = loader.getSameMethod(method); + TestEntitlementCheckerHolder.checkerInstance.isActive = true; + expectThrows(TestException.class, () -> callStaticMethod(testMethod, args)); + } + + private void assertStaticMethod(TestLoader loader, Method method, Object... args) { + Method testMethod = loader.getSameMethod(method); + TestEntitlementCheckerHolder.checkerInstance.isActive = false; + callStaticMethod(testMethod, args); + } + + private void assertCtorThrows(TestLoader loader, Constructor ctor, Object... args) { + Constructor testCtor = loader.getSameConstructor(ctor); + TestEntitlementCheckerHolder.checkerInstance.isActive = true; + expectThrows(TestException.class, () -> { + try { + testCtor.newInstance(args); + } catch (InvocationTargetException e) { + unwrapInvocationException(e); + } catch (IllegalAccessException | InstantiationException e) { + throw new AssertionError(e); + } + }); } } diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java deleted file mode 100644 index 9eb8e9328ecba..0000000000000 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.instrumentation.impl; - -class TestLoader extends ClassLoader { - TestLoader(ClassLoader parent) { - super(parent); - } - - public Class defineClassFromBytes(String name, byte[] bytes) { - return defineClass(name, bytes, 0, bytes.length); - } -} diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java deleted file mode 100644 index de7822fea926e..0000000000000 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.entitlement.instrumentation.impl; - -import org.elasticsearch.entitlement.instrumentation.CheckMethod; -import org.elasticsearch.entitlement.instrumentation.MethodKey; -import org.objectweb.asm.Type; - -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Stream; - -class TestMethodUtils { - - /** - * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline - */ - static MethodKey methodKeyForTarget(Method targetMethod) { - Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod)); - return new MethodKey( - Type.getInternalName(targetMethod.getDeclaringClass()), - targetMethod.getName(), - Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList() - ); - } - - static MethodKey methodKeyForConstructor(Class classToInstrument, List params) { - return new MethodKey(classToInstrument.getName().replace('.', '/'), "", params); - } - - static CheckMethod getCheckMethod(Class clazz, String methodName, Class... parameterTypes) throws NoSuchMethodException { - var method = clazz.getMethod(methodName, parameterTypes); - return new CheckMethod( - Type.getInternalName(clazz), - method.getName(), - Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList() - ); - } - - /** - * Calling a static method of a dynamically loaded class is significantly more cumbersome - * than calling a virtual method. - */ - static void callStaticMethod(Class c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException { - try { - c.getMethod(methodName, int.class).invoke(null, arg); - } catch (InvocationTargetException e) { - Throwable cause = e.getCause(); - if (cause instanceof TestException n) { - // Sometimes we're expecting this one! - throw n; - } else { - throw new AssertionError(cause); - } - } - } - - static void callStaticMethod(Class c, String methodName, int arg1, String arg2) throws NoSuchMethodException, - IllegalAccessException { - try { - c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2); - } catch (InvocationTargetException e) { - Throwable cause = e.getCause(); - if (cause instanceof TestException n) { - // Sometimes we're expecting this one! - throw n; - } else { - throw new AssertionError(cause); - } - } - } -} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java index 2956efa8eec31..9118f67cdc145 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java @@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bridge.EntitlementChecker; import org.elasticsearch.entitlement.instrumentation.CheckMethod; import org.elasticsearch.entitlement.instrumentation.InstrumentationService; +import org.elasticsearch.entitlement.instrumentation.Instrumenter; import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.entitlement.instrumentation.Transformer; import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; @@ -64,13 +65,12 @@ public static EntitlementChecker checker() { public static void initialize(Instrumentation inst) throws Exception { manager = initChecker(); - Map checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument( - "org.elasticsearch.entitlement.bridge.EntitlementChecker" - ); + Map checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class); var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet()); - inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true); + Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods); + inst.addTransformer(new Transformer(instrumenter, classesToTransform), true); // TODO: should we limit this array somehow? var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new); inst.retransformClasses(classesToRetransform); diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java index d0331d756d2b2..66d8ad9488cfa 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java @@ -16,7 +16,7 @@ * The SPI service entry point for instrumentation. */ public interface InstrumentationService { - Instrumenter newInstrumenter(Map checkMethods); + Instrumenter newInstrumenter(Class clazz, Map methods); - Map lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException; + Map lookupMethods(Class clazz) throws IOException; } diff --git a/muted-tests.yml b/muted-tests.yml index 1255739e818be..2689f02cc92cd 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -79,9 +79,6 @@ tests: - class: org.elasticsearch.xpack.restart.MLModelDeploymentFullClusterRestartIT method: testDeploymentSurvivesRestart {cluster=UPGRADED} issue: https://github.com/elastic/elasticsearch/issues/115528 -- class: org.elasticsearch.action.update.UpdateResponseTests - method: testToAndFromXContent - issue: https://github.com/elastic/elasticsearch/issues/115689 - class: org.elasticsearch.xpack.shutdown.NodeShutdownIT method: testStalledShardMigrationProperlyDetected issue: https://github.com/elastic/elasticsearch/issues/115697 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 388123e86c882..f5e581a81a37c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -55,7 +55,6 @@ static TransportVersion def(int id) { public static final TransportVersion V_7_3_0 = def(7_03_00_99); public static final TransportVersion V_7_4_0 = def(7_04_00_99); public static final TransportVersion V_7_6_0 = def(7_06_00_99); - public static final TransportVersion V_7_7_0 = def(7_07_00_99); public static final TransportVersion V_7_8_0 = def(7_08_00_99); public static final TransportVersion V_7_8_1 = def(7_08_01_99); public static final TransportVersion V_7_9_0 = def(7_09_00_99); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ComponentTemplateMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ComponentTemplateMetadata.java index 1151a99a24403..2aedfcef272ed 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ComponentTemplateMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ComponentTemplateMetadata.java @@ -89,7 +89,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } @Override @@ -166,7 +166,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplateMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplateMetadata.java index e798b0f6add4f..f34375d9c8a6d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplateMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplateMetadata.java @@ -94,7 +94,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } @Override @@ -167,7 +167,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamMetadata.java index 3f5e7a2e0c4aa..e49b991c21887 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamMetadata.java @@ -217,7 +217,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } @Override @@ -311,7 +311,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamService.java index aee60c3eda57f..39acc6d3f6311 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamService.java @@ -29,6 +29,7 @@ import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DataStreamTimestampFieldMapper; import org.elasticsearch.index.mapper.DocumentMapper; @@ -250,9 +251,11 @@ static void prepareBackingIndex( DataStreamFailureStoreDefinition.applyFailureStoreSettings(nodeSettings, settingsUpdate); } - imb.settings(settingsUpdate.build()) - .settingsVersion(im.getSettingsVersion() + 1) - .mappingVersion(im.getMappingVersion() + 1) + Settings maybeUpdatedSettings = settingsUpdate.build(); + if (IndexSettings.same(im.getSettings(), maybeUpdatedSettings) == false) { + imb.settings(maybeUpdatedSettings).settingsVersion(im.getSettingsVersion() + 1); + } + imb.mappingVersion(im.getMappingVersion() + 1) .mappingsUpdatedVersion(IndexVersion.current()) .putMapping(new MappingMetadata(mapper)); b.put(imb); diff --git a/server/src/main/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerService.java b/server/src/main/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerService.java index e5c9b14cf90fc..ba2382208b705 100644 --- a/server/src/main/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerService.java +++ b/server/src/main/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerService.java @@ -58,7 +58,7 @@ public CircuitBreakerStats stats(String name) { } @Override - protected void doClose() { + public void close() { preallocated.close(); } diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 0aaae2104576a..fd321f6256194 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -133,6 +133,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID_BACKPORT = def(8_520_00_0, Version.LUCENE_9_12_0); public static final IndexVersion V8_DEPRECATE_SOURCE_MODE_MAPPER = def(8_521_00_0, Version.LUCENE_9_12_0); public static final IndexVersion USE_SYNTHETIC_SOURCE_FOR_RECOVERY_BACKPORT = def(8_522_00_0, Version.LUCENE_9_12_0); + public static final IndexVersion UPGRADE_TO_LUCENE_9_12_1 = def(8_523_00_0, parseUnchecked("9.12.1")); public static final IndexVersion UPGRADE_TO_LUCENE_10_0_0 = def(9_000_00_0, Version.LUCENE_10_0_0); public static final IndexVersion LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT = def(9_001_00_0, Version.LUCENE_10_0_0); public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID = def(9_002_00_0, Version.LUCENE_10_0_0); diff --git a/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java b/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java index 8dee39f7050cb..49e0ae0587085 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/NoOpEngine.java @@ -48,7 +48,7 @@ public final class NoOpEngine extends ReadOnlyEngine { public NoOpEngine(EngineConfig config) { this( config, - config.isPromotableToPrimary() ? null : new TranslogStats(0, 0, 0, 0, 0), + config.isPromotableToPrimary() && config.getTranslogConfig().hasTranslog() ? null : new TranslogStats(0, 0, 0, 0, 0), config.isPromotableToPrimary() ? null : new SeqNoStats( diff --git a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java index c1d11223fa55e..1d032c1f400ef 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java @@ -98,7 +98,7 @@ public class ReadOnlyEngine extends Engine { public ReadOnlyEngine( EngineConfig config, SeqNoStats seqNoStats, - TranslogStats translogStats, + @Nullable TranslogStats translogStats, boolean obtainLock, Function readerWrapperFunction, boolean requireCompleteHistory, @@ -251,6 +251,7 @@ private static SeqNoStats buildSeqNoStats(EngineConfig config, SegmentInfos info } private static TranslogStats translogStats(final EngineConfig config, final SegmentInfos infos) throws IOException { + assert config.getTranslogConfig().hasTranslog(); final String translogUuid = infos.getUserData().get(Translog.TRANSLOG_UUID_KEY); if (translogUuid == null) { throw new IllegalStateException("commit doesn't contain translog unique id"); diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index a76feff84e61b..966764d2797c9 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -1487,6 +1487,27 @@ public void flush(FlushRequest request, ActionListener listener) { }); } + /** + * @return true the shard has a translog. + */ + public boolean hasTranslog() { + return translogConfig.hasTranslog(); + } + + /** + * Reads the global checkpoint from the translog checkpoint file if the shard has a translog. Otherwise, reads the local checkpoint from + * the provided commit user data. + * + * @return the global checkpoint to use for recovery + * @throws IOException + */ + public long readGlobalCheckpointForRecovery(Map commitUserData) throws IOException { + if (hasTranslog()) { + return Translog.readGlobalCheckpoint(translogConfig.getTranslogPath(), commitUserData.get(Translog.TRANSLOG_UUID_KEY)); + } + return Long.parseLong(commitUserData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); + } + /** * checks and removes translog files that no longer need to be retained. See * {@link org.elasticsearch.index.translog.TranslogDeletionPolicy} for details @@ -1859,8 +1880,7 @@ public void recoverLocallyUpToGlobalCheckpoint(ActionListener recoveryStar } assert routingEntry().recoverySource().getType() == RecoverySource.Type.PEER : "not a peer recovery [" + routingEntry() + "]"; try { - final var translogUUID = store.readLastCommittedSegmentsInfo().getUserData().get(Translog.TRANSLOG_UUID_KEY); - final var globalCheckpoint = Translog.readGlobalCheckpoint(translogConfig.getTranslogPath(), translogUUID); + final var globalCheckpoint = readGlobalCheckpointForRecovery(store.readLastCommittedSegmentsInfo().getUserData()); final var safeCommit = store.findSafeIndexCommit(globalCheckpoint); ActionListener.run(recoveryStartingSeqNoListener.delegateResponse((l, e) -> { logger.debug(() -> format("failed to recover shard locally up to global checkpoint %s", globalCheckpoint), e); @@ -2084,8 +2104,7 @@ private void loadGlobalCheckpointToReplicationTracker() throws IOException { // we have to set it before we open an engine and recover from the translog because // acquiring a snapshot from the translog causes a sync which causes the global checkpoint to be pulled in, // and an engine can be forced to close in ctor which also causes the global checkpoint to be pulled in. - final String translogUUID = store.readLastCommittedSegmentsInfo().getUserData().get(Translog.TRANSLOG_UUID_KEY); - final long globalCheckpoint = Translog.readGlobalCheckpoint(translogConfig.getTranslogPath(), translogUUID); + final long globalCheckpoint = readGlobalCheckpointForRecovery(store.readLastCommittedSegmentsInfo().getUserData()); replicationTracker.updateGlobalCheckpointOnReplica(globalCheckpoint, "read from translog checkpoint"); } else { replicationTracker.updateGlobalCheckpointOnReplica(globalCheckPointIfUnpromotable, "from CleanFilesRequest"); @@ -2162,7 +2181,7 @@ private void innerOpenEngineAndTranslog(LongSupplier globalCheckpointSupplier) t // time elapses after the engine is created above (pulling the config settings) until we set the engine reference, during // which settings changes could possibly have happened, so here we forcefully push any config changes to the new engine. onSettingsChanged(); - assert assertSequenceNumbersInCommit(); + assert assertLastestCommitUserData(); recoveryState.validateCurrentStage(RecoveryState.Stage.TRANSLOG); checkAndCallWaitForEngineOrClosedShardListeners(); } @@ -2183,9 +2202,13 @@ private Engine createEngine(EngineConfig config) { } } - private boolean assertSequenceNumbersInCommit() throws IOException { + /** + * Asserts that the latest Lucene commit contains expected information about sequence numbers or ES version. + */ + private boolean assertLastestCommitUserData() throws IOException { final SegmentInfos segmentCommitInfos = SegmentInfos.readLatestCommit(store.directory()); final Map userData = segmentCommitInfos.getUserData(); + // Ensure sequence numbers are present in commit data assert userData.containsKey(SequenceNumbers.LOCAL_CHECKPOINT_KEY) : "commit point doesn't contains a local checkpoint"; assert userData.containsKey(SequenceNumbers.MAX_SEQ_NO) : "commit point doesn't contains a maximum sequence number"; assert userData.containsKey(Engine.HISTORY_UUID_KEY) : "commit point doesn't contains a history uuid"; @@ -2195,10 +2218,16 @@ private boolean assertSequenceNumbersInCommit() throws IOException { + "] is different than engine [" + getHistoryUUID() + "]"; + assert userData.containsKey(Engine.MAX_UNSAFE_AUTO_ID_TIMESTAMP_COMMIT_ID) : "opening index which was created post 5.5.0 but " + Engine.MAX_UNSAFE_AUTO_ID_TIMESTAMP_COMMIT_ID + " is not found in commit"; + + // From 7.16.0, the ES version is included in the Lucene commit user data as well as in the snapshot metadata in the repository. + // This is used during primary failover to detect if the latest snapshot can be used to recover the new primary, because the failed + // primary may have created new segments in a more recent Lucene version, that may have been later snapshotted, meaning that the + // snapshotted files cannot be recovered on a node with a less recent Lucene version. Note that for versions <= 7.15 this assertion + // relies in the previous minor having a different lucene version. final org.apache.lucene.util.Version commitLuceneVersion = segmentCommitInfos.getCommitLuceneVersion(); - // This relies in the previous minor having another lucene version assert commitLuceneVersion.onOrAfter(RecoverySettings.SEQ_NO_SNAPSHOT_RECOVERIES_SUPPORTED_VERSION.luceneVersion()) == false || userData.containsKey(Engine.ES_VERSION) && Engine.readIndexVersion(userData.get(Engine.ES_VERSION)).onOrBefore(IndexVersion.current()) diff --git a/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java b/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java index 42f62cf86545b..06f9b3e6c8943 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java +++ b/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java @@ -315,7 +315,7 @@ void recoverFromRepository(final IndexShard indexShard, Repository repository, A RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType(); assert recoveryType == RecoverySource.Type.SNAPSHOT : "expected snapshot recovery type: " + recoveryType; SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) indexShard.recoveryState().getRecoverySource(); - restore(indexShard, repository, recoverySource, recoveryListener(indexShard, listener).map(ignored -> true)); + recoverFromRepository(indexShard, repository, recoverySource, recoveryListener(indexShard, listener).map(ignored -> true)); } else { listener.onResponse(false); } @@ -459,7 +459,7 @@ private void internalRecoverFromStore(IndexShard indexShard, ActionListener outerListener ) { + assert indexShard.shardRouting.primary() : "only primary shards can recover from snapshot"; logger.debug("restoring from {} ...", indexShard.recoveryState().getRecoverySource()); record ShardAndIndexIds(IndexId indexId, ShardId shardId) {} @@ -538,13 +539,13 @@ record ShardAndIndexIds(IndexId indexId, ShardId shardId) {} .newForked(indexShard::preRecovery) .andThen(shardAndIndexIdsListener -> { - final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog(); if (restoreSource == null) { throw new IndexShardRestoreFailedException(shardId, "empty restore source"); } if (logger.isTraceEnabled()) { logger.trace("[{}] restoring shard [{}]", restoreSource.snapshot(), shardId); } + final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog(); translogState.totalOperations(0); translogState.totalOperationsOnStart(0); indexShard.prepareForIndexRecovery(); @@ -588,9 +589,7 @@ record ShardAndIndexIds(IndexId indexId, ShardId shardId) {} .andThen(l -> { indexShard.getIndexEventListener().afterFilesRestoredFromRepository(indexShard); - final Store store = indexShard.store(); - bootstrap(indexShard, store); - assert indexShard.shardRouting.primary() : "only primary shards can recover from store"; + bootstrap(indexShard); writeEmptyRetentionLeasesFile(indexShard); indexShard.openEngineAndRecoverFromTranslog(l); }) @@ -610,19 +609,37 @@ record ShardAndIndexIds(IndexId indexId, ShardId shardId) {} })); } + /** + * @deprecated use {@link #bootstrap(IndexShard)} instead + */ + @Deprecated(forRemoval = true) public static void bootstrap(final IndexShard indexShard, final Store store) throws IOException { - if (indexShard.indexSettings.getIndexMetadata().isSearchableSnapshot() == false) { - // not bootstrapping new history for searchable snapshots (which are read-only) allows sequence-number based peer recoveries + assert indexShard.store() == store; + bootstrap(indexShard); + } + + private static void bootstrap(final IndexShard indexShard) throws IOException { + assert indexShard.routingEntry().primary(); + final var store = indexShard.store(); + store.incRef(); + try { + final var translogLocation = indexShard.shardPath().resolveTranslog(); + if (indexShard.hasTranslog() == false) { + Translog.deleteAll(translogLocation); + return; + } store.bootstrapNewHistory(); + final SegmentInfos segmentInfos = store.readLastCommittedSegmentsInfo(); + final long localCheckpoint = Long.parseLong(segmentInfos.userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); + final String translogUUID = Translog.createEmptyTranslog( + translogLocation, + localCheckpoint, + indexShard.shardId(), + indexShard.getPendingPrimaryTerm() + ); + store.associateIndexWithNewTranslog(translogUUID); + } finally { + store.decRef(); } - final SegmentInfos segmentInfos = store.readLastCommittedSegmentsInfo(); - final long localCheckpoint = Long.parseLong(segmentInfos.userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); - final String translogUUID = Translog.createEmptyTranslog( - indexShard.shardPath().resolveTranslog(), - localCheckpoint, - indexShard.shardId(), - indexShard.getPendingPrimaryTerm() - ); - store.associateIndexWithNewTranslog(translogUUID); } } diff --git a/server/src/main/java/org/elasticsearch/index/store/ByteSizeCachingDirectory.java b/server/src/main/java/org/elasticsearch/index/store/ByteSizeCachingDirectory.java index 166f0eadc62b4..033859bd62128 100644 --- a/server/src/main/java/org/elasticsearch/index/store/ByteSizeCachingDirectory.java +++ b/server/src/main/java/org/elasticsearch/index/store/ByteSizeCachingDirectory.java @@ -10,6 +10,7 @@ package org.elasticsearch.index.store; import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexOutput; import org.elasticsearch.common.lucene.store.FilterIndexOutput; @@ -19,7 +20,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -final class ByteSizeCachingDirectory extends ByteSizeDirectory { +public final class ByteSizeCachingDirectory extends ByteSizeDirectory { private static class SizeAndModCount { final long size; @@ -174,9 +175,29 @@ public void deleteFile(String name) throws IOException { try { super.deleteFile(name); } finally { - synchronized (this) { - modCount++; + markEstimatedSizeAsStale(); + } + } + + /** + * Mark the cached size as stale so that it is guaranteed to be refreshed the next time. + */ + public void markEstimatedSizeAsStale() { + synchronized (this) { + modCount++; + } + } + + public static ByteSizeCachingDirectory unwrapDirectory(Directory dir) { + while (dir != null) { + if (dir instanceof ByteSizeCachingDirectory) { + return (ByteSizeCachingDirectory) dir; + } else if (dir instanceof FilterDirectory) { + dir = ((FilterDirectory) dir).getDelegate(); + } else { + dir = null; } } + return null; } } diff --git a/server/src/main/java/org/elasticsearch/index/store/Store.java b/server/src/main/java/org/elasticsearch/index/store/Store.java index e6b499c07f189..322064f09cf77 100644 --- a/server/src/main/java/org/elasticsearch/index/store/Store.java +++ b/server/src/main/java/org/elasticsearch/index/store/Store.java @@ -1449,6 +1449,7 @@ public void bootstrapNewHistory() throws IOException { * @see SequenceNumbers#MAX_SEQ_NO */ public void bootstrapNewHistory(long localCheckpoint, long maxSeqNo) throws IOException { + assert indexSettings.getIndexMetadata().isSearchableSnapshot() == false; metadataLock.writeLock().lock(); try (IndexWriter writer = newTemporaryAppendingIndexWriter(directory, null)) { final Map map = new HashMap<>(); @@ -1572,6 +1573,7 @@ private IndexWriter newTemporaryEmptyIndexWriter(final Directory dir, final Vers } private IndexWriterConfig newTemporaryIndexWriterConfig() { + assert indexSettings.getIndexMetadata().isSearchableSnapshot() == false; // this config is only used for temporary IndexWriter instances, used to initialize the index or update the commit data, // so we don't want any merges to happen var iwc = indexWriterConfigWithNoMerging(null).setSoftDeletesField(Lucene.SOFT_DELETES_FIELD).setCommitOnClose(false); diff --git a/server/src/main/java/org/elasticsearch/index/translog/Translog.java b/server/src/main/java/org/elasticsearch/index/translog/Translog.java index 2d3e2d8c20256..75f38ed5f7342 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/Translog.java +++ b/server/src/main/java/org/elasticsearch/index/translog/Translog.java @@ -220,6 +220,10 @@ public Translog( } } + public static void deleteAll(Path translogLocation) throws IOException { + IOUtils.rm(translogLocation); + } + /** recover all translog files found on disk */ private ArrayList recoverFromFiles(Checkpoint checkpoint) throws IOException { boolean success = false; diff --git a/server/src/main/java/org/elasticsearch/index/translog/TranslogConfig.java b/server/src/main/java/org/elasticsearch/index/translog/TranslogConfig.java index 66018092089ce..280e319335b12 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/TranslogConfig.java +++ b/server/src/main/java/org/elasticsearch/index/translog/TranslogConfig.java @@ -143,4 +143,13 @@ public OperationListener getOperationListener() { public boolean fsync() { return fsync; } + + /** + * @return {@code true} if the configuration allows the Translog files to exist, {@code false} otherwise. In the case there is no + * translog, the shard is not writeable. + */ + public boolean hasTranslog() { + // Expect no translog files to exist for searchable snapshots + return false == indexSettings.getIndexMetadata().isSearchableSnapshot(); + } } diff --git a/server/src/main/java/org/elasticsearch/indices/breaker/CircuitBreakerService.java b/server/src/main/java/org/elasticsearch/indices/breaker/CircuitBreakerService.java index c8d7fa0775f8f..be89dd97aab89 100644 --- a/server/src/main/java/org/elasticsearch/indices/breaker/CircuitBreakerService.java +++ b/server/src/main/java/org/elasticsearch/indices/breaker/CircuitBreakerService.java @@ -10,13 +10,12 @@ package org.elasticsearch.indices.breaker; import org.elasticsearch.common.breaker.CircuitBreaker; -import org.elasticsearch.common.component.AbstractLifecycleComponent; /** * Interface for Circuit Breaker services, which provide breakers to classes * that load field data. */ -public abstract class CircuitBreakerService extends AbstractLifecycleComponent { +public abstract class CircuitBreakerService { protected CircuitBreakerService() {} @@ -35,13 +34,4 @@ protected CircuitBreakerService() {} */ public abstract CircuitBreakerStats stats(String name); - @Override - protected void doStart() {} - - @Override - protected void doStop() {} - - @Override - protected void doClose() {} - } diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java b/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java index c8d31d2060caf..d069717a66ad0 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java @@ -49,9 +49,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardLongFieldRange; import org.elasticsearch.index.shard.ShardNotFoundException; -import org.elasticsearch.index.shard.StoreRecovery; import org.elasticsearch.index.store.Store; -import org.elasticsearch.index.translog.Translog; import org.elasticsearch.index.translog.TranslogCorruptedException; import org.elasticsearch.indices.recovery.RecoveriesCollection.RecoveryRef; import org.elasticsearch.tasks.Task; @@ -385,15 +383,8 @@ record StartRecoveryRequestToSend(StartRecoveryRequest startRecoveryRequest, Str logger.trace("{} preparing shard for peer recovery", recoveryTarget.shardId()); indexShard.prepareForIndexRecovery(); if (indexShard.indexSettings().getIndexMetadata().isSearchableSnapshot()) { - // for searchable snapshots, peer recovery is treated similarly to recovery from snapshot + // for archives indices mounted as searchable snapshots, we need to call this indexShard.getIndexEventListener().afterFilesRestoredFromRepository(indexShard); - final Store store = indexShard.store(); - store.incRef(); - try { - StoreRecovery.bootstrap(indexShard, store); - } finally { - store.decRef(); - } } indexShard.recoverLocallyUpToGlobalCheckpoint(ActionListener.assertOnce(l)); }) @@ -488,8 +479,8 @@ public static StartRecoveryRequest getStartRecoveryRequest( // Make sure that the current translog is consistent with the Lucene index; otherwise, we have to throw away the Lucene // index. try { - final String expectedTranslogUUID = metadataSnapshot.commitUserData().get(Translog.TRANSLOG_UUID_KEY); - final long globalCheckpoint = Translog.readGlobalCheckpoint(recoveryTarget.translogLocation(), expectedTranslogUUID); + final long globalCheckpoint = recoveryTarget.indexShard() + .readGlobalCheckpointForRecovery(metadataSnapshot.commitUserData()); assert globalCheckpoint + 1 >= startingSeqNo : "invalid startingSeqNo " + startingSeqNo + " >= " + globalCheckpoint; } catch (IOException | TranslogCorruptedException e) { logGlobalCheckpointWarning(logger, startingSeqNo, e); diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java index ea485a411143e..362a62c838e3b 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java @@ -45,7 +45,6 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; -import java.nio.file.Path; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; @@ -516,13 +515,7 @@ public void cleanFiles( try { if (indexShard.routingEntry().isPromotableToPrimary()) { store.cleanupAndVerify("recovery CleanFilesRequestHandler", sourceMetadata); - final String translogUUID = Translog.createEmptyTranslog( - indexShard.shardPath().resolveTranslog(), - globalCheckpoint, - shardId, - indexShard.getPendingPrimaryTerm() - ); - store.associateIndexWithNewTranslog(translogUUID); + bootstrap(indexShard, globalCheckpoint); } else { indexShard.setGlobalCheckpointIfUnpromotable(globalCheckpoint); } @@ -634,7 +627,33 @@ public String getTempNameForFile(String origFile) { return multiFileWriter.getTempNameForFile(origFile); } - Path translogLocation() { - return indexShard().shardPath().resolveTranslog(); + private static void bootstrap(final IndexShard indexShard, long globalCheckpoint) throws IOException { + assert indexShard.routingEntry().isPromotableToPrimary(); + final var store = indexShard.store(); + store.incRef(); + try { + final var translogLocation = indexShard.shardPath().resolveTranslog(); + if (indexShard.hasTranslog() == false) { + if (Assertions.ENABLED) { + if (indexShard.indexSettings().getIndexMetadata().isSearchableSnapshot()) { + long localCheckpoint = Long.parseLong( + store.readLastCommittedSegmentsInfo().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) + ); + assert localCheckpoint == globalCheckpoint : localCheckpoint + " != " + globalCheckpoint; + } + } + Translog.deleteAll(translogLocation); + return; + } + final String translogUUID = Translog.createEmptyTranslog( + indexShard.shardPath().resolveTranslog(), + globalCheckpoint, + indexShard.shardId(), + indexShard.getPendingPrimaryTerm() + ); + store.associateIndexWithNewTranslog(translogUUID); + } finally { + store.decRef(); + } } } diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java similarity index 73% rename from server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java rename to server/src/main/java/org/elasticsearch/inference/ChunkedInference.java index 10e00e9860200..c54e5a98d56cc 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java @@ -12,23 +12,27 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.xcontent.XContent; +import java.io.IOException; import java.util.Iterator; -public interface ChunkedInferenceServiceResults extends InferenceServiceResults { +public interface ChunkedInference { /** * Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields. - * The iterator iterates over all the chunks stored in the {@link ChunkedInferenceServiceResults}. * * @param xcontent provided by the SemanticTextField * @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding). */ - Iterator chunksAsMatchedTextAndByteReference(XContent xcontent); + Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException; /** - * A chunk of inference results containing matched text and the bytes reference. + * A chunk of inference results containing matched text, the substring location + * in the original text and the bytes reference. * @param matchedText + * @param textOffset * @param bytesReference */ - record Chunk(String matchedText, BytesReference bytesReference) {} + record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {} + + record TextOffset(int start, int end) {} } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index c2d690d8160ac..c8ed9e6b230ce 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -144,7 +144,7 @@ void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); /** diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index aec8eb0c3ca67..17e56a392daff 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1459,7 +1459,6 @@ private CircuitBreakerService createCircuitBreakerService( case "none" -> new NoneCircuitBreakerService(); default -> throw new IllegalArgumentException("Unknown circuit breaker type [" + type + "]"); }; - resourcesToClose.add(circuitBreakerService); modules.bindToInstance(CircuitBreakerService.class, circuitBreakerService); pluginBreakers.forEach(t -> { diff --git a/server/src/main/java/org/elasticsearch/script/ScriptException.java b/server/src/main/java/org/elasticsearch/script/ScriptException.java index 72349e408210e..ab8d804b82ce8 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptException.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptException.java @@ -10,7 +10,6 @@ package org.elasticsearch.script; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -79,7 +78,7 @@ public ScriptException(StreamInput in) throws IOException { scriptStack = Arrays.asList(in.readStringArray()); script = in.readString(); lang = in.readString(); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0) && in.readBoolean()) { + if (in.readBoolean()) { pos = new Position(in); } else { pos = null; @@ -92,13 +91,11 @@ protected void writeTo(StreamOutput out, Writer nestedExceptionsWrite out.writeStringCollection(scriptStack); out.writeString(script); out.writeString(lang); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - if (pos == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - pos.writeTo(out); - } + if (pos == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + pos.writeTo(out); } } diff --git a/server/src/main/java/org/elasticsearch/search/DocValueFormat.java b/server/src/main/java/org/elasticsearch/search/DocValueFormat.java index f8d161ef1f5e5..dca3da9419380 100644 --- a/server/src/main/java/org/elasticsearch/search/DocValueFormat.java +++ b/server/src/main/java/org/elasticsearch/search/DocValueFormat.java @@ -271,13 +271,6 @@ private DateTime(StreamInput in) throws IOException { this.formatter = DateFormatter.forPattern(formatterPattern).withZone(this.timeZone).withLocale(locale); this.parser = formatter.toDateMathParser(); this.resolution = DateFieldMapper.Resolution.ofOrdinal(in.readVInt()); - if (in.getTransportVersion().between(TransportVersions.V_7_7_0, TransportVersions.V_8_0_0)) { - /* when deserialising from 7.7+ nodes expect a flag indicating if a pattern is of joda style - This is only used to support joda style indices in 7.x, in 8 we no longer support this. - All indices in 8 should use java style pattern. Hence we can ignore this flag. - */ - in.readBoolean(); - } this.formatSortValues = in.readBoolean(); } @@ -302,13 +295,6 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeString(timeZone.getId()); out.writeVInt(resolution.ordinal()); - if (out.getTransportVersion().between(TransportVersions.V_7_7_0, TransportVersions.V_8_0_0)) { - /* when serializing to 7.7+ send out a flag indicating if a pattern is of joda style - This is only used to support joda style indices in 7.x, in 8 we no longer support this. - All indices in 8 should use java style pattern. Hence this flag is always false. - */ - out.writeBoolean(false); - } out.writeBoolean(formatSortValues); } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index d457607d0bc7f..32a71bbe8a26a 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -333,11 +333,7 @@ private BanParentTaskRequest(StreamInput in) throws IOException { parentTaskId = TaskId.readFromStream(in); ban = in.readBoolean(); reason = ban ? in.readString() : null; - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_8_0)) { - waitForCompletion = in.readBoolean(); - } else { - waitForCompletion = false; - } + waitForCompletion = in.readBoolean(); } @Override @@ -348,9 +344,7 @@ public void writeTo(StreamOutput out) throws IOException { if (ban) { out.writeString(reason); } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_8_0)) { - out.writeBoolean(waitForCompletion); - } + out.writeBoolean(waitForCompletion); } } diff --git a/server/src/main/java/org/elasticsearch/transport/ProxyConnectionStrategy.java b/server/src/main/java/org/elasticsearch/transport/ProxyConnectionStrategy.java index 4b3678d75af7c..d5047a61e4606 100644 --- a/server/src/main/java/org/elasticsearch/transport/ProxyConnectionStrategy.java +++ b/server/src/main/java/org/elasticsearch/transport/ProxyConnectionStrategy.java @@ -9,7 +9,6 @@ package org.elasticsearch.transport; -import org.elasticsearch.TransportVersions; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterName; @@ -355,11 +354,7 @@ public ProxyModeInfo(String address, String serverName, int maxSocketConnections private ProxyModeInfo(StreamInput input) throws IOException { address = input.readString(); - if (input.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - serverName = input.readString(); - } else { - serverName = null; - } + serverName = input.readString(); maxSocketConnections = input.readVInt(); numSocketsConnected = input.readVInt(); } @@ -376,9 +371,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(address); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - out.writeString(serverName); - } + out.writeString(serverName); out.writeVInt(maxSocketConnections); out.writeVInt(numSocketsConnected); } diff --git a/server/src/test/java/org/elasticsearch/action/OriginalIndicesTests.java b/server/src/test/java/org/elasticsearch/action/OriginalIndicesTests.java index 4a2fb698468b4..1db1b62facc9d 100644 --- a/server/src/test/java/org/elasticsearch/action/OriginalIndicesTests.java +++ b/server/src/test/java/org/elasticsearch/action/OriginalIndicesTests.java @@ -9,7 +9,6 @@ package org.elasticsearch.action; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; @@ -43,15 +42,7 @@ public void testOriginalIndicesSerialization() throws IOException { OriginalIndices originalIndices2 = OriginalIndices.readOriginalIndices(in); assertThat(originalIndices2.indices(), equalTo(originalIndices.indices())); - // indices options are not equivalent when sent to an older version and re-read due - // to the addition of hidden indices as expand to hidden indices is always true when - // read from a prior version - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0) - || originalIndices.indicesOptions().expandWildcardsHidden()) { - assertThat(originalIndices2.indicesOptions(), equalTo(originalIndices.indicesOptions())); - } else if (originalIndices.indicesOptions().expandWildcardsHidden()) { - assertThat(originalIndices2.indicesOptions(), equalTo(originalIndices.indicesOptions())); - } + assertThat(originalIndices2.indicesOptions(), equalTo(originalIndices.indicesOptions())); } } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/shards/ClusterSearchShardsRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/shards/ClusterSearchShardsRequestTests.java index 7bf2cc600dd3b..d4a3d29ea68f2 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/shards/ClusterSearchShardsRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/shards/ClusterSearchShardsRequestTests.java @@ -10,7 +10,6 @@ package org.elasticsearch.action.admin.cluster.shards; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; @@ -54,12 +53,7 @@ public void testSerialization() throws Exception { in.setTransportVersion(version); ClusterSearchShardsRequest deserialized = new ClusterSearchShardsRequest(in); assertArrayEquals(request.indices(), deserialized.indices()); - // indices options are not equivalent when sent to an older version and re-read due - // to the addition of hidden indices as expand to hidden indices is always true when - // read from a prior version - if (version.onOrAfter(TransportVersions.V_7_7_0) || request.indicesOptions().expandWildcardsHidden()) { - assertEquals(request.indicesOptions(), deserialized.indicesOptions()); - } + assertEquals(request.indicesOptions(), deserialized.indicesOptions()); assertEquals(request.routing(), deserialized.routing()); assertEquals(request.preference(), deserialized.preference()); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java index dc267c7c1d80b..e45e940334a9e 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java @@ -56,14 +56,7 @@ public void testBwcSerialization() throws Exception { assertEquals(request.ackTimeout(), in.readTimeValue()); assertArrayEquals(request.indices(), in.readStringArray()); final IndicesOptions indicesOptions = IndicesOptions.readIndicesOptions(in); - // indices options are not equivalent when sent to an older version and re-read due - // to the addition of hidden indices as expand to hidden indices is always true when - // read from a prior version - // TODO update version on backport! - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0) - || request.indicesOptions().expandWildcardsHidden()) { - assertEquals(request.indicesOptions(), indicesOptions); - } + assertEquals(request.indicesOptions(), indicesOptions); assertEquals(request.waitForActiveShards(), ActiveShardCount.readFrom(in)); } } @@ -92,13 +85,7 @@ public void testBwcSerialization() throws Exception { assertEquals(sample.masterNodeTimeout(), deserializedRequest.masterNodeTimeout()); assertEquals(sample.ackTimeout(), deserializedRequest.ackTimeout()); assertArrayEquals(sample.indices(), deserializedRequest.indices()); - // indices options are not equivalent when sent to an older version and re-read due - // to the addition of hidden indices as expand to hidden indices is always true when - // read from a prior version - // TODO change version on backport - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0) || sample.indicesOptions().expandWildcardsHidden()) { - assertEquals(sample.indicesOptions(), deserializedRequest.indicesOptions()); - } + assertEquals(sample.indicesOptions(), deserializedRequest.indicesOptions()); assertEquals(sample.waitForActiveShards(), deserializedRequest.waitForActiveShards()); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamServiceTests.java index 8be0f4de15500..63e92835ba8db 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataMigrateToDataStreamServiceTests.java @@ -24,11 +24,13 @@ import java.io.IOException; import java.util.List; import java.util.Set; +import java.util.function.Function; import static org.elasticsearch.cluster.metadata.DataStreamTestHelper.generateMapping; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -425,6 +427,86 @@ public void testCreateDataStreamWithoutSuppliedWriteIndex() { assertThat(e.getMessage(), containsString("alias [" + dataStreamName + "] must specify a write index")); } + public void testSettingsVersion() throws IOException { + /* + * This tests that applyFailureStoreSettings updates the settings version when the settings have been modified, and does not change + * it otherwise. Incrementing the settings version when the settings have not changed can result in an assertion failing in + * IndexService::updateMetadata. + */ + String indexName = randomAlphaOfLength(30); + String dataStreamName = randomAlphaOfLength(50); + Function mapperSupplier = this::getMapperService; + boolean removeAlias = randomBoolean(); + boolean failureStore = randomBoolean(); + Settings nodeSettings = Settings.EMPTY; + + { + /* + * Here the input indexMetadata will have the index.hidden setting set to true. So we expect no change to the settings, and + * for the settings version to remain the same + */ + Metadata.Builder metadataBuilder = Metadata.builder(); + Settings indexMetadataSettings = Settings.builder() + .put(IndexMetadata.SETTING_INDEX_HIDDEN, true) + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .build(); + IndexMetadata indexMetadata = IndexMetadata.builder(indexName) + .settings(indexMetadataSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putMapping(getTestMappingWithTimestamp()) + .build(); + MetadataMigrateToDataStreamService.prepareBackingIndex( + metadataBuilder, + indexMetadata, + dataStreamName, + mapperSupplier, + removeAlias, + failureStore, + nodeSettings + ); + Metadata metadata = metadataBuilder.build(); + assertThat(indexMetadata.getSettings(), equalTo(metadata.index(indexName).getSettings())); + assertThat(metadata.index(indexName).getSettingsVersion(), equalTo(indexMetadata.getSettingsVersion())); + } + { + /* + * Here the input indexMetadata will not have the index.hidden setting set to true. So prepareBackingIndex will add that, + * meaning that the settings and settings version will change. + */ + Metadata.Builder metadataBuilder = Metadata.builder(); + Settings indexMetadataSettings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); + IndexMetadata indexMetadata = IndexMetadata.builder(indexName) + .settings(indexMetadataSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putMapping(getTestMappingWithTimestamp()) + .build(); + MetadataMigrateToDataStreamService.prepareBackingIndex( + metadataBuilder, + indexMetadata, + dataStreamName, + mapperSupplier, + removeAlias, + failureStore, + nodeSettings + ); + Metadata metadata = metadataBuilder.build(); + assertThat(indexMetadata.getSettings(), not(equalTo(metadata.index(indexName).getSettings()))); + assertThat(metadata.index(indexName).getSettingsVersion(), equalTo(indexMetadata.getSettingsVersion() + 1)); + } + } + + private String getTestMappingWithTimestamp() { + return """ + { + "properties": { + "@timestamp": {"type": "date"} + } + } + """; + } + private MapperService getMapperService(IndexMetadata im) { try { return createMapperService("{\"_doc\": " + im.mapping().source().toString() + "}"); diff --git a/server/src/test/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerServiceTests.java b/server/src/test/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerServiceTests.java index 742c03b564307..52a30e2494974 100644 --- a/server/src/test/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerServiceTests.java +++ b/server/src/test/java/org/elasticsearch/common/breaker/PreallocatedCircuitBreakerServiceTests.java @@ -25,100 +25,94 @@ public class PreallocatedCircuitBreakerServiceTests extends ESTestCase { public void testUseNotPreallocated() { - try (HierarchyCircuitBreakerService real = real()) { - try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { - CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); - b.addEstimateBytesAndMaybeBreak(100, "test"); - b.addWithoutBreaking(-100); - } - assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); + HierarchyCircuitBreakerService real = real(); + try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { + CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); + b.addEstimateBytesAndMaybeBreak(100, "test"); + b.addWithoutBreaking(-100); } + assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } public void testUseLessThanPreallocated() { - try (HierarchyCircuitBreakerService real = real()) { - try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { - CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); - b.addEstimateBytesAndMaybeBreak(100, "test"); - b.addWithoutBreaking(-100); - } - assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); + HierarchyCircuitBreakerService real = real(); + try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { + CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); + b.addEstimateBytesAndMaybeBreak(100, "test"); + b.addWithoutBreaking(-100); } + assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } public void testCloseIsIdempotent() { - try (HierarchyCircuitBreakerService real = real()) { - try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { - CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); - b.addEstimateBytesAndMaybeBreak(100, "test"); - b.addWithoutBreaking(-100); - preallocated.close(); - assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); - } // Closes again which should do nothing + HierarchyCircuitBreakerService real = real(); + try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { + CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); + b.addEstimateBytesAndMaybeBreak(100, "test"); + b.addWithoutBreaking(-100); + preallocated.close(); assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); - } + } // Closes again which should do nothing + assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } public void testUseMoreThanPreallocated() { - try (HierarchyCircuitBreakerService real = real()) { - try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { - CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); - b.addEstimateBytesAndMaybeBreak(2048, "test"); - b.addWithoutBreaking(-2048); - } - assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); + HierarchyCircuitBreakerService real = real(); + try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, 1024)) { + CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); + b.addEstimateBytesAndMaybeBreak(2048, "test"); + b.addWithoutBreaking(-2048); } + assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } public void testPreallocateMoreThanRemains() { - try (HierarchyCircuitBreakerService real = real()) { - long limit = real.getBreaker(CircuitBreaker.REQUEST).getLimit(); - Exception e = expectThrows(CircuitBreakingException.class, () -> preallocateRequest(real, limit + 1024)); - assertThat(e.getMessage(), startsWith("[request] Data too large, data for [preallocate[test]] would be [")); - } + HierarchyCircuitBreakerService real = real(); + long limit = real.getBreaker(CircuitBreaker.REQUEST).getLimit(); + Exception e = expectThrows(CircuitBreakingException.class, () -> preallocateRequest(real, limit + 1024)); + assertThat(e.getMessage(), startsWith("[request] Data too large, data for [preallocate[test]] would be [")); } public void testRandom() { - try (HierarchyCircuitBreakerService real = real()) { - CircuitBreaker realBreaker = real.getBreaker(CircuitBreaker.REQUEST); - long preallocatedBytes = randomLongBetween(1, (long) (realBreaker.getLimit() * .8)); - try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, preallocatedBytes)) { - CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); - boolean usedPreallocated = false; - long current = 0; - for (int i = 0; i < 10000; i++) { - if (current >= preallocatedBytes) { - usedPreallocated = true; - } - if (usedPreallocated) { - assertThat(realBreaker.getUsed(), equalTo(current)); - } else { - assertThat(realBreaker.getUsed(), equalTo(preallocatedBytes)); - } - if (current > 0 && randomBoolean()) { - long delta = randomLongBetween(-Math.min(current, realBreaker.getLimit() / 100), 0); - b.addWithoutBreaking(delta); - current += delta; - continue; - } - long delta = randomLongBetween(0, realBreaker.getLimit() / 100); - if (randomBoolean()) { - b.addWithoutBreaking(delta); - current += delta; - continue; - } - if (current + delta < realBreaker.getLimit()) { - b.addEstimateBytesAndMaybeBreak(delta, "test"); - current += delta; - continue; - } - Exception e = expectThrows(CircuitBreakingException.class, () -> b.addEstimateBytesAndMaybeBreak(delta, "test")); - assertThat(e.getMessage(), startsWith("[request] Data too large, data for [test] would be [")); + HierarchyCircuitBreakerService real = real(); + CircuitBreaker realBreaker = real.getBreaker(CircuitBreaker.REQUEST); + long preallocatedBytes = randomLongBetween(1, (long) (realBreaker.getLimit() * .8)); + try (PreallocatedCircuitBreakerService preallocated = preallocateRequest(real, preallocatedBytes)) { + CircuitBreaker b = preallocated.getBreaker(CircuitBreaker.REQUEST); + boolean usedPreallocated = false; + long current = 0; + for (int i = 0; i < 10000; i++) { + if (current >= preallocatedBytes) { + usedPreallocated = true; + } + if (usedPreallocated) { + assertThat(realBreaker.getUsed(), equalTo(current)); + } else { + assertThat(realBreaker.getUsed(), equalTo(preallocatedBytes)); + } + if (current > 0 && randomBoolean()) { + long delta = randomLongBetween(-Math.min(current, realBreaker.getLimit() / 100), 0); + b.addWithoutBreaking(delta); + current += delta; + continue; } - b.addWithoutBreaking(-current); + long delta = randomLongBetween(0, realBreaker.getLimit() / 100); + if (randomBoolean()) { + b.addWithoutBreaking(delta); + current += delta; + continue; + } + if (current + delta < realBreaker.getLimit()) { + b.addEstimateBytesAndMaybeBreak(delta, "test"); + current += delta; + continue; + } + Exception e = expectThrows(CircuitBreakingException.class, () -> b.addEstimateBytesAndMaybeBreak(delta, "test")); + assertThat(e.getMessage(), startsWith("[request] Data too large, data for [test] would be [")); } - assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); + b.addWithoutBreaking(-current); } + assertThat(real.getBreaker(CircuitBreaker.REQUEST).getUsed(), equalTo(0L)); } private HierarchyCircuitBreakerService real() { diff --git a/server/src/test/java/org/elasticsearch/common/util/BigArraysTests.java b/server/src/test/java/org/elasticsearch/common/util/BigArraysTests.java index 3ef010f760ab6..1d571c3d9b296 100644 --- a/server/src/test/java/org/elasticsearch/common/util/BigArraysTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/BigArraysTests.java @@ -527,49 +527,46 @@ public void testOverSizeUsesMinPageCount() { */ public void testPreallocate() { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + HierarchyCircuitBreakerService realBreakers = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + List.of(), + clusterSettings + ); + BigArrays unPreAllocated = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), realBreakers); + long toPreallocate = randomLongBetween(4000, 10000); + CircuitBreaker realBreaker = realBreakers.getBreaker(CircuitBreaker.REQUEST); + assertThat(realBreaker.getUsed(), equalTo(0L)); try ( - HierarchyCircuitBreakerService realBreakers = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - List.of(), - clusterSettings + PreallocatedCircuitBreakerService prealloctedBreakerService = new PreallocatedCircuitBreakerService( + realBreakers, + CircuitBreaker.REQUEST, + toPreallocate, + "test" ) ) { - BigArrays unPreAllocated = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), realBreakers); - long toPreallocate = randomLongBetween(4000, 10000); - CircuitBreaker realBreaker = realBreakers.getBreaker(CircuitBreaker.REQUEST); - assertThat(realBreaker.getUsed(), equalTo(0L)); - try ( - PreallocatedCircuitBreakerService prealloctedBreakerService = new PreallocatedCircuitBreakerService( - realBreakers, - CircuitBreaker.REQUEST, - toPreallocate, - "test" - ) - ) { - assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); - BigArrays preallocated = unPreAllocated.withBreakerService(prealloctedBreakerService); + assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); + BigArrays preallocated = unPreAllocated.withBreakerService(prealloctedBreakerService); - // We don't grab any bytes just making a new BigArrays - assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); + // We don't grab any bytes just making a new BigArrays + assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); - List arrays = new ArrayList<>(); - for (int i = 0; i < 30; i++) { - // We're well under the preallocation so grabbing a little array doesn't allocate anything - arrays.add(preallocated.newLongArray(1)); - assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); - } - - // Allocating a large array *does* allocate some bytes - arrays.add(preallocated.newLongArray(1024)); - long expectedMin = (PageCacheRecycler.LONG_PAGE_SIZE + arrays.size()) * Long.BYTES; - assertThat(realBreaker.getUsed(), greaterThanOrEqualTo(expectedMin)); - // 64 should be enough room for each BigArray object - assertThat(realBreaker.getUsed(), lessThanOrEqualTo(expectedMin + 64 * arrays.size())); - Releasables.close(arrays); + List arrays = new ArrayList<>(); + for (int i = 0; i < 30; i++) { + // We're well under the preallocation so grabbing a little array doesn't allocate anything + arrays.add(preallocated.newLongArray(1)); + assertThat(realBreaker.getUsed(), equalTo(toPreallocate)); } - assertThat(realBreaker.getUsed(), equalTo(0L)); + + // Allocating a large array *does* allocate some bytes + arrays.add(preallocated.newLongArray(1024)); + long expectedMin = (PageCacheRecycler.LONG_PAGE_SIZE + arrays.size()) * Long.BYTES; + assertThat(realBreaker.getUsed(), greaterThanOrEqualTo(expectedMin)); + // 64 should be enough room for each BigArray object + assertThat(realBreaker.getUsed(), lessThanOrEqualTo(expectedMin + 64 * arrays.size())); + Releasables.close(arrays); } + assertThat(realBreaker.getUsed(), equalTo(0L)); } private List bigArrayCreators(final long maxSize, final boolean withBreaking) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java index 2681f70743f56..a161bcbc5d6d2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/ParametrizedMapperTests.java @@ -11,16 +11,13 @@ import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.analysis.AnalyzerScope; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; @@ -578,25 +575,6 @@ public void testAnalyzers() { ); } - @UpdateForV9(owner = UpdateForV9.Owner.SEARCH_FOUNDATIONS) - @AwaitsFix(bugUrl = "this is testing legacy functionality so can likely be removed in 9.0") - public void testDeprecatedParameters() { - // 'index' is declared explicitly, 'store' is not, but is one of the previously always-accepted params - String mapping = """ - {"type":"test_mapper","index":false,"store":true,"required":"value"}"""; - TestMapper mapper = fromMapping(mapping, IndexVersions.V_7_8_0, TransportVersions.V_7_8_0); - assertWarnings("Parameter [store] has no effect on type [test_mapper] and will be removed in future"); - assertFalse(mapper.index); - assertEquals(""" - {"field":{"type":"test_mapper","index":false,"required":"value"}}""", Strings.toString(mapper)); - - MapperParsingException e = expectThrows( - MapperParsingException.class, - () -> fromMapping(mapping, IndexVersions.V_8_0_0, TransportVersions.V_8_0_0) - ); - assertEquals("unknown parameter [store] on mapper [field] of type [test_mapper]", e.getMessage()); - } - public void testLinkedAnalyzers() throws IOException { String mapping = """ {"type":"test_mapper","analyzer":"_standard","required":"value"}"""; diff --git a/server/src/test/java/org/elasticsearch/indices/breaker/HierarchyCircuitBreakerServiceTests.java b/server/src/test/java/org/elasticsearch/indices/breaker/HierarchyCircuitBreakerServiceTests.java index 0a8fbcf6d56b9..4cb6e1febd5e9 100644 --- a/server/src/test/java/org/elasticsearch/indices/breaker/HierarchyCircuitBreakerServiceTests.java +++ b/server/src/test/java/org/elasticsearch/indices/breaker/HierarchyCircuitBreakerServiceTests.java @@ -200,39 +200,38 @@ public void testBorrowingSiblingBreakerMemory() { .put(HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "150mb") .put(HierarchyCircuitBreakerService.FIELDDATA_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "150mb") .build(); - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - clusterSettings, - Collections.emptyList(), - new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { - CircuitBreaker requestCircuitBreaker = service.getBreaker(CircuitBreaker.REQUEST); - CircuitBreaker fieldDataCircuitBreaker = service.getBreaker(CircuitBreaker.FIELDDATA); - - assertEquals(new ByteSizeValue(200, ByteSizeUnit.MB).getBytes(), service.stats().getStats(CircuitBreaker.PARENT).getLimit()); - assertEquals(new ByteSizeValue(150, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getLimit()); - assertEquals(new ByteSizeValue(150, ByteSizeUnit.MB).getBytes(), fieldDataCircuitBreaker.getLimit()); - - fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); - assertEquals(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), fieldDataCircuitBreaker.getUsed(), 0.0); - requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); - assertEquals(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getUsed(), 0.0); - requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); - assertEquals(new ByteSizeValue(100, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getUsed(), 0.0); - CircuitBreakingException exception = expectThrows( - CircuitBreakingException.class, - () -> requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should break") - ); - assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [should break] would be")); - assertThat(exception.getMessage(), containsString("which is larger than the limit of [209715200/200mb]")); - assertThat(exception.getMessage(), containsString("usages [")); - assertThat(exception.getMessage(), containsString("fielddata=54001664/51.5mb")); - assertThat(exception.getMessage(), containsString("inflight_requests=0/0b")); - assertThat(exception.getMessage(), containsString("request=157286400/150mb")); - assertThat(exception.getDurability(), equalTo(CircuitBreaker.Durability.TRANSIENT)); - } + + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + clusterSettings, + Collections.emptyList(), + new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + CircuitBreaker requestCircuitBreaker = service.getBreaker(CircuitBreaker.REQUEST); + CircuitBreaker fieldDataCircuitBreaker = service.getBreaker(CircuitBreaker.FIELDDATA); + + assertEquals(new ByteSizeValue(200, ByteSizeUnit.MB).getBytes(), service.stats().getStats(CircuitBreaker.PARENT).getLimit()); + assertEquals(new ByteSizeValue(150, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getLimit()); + assertEquals(new ByteSizeValue(150, ByteSizeUnit.MB).getBytes(), fieldDataCircuitBreaker.getLimit()); + + fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); + assertEquals(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), fieldDataCircuitBreaker.getUsed(), 0.0); + requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); + assertEquals(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getUsed(), 0.0); + requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should not break"); + assertEquals(new ByteSizeValue(100, ByteSizeUnit.MB).getBytes(), requestCircuitBreaker.getUsed(), 0.0); + CircuitBreakingException exception = expectThrows( + CircuitBreakingException.class, + () -> requestCircuitBreaker.addEstimateBytesAndMaybeBreak(new ByteSizeValue(50, ByteSizeUnit.MB).getBytes(), "should break") + ); + assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [should break] would be")); + assertThat(exception.getMessage(), containsString("which is larger than the limit of [209715200/200mb]")); + assertThat(exception.getMessage(), containsString("usages [")); + assertThat(exception.getMessage(), containsString("fielddata=54001664/51.5mb")); + assertThat(exception.getMessage(), containsString("inflight_requests=0/0b")); + assertThat(exception.getMessage(), containsString("request=157286400/150mb")); + assertThat(exception.getDurability(), equalTo(CircuitBreaker.Durability.TRANSIENT)); assertCircuitBreakerLimitWarning(); } @@ -683,42 +682,41 @@ public void testTrippedCircuitBreakerDurability() { .put(HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "150mb") .put(HierarchyCircuitBreakerService.FIELDDATA_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "150mb") .build(); - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - clusterSettings, - Collections.emptyList(), - new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { - CircuitBreaker requestCircuitBreaker = service.getBreaker(CircuitBreaker.REQUEST); - CircuitBreaker fieldDataCircuitBreaker = service.getBreaker(CircuitBreaker.FIELDDATA); - - CircuitBreaker.Durability expectedDurability; - if (randomBoolean()) { - fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(100), "should not break"); - requestCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(70), "should not break"); - expectedDurability = CircuitBreaker.Durability.PERMANENT; - } else { - fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(70), "should not break"); - requestCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(120), "should not break"); - expectedDurability = CircuitBreaker.Durability.TRANSIENT; - } + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + clusterSettings, + Collections.emptyList(), + new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + CircuitBreaker requestCircuitBreaker = service.getBreaker(CircuitBreaker.REQUEST); + CircuitBreaker fieldDataCircuitBreaker = service.getBreaker(CircuitBreaker.FIELDDATA); + + CircuitBreaker.Durability expectedDurability; + if (randomBoolean()) { + fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(100), "should not break"); + requestCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(70), "should not break"); + expectedDurability = CircuitBreaker.Durability.PERMANENT; + } else { + fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(70), "should not break"); + requestCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(120), "should not break"); + expectedDurability = CircuitBreaker.Durability.TRANSIENT; + } - CircuitBreakingException exception = expectThrows( - CircuitBreakingException.class, - () -> fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(40), "should break") - ); + CircuitBreakingException exception = expectThrows( + CircuitBreakingException.class, + () -> fieldDataCircuitBreaker.addEstimateBytesAndMaybeBreak(mb(40), "should break") + ); + + assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [should break] would be")); + assertThat(exception.getMessage(), containsString("which is larger than the limit of [209715200/200mb]")); + assertThat( + "Expected [" + expectedDurability + "] due to [" + exception.getMessage() + "]", + exception.getDurability(), + equalTo(expectedDurability) + ); - assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [should break] would be")); - assertThat(exception.getMessage(), containsString("which is larger than the limit of [209715200/200mb]")); - assertThat( - "Expected [" + expectedDurability + "] due to [" + exception.getMessage() + "]", - exception.getDurability(), - equalTo(expectedDurability) - ); - } assertCircuitBreakerLimitWarning(); + } public void testAllocationBucketsBreaker() { @@ -727,34 +725,31 @@ public void testAllocationBucketsBreaker() { .put(HierarchyCircuitBreakerService.USE_REAL_MEMORY_USAGE_SETTING.getKey(), "false") .build(); - try ( - HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - clusterSettings, - Collections.emptyList(), - new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { + HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + clusterSettings, + Collections.emptyList(), + new ClusterSettings(clusterSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); - long parentLimitBytes = service.getParentLimit(); - assertEquals(new ByteSizeValue(100, ByteSizeUnit.BYTES).getBytes(), parentLimitBytes); + long parentLimitBytes = service.getParentLimit(); + assertEquals(new ByteSizeValue(100, ByteSizeUnit.BYTES).getBytes(), parentLimitBytes); - CircuitBreaker breaker = service.getBreaker(CircuitBreaker.REQUEST); - MultiBucketConsumerService.MultiBucketConsumer multiBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( - 10000, - breaker - ); + CircuitBreaker breaker = service.getBreaker(CircuitBreaker.REQUEST); + MultiBucketConsumerService.MultiBucketConsumer multiBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + 10000, + breaker + ); - // make sure used bytes is greater than the total circuit breaker limit - breaker.addWithoutBreaking(200); - // make sure that we check on the following call - for (int i = 0; i < 1023; i++) { - multiBucketConsumer.accept(0); - } - CircuitBreakingException exception = expectThrows(CircuitBreakingException.class, () -> multiBucketConsumer.accept(1024)); - assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [allocated_buckets] would be")); - assertThat(exception.getMessage(), containsString("which is larger than the limit of [100/100b]")); + // make sure used bytes is greater than the total circuit breaker limit + breaker.addWithoutBreaking(200); + // make sure that we check on the following call + for (int i = 0; i < 1023; i++) { + multiBucketConsumer.accept(0); } + CircuitBreakingException exception = expectThrows(CircuitBreakingException.class, () -> multiBucketConsumer.accept(1024)); + assertThat(exception.getMessage(), containsString("[parent] Data too large, data for [allocated_buckets] would be")); + assertThat(exception.getMessage(), containsString("which is larger than the limit of [100/100b]")); assertCircuitBreakerLimitWarning(); } @@ -790,21 +785,18 @@ public void testRegisterCustomCircuitBreakers_WithDuplicates() { } public void testCustomCircuitBreakers() { - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - Arrays.asList(new BreakerSettings("foo", 100, 1.2), new BreakerSettings("bar", 200, 0.1)), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { - assertThat(service.getBreaker("foo"), is(not(nullValue()))); - assertThat(service.getBreaker("foo").getOverhead(), equalTo(1.2)); - assertThat(service.getBreaker("foo").getLimit(), equalTo(100L)); - assertThat(service.getBreaker("bar"), is(not(nullValue()))); - assertThat(service.getBreaker("bar").getOverhead(), equalTo(0.1)); - assertThat(service.getBreaker("bar").getLimit(), equalTo(200L)); - } + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + Arrays.asList(new BreakerSettings("foo", 100, 1.2), new BreakerSettings("bar", 200, 0.1)), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + assertThat(service.getBreaker("foo"), is(not(nullValue()))); + assertThat(service.getBreaker("foo").getOverhead(), equalTo(1.2)); + assertThat(service.getBreaker("foo").getLimit(), equalTo(100L)); + assertThat(service.getBreaker("bar"), is(not(nullValue()))); + assertThat(service.getBreaker("bar").getOverhead(), equalTo(0.1)); + assertThat(service.getBreaker("bar").getLimit(), equalTo(200L)); } private static long mb(long size) { @@ -812,28 +804,25 @@ private static long mb(long size) { } public void testUpdatingUseRealMemory() { - try ( - HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - Collections.emptyList(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { - // use real memory default true - assertTrue(service.isTrackRealMemoryUsage()); - assertThat(service.getOverLimitStrategy(), instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class)); - - // update use_real_memory to false - service.updateUseRealMemorySetting(false); - assertFalse(service.isTrackRealMemoryUsage()); - assertThat(service.getOverLimitStrategy(), not(instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class))); - - // update use_real_memory to true - service.updateUseRealMemorySetting(true); - assertTrue(service.isTrackRealMemoryUsage()); - assertThat(service.getOverLimitStrategy(), instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class)); - } + HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + Collections.emptyList(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + // use real memory default true + assertTrue(service.isTrackRealMemoryUsage()); + assertThat(service.getOverLimitStrategy(), instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class)); + + // update use_real_memory to false + service.updateUseRealMemorySetting(false); + assertFalse(service.isTrackRealMemoryUsage()); + assertThat(service.getOverLimitStrategy(), not(instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class))); + + // update use_real_memory to true + service.updateUseRealMemorySetting(true); + assertTrue(service.isTrackRealMemoryUsage()); + assertThat(service.getOverLimitStrategy(), instanceOf(HierarchyCircuitBreakerService.G1OverLimitStrategy.class)); } public void testApplySettingForUpdatingUseRealMemory() { @@ -842,34 +831,31 @@ public void testApplySettingForUpdatingUseRealMemory() { Settings initialSettings = Settings.builder().put(useRealMemoryUsageSetting, "true").build(); ClusterSettings clusterSettings = new ClusterSettings(initialSettings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - try ( - HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - Collections.emptyList(), - clusterSettings - ) - ) { - // total.limit defaults to 95% of the JVM heap if use_real_memory is true - assertEquals( - MemorySizeValue.parseBytesSizeValueOrHeapRatio("95%", totalCircuitBreakerLimitSetting).getBytes(), - service.getParentLimit() - ); + HierarchyCircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + Collections.emptyList(), + clusterSettings + ); + // total.limit defaults to 95% of the JVM heap if use_real_memory is true + assertEquals( + MemorySizeValue.parseBytesSizeValueOrHeapRatio("95%", totalCircuitBreakerLimitSetting).getBytes(), + service.getParentLimit() + ); - // total.limit defaults to 70% of the JVM heap if use_real_memory set to false - clusterSettings.applySettings(Settings.builder().put(useRealMemoryUsageSetting, false).build()); - assertEquals( - MemorySizeValue.parseBytesSizeValueOrHeapRatio("70%", totalCircuitBreakerLimitSetting).getBytes(), - service.getParentLimit() - ); + // total.limit defaults to 70% of the JVM heap if use_real_memory set to false + clusterSettings.applySettings(Settings.builder().put(useRealMemoryUsageSetting, false).build()); + assertEquals( + MemorySizeValue.parseBytesSizeValueOrHeapRatio("70%", totalCircuitBreakerLimitSetting).getBytes(), + service.getParentLimit() + ); - // total.limit defaults to 95% of the JVM heap if use_real_memory set to true - clusterSettings.applySettings(Settings.builder().put(useRealMemoryUsageSetting, true).build()); - assertEquals( - MemorySizeValue.parseBytesSizeValueOrHeapRatio("95%", totalCircuitBreakerLimitSetting).getBytes(), - service.getParentLimit() - ); - } + // total.limit defaults to 95% of the JVM heap if use_real_memory set to true + clusterSettings.applySettings(Settings.builder().put(useRealMemoryUsageSetting, true).build()); + assertEquals( + MemorySizeValue.parseBytesSizeValueOrHeapRatio("95%", totalCircuitBreakerLimitSetting).getBytes(), + service.getParentLimit() + ); } public void testSizeBelowMinimumWarning() { diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateReleasingTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateReleasingTests.java index c03578b0017bb..b9f8037efb9ca 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateReleasingTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/TDigestStateReleasingTests.java @@ -102,19 +102,18 @@ public void testReadWithData() throws IOException { */ public void testCircuitBreakerTrip(CheckedFunction tDigestStateFactory) throws E { - try (CrankyCircuitBreakerService circuitBreakerService = new CrankyCircuitBreakerService()) { - CircuitBreaker breaker = circuitBreakerService.getBreaker("test"); + CrankyCircuitBreakerService circuitBreakerService = new CrankyCircuitBreakerService(); + CircuitBreaker breaker = circuitBreakerService.getBreaker("test"); - try (TDigestState state = tDigestStateFactory.apply(breaker)) { - // Add some data to make it trip. It won't work in all digest types - for (int i = 0; i < 10; i++) { - state.add(randomDoubleBetween(-Double.MAX_VALUE, Double.MAX_VALUE, true)); - } - } catch (CircuitBreakingException e) { - // Expected - } finally { - assertThat("unreleased bytes", breaker.getUsed(), equalTo(0L)); + try (TDigestState state = tDigestStateFactory.apply(breaker)) { + // Add some data to make it trip. It won't work in all digest types + for (int i = 0; i < 10; i++) { + state.add(randomDoubleBetween(-Double.MAX_VALUE, Double.MAX_VALUE, true)); } + } catch (CircuitBreakingException e) { + // Expected + } finally { + assertThat("unreleased bytes", breaker.getUsed(), equalTo(0L)); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index d491e4bb52fa2..f057d35d6e7a9 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -671,54 +671,49 @@ private A searchAndReduce( Releasables.close(context); } } - - try { - if (aggTestConfig.incrementalReduce() && internalAggs.size() > 1) { - // sometimes do an incremental reduce - int toReduceSize = internalAggs.size(); - Collections.shuffle(internalAggs, random()); - int r = randomIntBetween(1, toReduceSize); - List toReduce = internalAggs.subList(0, r); - AggregationReduceContext reduceContext = new AggregationReduceContext.ForPartial( - bigArraysForReduction, - getMockScriptService(), - () -> false, - builder, - b -> {} - ); - internalAggs = new ArrayList<>(internalAggs.subList(r, toReduceSize)); - internalAggs.add(InternalAggregations.topLevelReduce(toReduce, reduceContext)); - for (InternalAggregations internalAggregation : internalAggs) { - assertRoundTrip(internalAggregation.copyResults()); - } - } - - // now do the final reduce - MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer( - maxBucket, - new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) - ); - AggregationReduceContext reduceContext = new AggregationReduceContext.ForFinal( + if (aggTestConfig.incrementalReduce() && internalAggs.size() > 1) { + // sometimes do an incremental reduce + int toReduceSize = internalAggs.size(); + Collections.shuffle(internalAggs, random()); + int r = randomIntBetween(1, toReduceSize); + List toReduce = internalAggs.subList(0, r); + AggregationReduceContext reduceContext = new AggregationReduceContext.ForPartial( bigArraysForReduction, getMockScriptService(), () -> false, builder, - reduceBucketConsumer + b -> {} ); + internalAggs = new ArrayList<>(internalAggs.subList(r, toReduceSize)); + internalAggs.add(InternalAggregations.topLevelReduce(toReduce, reduceContext)); + for (InternalAggregations internalAggregation : internalAggs) { + assertRoundTrip(internalAggregation.copyResults()); + } + } - @SuppressWarnings("unchecked") - A internalAgg = (A) doInternalAggregationsReduce(internalAggs, reduceContext); - assertRoundTrip(internalAgg); + // now do the final reduce + MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer( + maxBucket, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + AggregationReduceContext reduceContext = new AggregationReduceContext.ForFinal( + bigArraysForReduction, + getMockScriptService(), + () -> false, + builder, + reduceBucketConsumer + ); - doAssertReducedMultiBucketConsumer(internalAgg, reduceBucketConsumer); - assertRoundTrip(internalAgg); - if (aggTestConfig.builder instanceof ValuesSourceAggregationBuilder.MetricsAggregationBuilder) { - verifyMetricNames((ValuesSourceAggregationBuilder.MetricsAggregationBuilder) aggTestConfig.builder, internalAgg); - } - return internalAgg; - } finally { - Releasables.close(breakerService); + @SuppressWarnings("unchecked") + A internalAgg = (A) doInternalAggregationsReduce(internalAggs, reduceContext); + assertRoundTrip(internalAgg); + + doAssertReducedMultiBucketConsumer(internalAgg, reduceBucketConsumer); + assertRoundTrip(internalAgg); + if (aggTestConfig.builder instanceof ValuesSourceAggregationBuilder.MetricsAggregationBuilder) { + verifyMetricNames((ValuesSourceAggregationBuilder.MetricsAggregationBuilder) aggTestConfig.builder, internalAgg); } + return internalAgg; } private InternalAggregation doReduce(List aggregators, AggregationReduceContext reduceContext) { diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/boxplot/BoxplotAggregationBuilder.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/boxplot/BoxplotAggregationBuilder.java index 61917220f10d1..37e05898dfc48 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/boxplot/BoxplotAggregationBuilder.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/boxplot/BoxplotAggregationBuilder.java @@ -194,6 +194,6 @@ public Optional> getOutputFieldNames() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java index e7e06946a5289..93fd71b55332a 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/topmetrics/TopMetricsAggregationBuilder.java @@ -235,6 +235,6 @@ public Optional> getOutputFieldNames() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java index 7137302459c0b..bbe497718b62a 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestAggregationBuilder.java @@ -182,6 +182,6 @@ public String getType() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_8_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java index 90358014e1ee4..c87057a7cd2e1 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java @@ -107,6 +107,6 @@ public BucketCardinality bucketCardinality() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_7_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/AutoscalingMetadata.java b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/AutoscalingMetadata.java index 38c654f94fff3..b8911d24e1a28 100644 --- a/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/AutoscalingMetadata.java +++ b/x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/AutoscalingMetadata.java @@ -114,7 +114,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_8_0; + return TransportVersions.ZERO; } @Override @@ -173,7 +173,7 @@ static Diff readFrom(final StreamInput in) throws IOE @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_8_0; + return TransportVersions.ZERO; } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/analytics/action/AnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/analytics/action/AnalyticsStatsAction.java index 1420b1f86f12e..1792ba7aa4637 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/analytics/action/AnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/analytics/action/AnalyticsStatsAction.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.analytics.action; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.nodes.BaseNodeResponse; @@ -138,36 +137,13 @@ public NodeResponse(DiscoveryNode node, EnumCounters counters) { public NodeResponse(StreamInput in) throws IOException { super(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_8_0)) { - counters = new EnumCounters<>(in, Item.class); - } else { - counters = new EnumCounters<>(Item.class); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - counters.inc(Item.BOXPLOT, in.readVLong()); - } - counters.inc(Item.CUMULATIVE_CARDINALITY, in.readZLong()); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - counters.inc(Item.STRING_STATS, in.readVLong()); - counters.inc(Item.TOP_METRICS, in.readVLong()); - } - } + counters = new EnumCounters<>(in, Item.class); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_8_0)) { - counters.writeTo(out); - } else { - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - out.writeVLong(counters.get(Item.BOXPLOT)); - } - out.writeZLong(counters.get(Item.CUMULATIVE_CARDINALITY)); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - out.writeVLong(counters.get(Item.STRING_STATS)); - out.writeVLong(counters.get(Item.TOP_METRICS)); - } - } + counters.writeTo(out); } public EnumCounters getStats() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/indexing/IndexerJobStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/indexing/IndexerJobStats.java index a13cdf1966811..9f9755b6c9874 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/indexing/IndexerJobStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/indexing/IndexerJobStats.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.indexing; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -83,11 +82,8 @@ public IndexerJobStats(StreamInput in) throws IOException { this.searchTotal = in.readVLong(); this.indexFailures = in.readVLong(); this.searchFailures = in.readVLong(); - - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - this.processingTime = in.readVLong(); - this.processingTotal = in.readVLong(); - } + this.processingTime = in.readVLong(); + this.processingTotal = in.readVLong(); } public long getNumPages() { @@ -205,10 +201,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(searchTotal); out.writeVLong(indexFailures); out.writeVLong(searchFailures); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_7_0)) { - out.writeVLong(processingTime); - out.writeVLong(processingTotal); - } + out.writeVLong(processingTime); + out.writeVLong(processingTotal); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java new file mode 100644 index 0000000000000..c2f70b0be2916 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingByte(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException { + XContentBuilder builder = XContentBuilder.builder(xContent); + builder.startArray(); + for (byte v : value) { + builder.value(v); + } + builder.endArray(); + return BytesReference.bytes(builder); + } + + public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java new file mode 100644 index 0000000000000..651d135b761dd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingFloat(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (float v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } + + public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java new file mode 100644 index 0000000000000..37bf92e0dbfce --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; + +public record ChunkedInferenceEmbeddingSparse(List chunks) implements ChunkedInference { + + public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { + validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingSparse( + List.of( + new SparseEmbeddingChunk( + sparseEmbeddingResults.embeddings().get(i).tokens(), + inputs.get(i), + new TextOffset(0, inputs.get(i).length()) + ) + ) + ) + ); + } + + return results; + } + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.weightedTokens()))); + } + return asChunk.iterator(); + } + + private static BytesReference toBytesReference(XContent xContent, List tokens) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } + + public record SparseEmbeddingChunk(List weightedTokens, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java new file mode 100644 index 0000000000000..65be9f12d7686 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; + +import java.util.Iterator; +import java.util.stream.Stream; + +public record ChunkedInferenceError(Exception exception) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { + return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java deleted file mode 100644 index 18f88a8ff022a..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; - -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -public class ErrorChunkedInferenceResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "error_chunked"; - - private final Exception exception; - - public ErrorChunkedInferenceResults(Exception exception) { - this.exception = Objects.requireNonNull(exception); - } - - public ErrorChunkedInferenceResults(StreamInput in) throws IOException { - this.exception = in.readException(); - } - - public Exception getException() { - return exception; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeException(exception); - } - - @Override - public boolean equals(Object object) { - if (object == this) { - return true; - } - if (object == null || getClass() != object.getClass()) { - return false; - } - ErrorChunkedInferenceResults that = (ErrorChunkedInferenceResults) object; - // Just compare the message for serialization test purposes - return Objects.equals(exception.getMessage(), that.exception.getMessage()); - } - - @Override - public int hashCode() { - // Just compare the message for serialization test purposes - return Objects.hash(exception.getMessage()); - } - - @Override - public List transformToCoordinationFormat() { - return null; - } - - @Override - public List transformToLegacyFormat() { - return null; - } - - @Override - public Map asMap() { - Map asMap = new LinkedHashMap<>(); - asMap.put(NAME, exception.getMessage()); - return asMap; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.field(NAME, exception.getMessage()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return Stream.of(exception).map(e -> new Chunk(e.getMessage(), BytesArray.EMPTY)).iterator(); - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java deleted file mode 100644 index c961050acefdb..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public class InferenceChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_sparse_embedding_results"; - public static final String FIELD_NAME = "sparse_embedding_chunk"; - - public static InferenceChunkedSparseEmbeddingResults ofMlResult(MlChunkedTextExpansionResults mlInferenceResults) { - return new InferenceChunkedSparseEmbeddingResults(mlInferenceResults.getChunks()); - } - - /** - * Returns a list of {@link InferenceChunkedSparseEmbeddingResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedSparseEmbeddingResults} will have a single chunk containing the entire results from the - * {@link SparseEmbeddingResults}. - */ - public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), sparseEmbeddingResults.embeddings().get(i))); - } - - return results; - } - - private static InferenceChunkedSparseEmbeddingResults ofSingle(String input, SparseEmbeddingResults.Embedding embedding) { - var weightedTokens = embedding.tokens() - .stream() - .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) - .toList(); - - return new InferenceChunkedSparseEmbeddingResults(List.of(new MlChunkedTextExpansionResults.ChunkedResult(input, weightedTokens))); - } - - private final List chunkedResults; - - public InferenceChunkedSparseEmbeddingResults(List chunks) { - this.chunkedResults = chunks; - } - - public InferenceChunkedSparseEmbeddingResults(StreamInput in) throws IOException { - this.chunkedResults = in.readCollectionAsList(MlChunkedTextExpansionResults.ChunkedResult::new); - } - - public List getChunkedResults() { - return chunkedResults; - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunkedResults.iterator()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunkedResults); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of( - FIELD_NAME, - chunkedResults.stream().map(MlChunkedTextExpansionResults.ChunkedResult::asMap).collect(Collectors.toList()) - ); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedSparseEmbeddingResults that = (InferenceChunkedSparseEmbeddingResults) o; - return Objects.equals(chunkedResults, that.chunkedResults); - } - - @Override - public int hashCode() { - return Objects.hash(chunkedResults); - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunkedResults.stream() - .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens()))) - .iterator(); - } - - /** - * Serialises the {@link WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java deleted file mode 100644 index 6bd66664068d5..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingByteResults(List chunks, boolean isTruncated) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_byte_results"; - public static final String FIELD_NAME = "text_embedding_byte_chunk"; - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingByteResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedTextEmbeddingByteResults} will have a single chunk containing the entire results from the - * {@link InferenceTextEmbeddingByteResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingByteResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), textEmbeddings.embeddings().get(i).values())); - } - - return results; - } - - private static InferenceChunkedTextEmbeddingByteResults ofSingle(String input, byte[] byteEmbeddings) { - return new InferenceChunkedTextEmbeddingByteResults(List.of(new InferenceByteEmbeddingChunk(input, byteEmbeddings)), false); - } - - public InferenceChunkedTextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceByteEmbeddingChunk::new), in.readBoolean()); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - out.writeBoolean(isTruncated); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingByteResults that = (InferenceChunkedTextEmbeddingByteResults) o; - return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks, isTruncated); - } - - public record InferenceByteEmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject { - - public InferenceByteEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeByteArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (byte value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceByteEmbeddingChunk that = (InferenceByteEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - private static BytesReference toBytesReference(XContent xContent, byte[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (byte v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java deleted file mode 100644 index 369f22a807913..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingFloatResults(List chunks) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_float_results"; - public static final String FIELD_NAME = "text_embedding_float_chunk"; - - public InferenceChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceFloatEmbeddingChunk::new)); - } - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingFloatResults}. - * Each {@link InferenceChunkedTextEmbeddingFloatResults} contain a single chunk with the text and the - * {@link InferenceTextEmbeddingFloatResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingFloatResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - - for (int i = 0; i < inputs.size(); i++) { - results.add( - new InferenceChunkedTextEmbeddingFloatResults( - List.of(new InferenceFloatEmbeddingChunk(inputs.get(i), textEmbeddings.embeddings().get(i).values())) - ) - ); - } - - return results; - } - - public static InferenceChunkedTextEmbeddingFloatResults ofMlResults(MlChunkedTextEmbeddingFloatResults mlInferenceResult) { - return new InferenceChunkedTextEmbeddingFloatResults( - mlInferenceResult.getChunks() - .stream() - .map(chunk -> new InferenceFloatEmbeddingChunk(chunk.matchedText(), FloatConversionUtils.floatArrayOf(chunk.embedding()))) - .toList() - ); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - // TODO add isTruncated flag - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingFloatResults that = (InferenceChunkedTextEmbeddingFloatResults) o; - return Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks); - } - - public record InferenceFloatEmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject { - - public InferenceFloatEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readFloatArray()); - } - - public static InferenceFloatEmbeddingChunk of(String matchedText, double[] doubleEmbedding) { - return new InferenceFloatEmbeddingChunk(matchedText, FloatConversionUtils.floatArrayOf(doubleEmbedding)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeFloatArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (float value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceFloatEmbeddingChunk that = (InferenceFloatEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, float[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (float v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResults.java index 9c764babe33fc..a5f72bd51c6c6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResults.java @@ -139,7 +139,11 @@ public void writeTo(StreamOutput out) throws IOException { } public Map asMap() { - return Map.of(NAME, Map.of(INDEX, index, RELEVANCE_SCORE, relevanceScore, TEXT, text)); + if (text != null) { + return Map.of(NAME, Map.of(INDEX, index, RELEVANCE_SCORE, relevanceScore, TEXT, text)); + } else { + return Map.of(NAME, Map.of(INDEX, index, RELEVANCE_SCORE, relevanceScore)); + } } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java index 4c68d02264457..cb69f1e403e9c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -31,7 +31,7 @@ public static int getFirstEmbeddingSize(List embeddings) throws Il * Throws an exception if the number of elements in the input text list is different than the results in text embedding * response. */ - static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { + public static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { if (inputs.size() != embeddingSize) { throw new IllegalArgumentException( Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputs.size(), embeddingSize) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index cc66c361925a6..6985769f08cb6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -28,7 +28,6 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); public static final ParseField PREDICTION_FIELD_TYPE = new ParseField("prediction_field_type"); private static final MlConfigVersion MIN_SUPPORTED_VERSION = MlConfigVersion.V_7_6_0; - private static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.V_7_6_0; public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig( 0, @@ -227,7 +226,7 @@ public MlConfigVersion getMinimalSupportedMlConfigVersion() { @Override public TransportVersion getMinimalSupportedTransportVersion() { - return requestingImportance() ? TransportVersions.V_7_7_0 : MIN_SUPPORTED_TRANSPORT_VERSION; + return TransportVersions.ZERO; } public static Builder builder() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java index de4004792af7c..1cf1c0ddeb570 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java @@ -210,7 +210,7 @@ public boolean isSupported(InferenceConfig inferenceConfig) { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_8_0; + return TransportVersions.ZERO; } public static class Builder implements InferenceConfigUpdate.Builder { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index 337a1ac693128..403a747d6801f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -23,7 +23,6 @@ public class RegressionConfig implements LenientlyParsedInferenceConfig, Strictl public static final ParseField NAME = new ParseField("regression"); private static final MlConfigVersion MIN_SUPPORTED_VERSION = MlConfigVersion.V_7_6_0; - private static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.V_7_6_0; public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null); @@ -160,7 +159,7 @@ public MlConfigVersion getMinimalSupportedMlConfigVersion() { @Override public TransportVersion getMinimalSupportedTransportVersion() { - return requestingImportance() ? TransportVersions.V_7_7_0 : MIN_SUPPORTED_TRANSPORT_VERSION; + return TransportVersions.ZERO; } public static Builder builder() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java index dc1a7bdeef104..ceebe3105eb63 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java @@ -114,7 +114,7 @@ public String getName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_8_0; + return TransportVersions.ZERO; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index 42e0bdac6e83a..749bf786ad44b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -34,6 +34,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou long estimatedNumOperations(); default TransportVersion getMinimalCompatibilityVersion() { - return TransportVersions.V_7_6_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 9afad760e1b7f..94d0431ff470d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -286,7 +286,7 @@ public TransportVersion getMinimalCompatibilityVersion() { return models.stream() .map(TrainedModel::getMinimalCompatibilityVersion) .max(TransportVersion::compareTo) - .orElse(TransportVersions.V_7_6_0); + .orElse(TransportVersions.ZERO); } public static class Builder { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 966db6e785c5c..3e38aee177724 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -9,8 +9,6 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountables; import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -287,14 +285,6 @@ public Collection getChildResources() { return Collections.unmodifiableCollection(accountables); } - @Override - public TransportVersion getMinimalCompatibilityVersion() { - if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) { - return TransportVersions.V_7_7_0; - } - return TransportVersions.V_7_6_0; - } - public static final class Builder { private List featureNames; private ArrayList nodes; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/TransformState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/TransformState.java index 3b5ed564bcb47..18925cb4a9eae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/TransformState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/TransformState.java @@ -136,11 +136,7 @@ public TransformState(StreamInput in) throws IOException { reason = in.readOptionalString(); progress = in.readOptionalWriteable(TransformProgress::new); node = in.readOptionalWriteable(NodeAttributes::new); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_6_0)) { - shouldStopAtNextCheckpoint = in.readBoolean(); - } else { - shouldStopAtNextCheckpoint = false; - } + shouldStopAtNextCheckpoint = in.readBoolean(); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { authState = in.readOptionalWriteable(AuthorizationState::new); } else { @@ -237,9 +233,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(reason); out.writeOptionalWriteable(progress); out.writeOptionalWriteable(node); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_6_0)) { - out.writeBoolean(shouldStopAtNextCheckpoint); - } + out.writeBoolean(shouldStopAtNextCheckpoint); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeOptionalWriteable(authState); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java deleted file mode 100644 index 83678cd030bc2..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.INFERENCE; -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.TEXT; - -public class InferenceChunkedTextEmbeddingFloatResultsTests extends ESTestCase { - /** - * Similar to {@link org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults#asMap()} but it converts the - * embeddings float array into a list of floats to make testing equality easier. - */ - public static Map asMapWithListsInsteadOfArrays(InferenceChunkedTextEmbeddingFloatResults result) { - return Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - result.getChunks() - .stream() - .map(InferenceChunkedTextEmbeddingFloatResultsTests::inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays) - .collect(Collectors.toList()) - ); - } - - /** - * Similar to {@link MlChunkedTextEmbeddingFloatResults.EmbeddingChunk#asMap()} but it converts the double array into a list of doubles - * to make testing equality easier. - */ - public static Map inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays( - InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk chunk - ) { - var chunkAsList = new ArrayList(chunk.embedding().length); - for (double embedding : chunk.embedding()) { - chunkAsList.add((float) embedding); - } - var map = new HashMap(); - map.put(TEXT, chunk.matchedText()); - map.put(INFERENCE, chunkAsList); - return map; - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java index 46f10928cad08..ff6f6848f4b69 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java @@ -12,10 +12,12 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.AbstractChunkedBWCSerializationTestCase; +import org.hamcrest.Matchers; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; public class RankedDocsResultsTests extends AbstractChunkedBWCSerializationTestCase { @@ -37,6 +39,16 @@ public static RankedDocsResults.RankedDoc createRandomDoc() { return new RankedDocsResults.RankedDoc(randomIntBetween(0, 100), randomFloat(), randomBoolean() ? null : randomAlphaOfLength(10)); } + public void test_asMap() { + var index = randomIntBetween(0, 100); + var score = randomFloat(); + var mapNullText = new RankedDocsResults.RankedDoc(index, score, null).asMap(); + assertThat(mapNullText, Matchers.is(Map.of("ranked_doc", Map.of("index", index, "relevance_score", score)))); + + var mapWithText = new RankedDocsResults.RankedDoc(index, score, "Sample text").asMap(); + assertThat(mapWithText, Matchers.is(Map.of("ranked_doc", Map.of("index", index, "relevance_score", score, "text", "Sample text")))); + } + @Override protected RankedDocsResults mutateInstance(RankedDocsResults instance) throws IOException { List copy = new ArrayList<>(List.copyOf(instance.getRankedDocs())); diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java index 9cd6549b4be2c..dc132659417ff 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java +++ b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sample/CircuitBreakerTests.java @@ -101,15 +101,13 @@ public void testMemoryCleared() { } private void testMemoryCleared(boolean fail) { - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - Collections.singletonList(EqlTestUtils.circuitBreakerSettings(Settings.EMPTY)), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ); - var threadPool = createThreadPool() - ) { + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + Collections.singletonList(EqlTestUtils.circuitBreakerSettings(Settings.EMPTY)), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + try (var threadPool = createThreadPool()) { final var esClient = new ESMockClient(threadPool, service.getBreaker(CIRCUIT_BREAKER_NAME)); CircuitBreaker eqlCircuitBreaker = service.getBreaker(CIRCUIT_BREAKER_NAME); IndexResolver indexResolver = new IndexResolver(esClient, "cluster", DefaultDataTypeRegistry.INSTANCE, () -> emptySet()); diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java index ecf5ef61ac49a..fe1fca45364e3 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java +++ b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/execution/sequence/CircuitBreakerTests.java @@ -209,15 +209,13 @@ private void assertMemoryCleared( TriFunction esClientSupplier ) { final int searchRequestsExpectedCount = 2; - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.EMPTY, - breakerSettings(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ); - var threadPool = createThreadPool() - ) { + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.EMPTY, + breakerSettings(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + try (var threadPool = createThreadPool()) { final var esClient = esClientSupplier.apply(threadPool, service.getBreaker(CIRCUIT_BREAKER_NAME), searchRequestsExpectedCount); CircuitBreaker eqlCircuitBreaker = service.getBreaker(CIRCUIT_BREAKER_NAME); QueryClient eqlClient = buildQueryClient(esClient, eqlCircuitBreaker); @@ -248,16 +246,13 @@ public void testEqlCBCleanedUp_on_ParentCBBreak() { Settings settings = Settings.builder() .put(HierarchyCircuitBreakerService.TOTAL_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "0%") .build(); - - try ( - CircuitBreakerService service = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - settings, - breakerSettings(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ); - var threadPool = createThreadPool() - ) { + CircuitBreakerService service = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + settings, + breakerSettings(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + try (var threadPool = createThreadPool()) { final var esClient = new SuccessfulESMockClient( threadPool, service.getBreaker(CIRCUIT_BREAKER_NAME), diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java index 5f095a654fc89..b4bccf162d9e4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RLike.java @@ -8,12 +8,11 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import java.io.IOException; -public class RLike extends RegexMatch { +public abstract class RLike extends RegexMatch { public RLike(Source source, Expression value, RLikePattern pattern) { super(source, value, pattern, false); @@ -33,13 +32,4 @@ public String getWriteableName() { throw new UnsupportedOperationException(); } - @Override - protected NodeInfo info() { - return NodeInfo.create(this, RLike::new, field(), pattern(), caseInsensitive()); - } - - @Override - protected RLike replaceChild(Expression newChild) { - return new RLike(source(), newChild, pattern(), caseInsensitive()); - } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java index 32e8b04573d2d..0f9116ade5a31 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.expression.predicate.regex; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; @@ -64,11 +63,7 @@ public boolean foldable() { @Override public Boolean fold() { - Object val = field().fold(); - if (val instanceof BytesRef br) { - val = br.utf8ToString(); - } - return RegexOperation.match(val, pattern().asJavaRegex()); + throw new UnsupportedOperationException(); } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java index bf54744667217..05027707326bd 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/WildcardLike.java @@ -8,12 +8,11 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import java.io.IOException; -public class WildcardLike extends RegexMatch { +public abstract class WildcardLike extends RegexMatch { public WildcardLike(Source source, Expression left, WildcardPattern pattern) { this(source, left, pattern, false); @@ -33,14 +32,4 @@ public String getWriteableName() { throw new UnsupportedOperationException(); } - @Override - protected NodeInfo info() { - return NodeInfo.create(this, WildcardLike::new, field(), pattern(), caseInsensitive()); - } - - @Override - protected WildcardLike replaceChild(Expression newLeft) { - return new WildcardLike(source(), newLeft, pattern(), caseInsensitive()); - } - } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/util/TestUtils.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/util/TestUtils.java index 9f8e23cb15a97..b37ca0431ec2d 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/util/TestUtils.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/util/TestUtils.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -46,10 +44,6 @@ public static Range rangeOf(Expression value, Expression lower, boolean includeL return new Range(EMPTY, value, lower, includeLower, upper, includeUpper, randomZone()); } - public static RLike rlike(Expression left, String exp) { - return new RLike(EMPTY, left, new RLikePattern(exp)); - } - public static FieldAttribute fieldAttribute() { return fieldAttribute(randomAlphaOfLength(10), randomFrom(DataType.types())); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java index 8a4d44a690571..692c385cef216 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvAssert.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.compute.data.Page; import org.elasticsearch.logging.Logger; @@ -197,7 +198,13 @@ public static void assertData( for (int row = 0; row < expectedValues.size(); row++) { try { if (row >= actualValues.size()) { - dataFailure("Expected more data but no more entries found after [" + row + "]", dataFailures, expected, actualValues); + dataFailure( + "Expected more data but no more entries found after [" + row + "]", + dataFailures, + expected, + actualValues, + valueTransformer + ); } if (logger != null) { @@ -208,45 +215,17 @@ public static void assertData( var actualRow = actualValues.get(row); for (int column = 0; column < expectedRow.size(); column++) { - var expectedValue = expectedRow.get(column); - var actualValue = actualRow.get(column); var expectedType = expected.columnTypes().get(column); + var expectedValue = convertExpectedValue(expectedType, expectedRow.get(column)); + var actualValue = actualRow.get(column); - if (expectedValue != null) { - // convert the long from CSV back to its STRING form - if (expectedType == Type.DATETIME) { - expectedValue = rebuildExpected(expectedValue, Long.class, x -> UTC_DATE_TIME_FORMATTER.formatMillis((long) x)); - } else if (expectedType == Type.DATE_NANOS) { - expectedValue = rebuildExpected( - expectedValue, - Long.class, - x -> DateFormatter.forPattern("strict_date_optional_time_nanos").formatNanos((long) x) - ); - } else if (expectedType == Type.GEO_POINT) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.CARTESIAN_POINT) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> CARTESIAN.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.GEO_SHAPE) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.CARTESIAN_SHAPE) { - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> CARTESIAN.wkbToWkt((BytesRef) x)); - } else if (expectedType == Type.IP) { - // convert BytesRef-packed IP to String, allowing subsequent comparison with what's expected - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> DocValueFormat.IP.format((BytesRef) x)); - } else if (expectedType == Type.VERSION) { - // convert BytesRef-packed Version to String - expectedValue = rebuildExpected(expectedValue, BytesRef.class, x -> new Version((BytesRef) x).toString()); - } else if (expectedType == UNSIGNED_LONG) { - expectedValue = rebuildExpected(expectedValue, Long.class, x -> unsignedLongAsNumber((long) x)); - } - } var transformedExpected = valueTransformer.apply(expectedType, expectedValue); var transformedActual = valueTransformer.apply(expectedType, actualValue); if (Objects.equals(transformedExpected, transformedActual) == false) { dataFailures.add(new DataFailure(row, column, transformedExpected, transformedActual)); } if (dataFailures.size() > 10) { - dataFailure("", dataFailures, expected, actualValues); + dataFailure("", dataFailures, expected, actualValues, valueTransformer); } } @@ -255,7 +234,8 @@ public static void assertData( "Plan has extra columns, returned [" + actualRow.size() + "], expected [" + expectedRow.size() + "]", dataFailures, expected, - actualValues + actualValues, + valueTransformer ); } } catch (AssertionError ae) { @@ -267,10 +247,16 @@ public static void assertData( } } if (dataFailures.isEmpty() == false) { - dataFailure("", dataFailures, expected, actualValues); + dataFailure("", dataFailures, expected, actualValues, valueTransformer); } if (expectedValues.size() < actualValues.size()) { - dataFailure("Elasticsearch still has data after [" + expectedValues.size() + "] entries", dataFailures, expected, actualValues); + dataFailure( + "Elasticsearch still has data after [" + expectedValues.size() + "] entries", + dataFailures, + expected, + actualValues, + valueTransformer + ); } } @@ -278,42 +264,72 @@ private static void dataFailure( String description, List dataFailures, ExpectedResults expectedValues, - List> actualValues + List> actualValues, + BiFunction valueTransformer ) { - var expected = pipeTable("Expected:", expectedValues.columnNames(), expectedValues.values(), 25); - var actual = pipeTable("Actual:", expectedValues.columnNames(), actualValues, 25); + var expected = pipeTable( + "Expected:", + expectedValues.columnNames(), + expectedValues.columnTypes(), + expectedValues.values(), + (type, value) -> valueTransformer.apply(type, convertExpectedValue(type, value)) + ); + var actual = pipeTable("Actual:", expectedValues.columnNames(), expectedValues.columnTypes(), actualValues, valueTransformer); fail(description + System.lineSeparator() + describeFailures(dataFailures) + actual + expected); } - private static String pipeTable(String description, List headers, List> values, int maxRows) { + private static final int MAX_ROWS = 25; + + private static String pipeTable( + String description, + List headers, + List types, + List> values, + BiFunction valueTransformer + ) { + int rows = Math.min(MAX_ROWS, values.size()); int[] width = new int[headers.size()]; - for (int i = 0; i < width.length; i++) { - width[i] = headers.get(i).length(); - for (List row : values) { - width[i] = Math.max(width[i], String.valueOf(row.get(i)).length()); + String[][] printableValues = new String[rows][headers.size()]; + for (int c = 0; c < headers.size(); c++) { + width[c] = header(headers.get(c), types.get(c)).length(); + } + for (int r = 0; r < rows; r++) { + for (int c = 0; c < headers.size(); c++) { + printableValues[r][c] = String.valueOf(valueTransformer.apply(types.get(c), values.get(r).get(c))); + width[c] = Math.max(width[c], printableValues[r][c].length()); } } var result = new StringBuilder().append(System.lineSeparator()).append(description).append(System.lineSeparator()); - for (int c = 0; c < width.length; c++) { - appendValue(result, headers.get(c), width[c]); + // headers + appendPaddedValue(result, header(headers.get(0), types.get(0)), width[0]); + for (int c = 1; c < width.length; c++) { + result.append(" | "); + appendPaddedValue(result, header(headers.get(c), types.get(c)), width[c]); } - result.append('|').append(System.lineSeparator()); - for (int r = 0; r < Math.min(maxRows, values.size()); r++) { - for (int c = 0; c < width.length; c++) { - appendValue(result, values.get(r).get(c), width[c]); + result.append(System.lineSeparator()); + // values + for (int r = 0; r < printableValues.length; r++) { + appendPaddedValue(result, printableValues[r][0], width[0]); + for (int c = 1; c < printableValues[r].length; c++) { + result.append(" | "); + appendPaddedValue(result, printableValues[r][c], width[c]); } - result.append('|').append(System.lineSeparator()); + result.append(System.lineSeparator()); } - if (values.size() > maxRows) { + if (values.size() > rows) { result.append("...").append(System.lineSeparator()); } return result.toString(); } - private static void appendValue(StringBuilder result, Object value, int width) { - result.append('|').append(value); - for (int i = 0; i < width - String.valueOf(value).length(); i++) { + private static String header(String name, Type type) { + return name + ':' + Strings.toLowercaseAscii(type.name()); + } + + private static void appendPaddedValue(StringBuilder result, String value, int width) { + result.append(value); + for (int i = 0; i < width - (value != null ? value.length() : 4); i++) { result.append(' '); } } @@ -369,6 +385,34 @@ private static Comparator> resultRowComparator(List types) { }; } + private static Object convertExpectedValue(Type expectedType, Object expectedValue) { + if (expectedValue == null) { + return null; + } + + // convert the long from CSV back to its STRING form + return switch (expectedType) { + case Type.DATETIME -> rebuildExpected(expectedValue, Long.class, x -> UTC_DATE_TIME_FORMATTER.formatMillis((long) x)); + case Type.DATE_NANOS -> rebuildExpected( + expectedValue, + Long.class, + x -> DateFormatter.forPattern("strict_date_optional_time_nanos").formatNanos((long) x) + ); + case Type.GEO_POINT, Type.GEO_SHAPE -> rebuildExpected(expectedValue, BytesRef.class, x -> GEO.wkbToWkt((BytesRef) x)); + case Type.CARTESIAN_POINT, Type.CARTESIAN_SHAPE -> rebuildExpected( + expectedValue, + BytesRef.class, + x -> CARTESIAN.wkbToWkt((BytesRef) x) + ); + case Type.IP -> // convert BytesRef-packed IP to String, allowing subsequent comparison with what's expected + rebuildExpected(expectedValue, BytesRef.class, x -> DocValueFormat.IP.format((BytesRef) x)); + case Type.VERSION -> // convert BytesRef-packed Version to String + rebuildExpected(expectedValue, BytesRef.class, x -> new Version((BytesRef) x).toString()); + case UNSIGNED_LONG -> rebuildExpected(expectedValue, Long.class, x -> unsignedLongAsNumber((long) x)); + default -> expectedValue; + }; + } + private static Object rebuildExpected(Object expectedValue, Class clazz, Function mapper) { if (List.class.isAssignableFrom(expectedValue.getClass())) { assertThat(((List) expectedValue).get(0), instanceOf(clazz)); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index 3b656ded94dd7..abfe90f80e372 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -41,7 +41,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.CsvTestUtils.COMMA_ESCAPING_REGEX; @@ -260,11 +259,22 @@ public static void main(String[] args) throws IOException { public static Set availableDatasetsForEs(RestClient client, boolean supportsIndexModeLookup) throws IOException { boolean inferenceEnabled = clusterHasInferenceEndpoint(client); - return CSV_DATASET_MAP.values() - .stream() - .filter(d -> d.requiresInferenceEndpoint == false || inferenceEnabled) - .filter(d -> supportsIndexModeLookup || d.indexName.endsWith("_lookup") == false) // TODO: use actual index settings - .collect(Collectors.toCollection(HashSet::new)); + Set testDataSets = new HashSet<>(); + + for (TestsDataset dataset : CSV_DATASET_MAP.values()) { + if ((inferenceEnabled || dataset.requiresInferenceEndpoint == false) + && (supportsIndexModeLookup || isLookupDataset(dataset) == false)) { + testDataSets.add(dataset); + } + } + + return testDataSets; + } + + public static boolean isLookupDataset(TestsDataset dataset) throws IOException { + Settings settings = dataset.readSettingsFile(); + String mode = settings.get("index.mode"); + return (mode != null && mode.equalsIgnoreCase("lookup")); } public static void loadDataSetIntoEs(RestClient client, boolean supportsIndexModeLookup) throws IOException { @@ -354,13 +364,8 @@ private static void load(RestClient client, TestsDataset dataset, Logger logger, if (data == null) { throw new IllegalArgumentException("Cannot find resource " + dataName); } - Settings indexSettings = Settings.EMPTY; - final String settingName = dataset.settingFileName != null ? "/" + dataset.settingFileName : null; - if (settingName != null) { - indexSettings = Settings.builder() - .loadFromStream(settingName, CsvTestsDataLoader.class.getResourceAsStream(settingName), false) - .build(); - } + + Settings indexSettings = dataset.readSettingsFile(); indexCreator.createIndex(client, dataset.indexName, readMappingFile(mapping, dataset.typeMapping), indexSettings); loadCsvData(client, dataset.indexName, data, dataset.allowSubFields, logger); } @@ -669,6 +674,18 @@ public TestsDataset withTypeMapping(Map typeMapping) { public TestsDataset withInferenceEndpoint(boolean needsInference) { return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping, needsInference); } + + private Settings readSettingsFile() throws IOException { + Settings indexSettings = Settings.EMPTY; + final String settingName = settingFileName != null ? "/" + settingFileName : null; + if (settingName != null) { + indexSettings = Settings.builder() + .loadFromStream(settingName, CsvTestsDataLoader.class.getResourceAsStream(settingName), false) + .build(); + } + + return indexSettings; + } } public record EnrichConfig(String policyName, String policyFileName) {} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec index cde5427bf37d6..2b3b0bee93471 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec @@ -223,7 +223,7 @@ null | null | null ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteName#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | dissect full_name "%{emp_no} %{b}" | keep full_name, emp_no, b | limit 3; @@ -245,7 +245,7 @@ emp_no:integer | first_name:keyword | rest:keyword ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteNameWhere#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | dissect full_name "%{emp_no} %{b}" | where emp_no == "Bezalel" | keep full_name, emp_no, b | limit 3; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec index 592b06107c8b5..72660c11d8b73 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec @@ -601,3 +601,39 @@ Mokhtar |Bernatsky |38992 |BM Parto |Bamford |61805 |BP Premal |Baek |52833 |BP ; + + +caseInsensitiveRegex +from employees | where first_name RLIKE "(?i)geor.*" | keep first_name +; + +first_name:keyword +; + + +caseInsensitiveRegex2 +from employees | where first_name RLIKE "(?i)Geor.*" | keep first_name +; + +first_name:keyword +; + + +caseInsensitiveRegexFold +required_capability: fixed_regex_fold +row foo = "Bar" | where foo rlike "(?i)ba.*" +; + +foo:keyword +; + + +caseInsensitiveRegexFold2 +required_capability: fixed_regex_fold +row foo = "Bar" | where foo rlike "(?i)Ba.*" +; + +foo:keyword +; + + diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec index eece1bdfbffa4..6dc9148ffc0e8 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec @@ -199,7 +199,7 @@ null | null | null ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteName#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | grok full_name "%{WORD:emp_no} %{WORD:b}" | keep full_name, emp_no, b | limit 3; @@ -210,7 +210,7 @@ Parto Bamford | Parto | Bamford ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions overwriteNameWhere#[skip:-8.12.99] from employees | sort emp_no asc | eval full_name = concat(first_name, " ", last_name) | grok full_name "%{WORD:emp_no} %{WORD:b}" | where emp_no == "Bezalel" | keep full_name, emp_no, b | limit 3; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 100c0d716d65c..80586ce9bcb09 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -564,7 +564,7 @@ c:long | gender:keyword | trunk_worked_seconds:long 0 | null | 200000000 ; -// the query is incorrectly physically plan (fails the verification) in pre-8.13.0 versions +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions byStringAndLongWithAlias#[skip:-8.12.99] FROM employees | EVAL trunk_worked_seconds = avg_worked_seconds / 100000000 * 100000000 @@ -720,7 +720,8 @@ c:long | d:date | gender:keyword | languages:integer 2 | 1987-01-01T00:00:00.000Z | M | 1 ; -byDateAndKeywordAndIntWithAlias +// the query is incorrectly physically planned (fails the verification) in pre-8.13.0 versions +byDateAndKeywordAndIntWithAlias#[skip:-8.12.99] from employees | eval d = date_trunc(1 year, hire_date) | rename gender as g, languages as l, emp_no as e | keep d, g, l, e | stats c = count(e) by d, g, l | sort c desc, d, l desc, g desc | limit 10; c:long | d:date | g:keyword | l:integer diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 649ec1eba9785..e9a0f89e4f448 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -577,7 +577,12 @@ public enum Cap { /** * Additional types for match function and operator */ - MATCH_ADDITIONAL_TYPES; + MATCH_ADDITIONAL_TYPES, + + /** + * Fix for regex folding with case-insensitive pattern https://github.com/elastic/elasticsearch/issues/118371 + */ + FIXED_REGEX_FOLD; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java index cd42711177510..996c90a8e40bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java @@ -79,7 +79,7 @@ public String getWriteableName() { } @Override - protected NodeInfo info() { + protected NodeInfo info() { return NodeInfo.create(this, RLike::new, field(), pattern(), caseInsensitive()); } @@ -93,6 +93,11 @@ protected TypeResolution resolveType() { return isString(field(), sourceText(), DEFAULT); } + @Override + public Boolean fold() { + return (Boolean) EvaluatorMapper.super.fold(); + } + @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return AutomataMatch.toEvaluator(source(), toEvaluator.apply(field()), pattern().createAutomaton()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java index c1b4f20f41795..d2edb0f92e8f2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java @@ -99,6 +99,11 @@ protected TypeResolution resolveType() { return isString(field(), sourceText(), DEFAULT); } + @Override + public Boolean fold() { + return (Boolean) EvaluatorMapper.super.fold(); + } + @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return AutomataMatch.toEvaluator( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java index 6014e24e39c5f..b6fa82360d1e8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java @@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockStreamInput; import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; @@ -32,17 +34,24 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.security.authz.IndicesAndAliasesResolverField.NO_INDEX_PLACEHOLDER; +import static org.elasticsearch.xpack.core.security.authz.IndicesAndAliasesResolverField.NO_INDICES_OR_ALIASES_ARRAY; + final class DataNodeRequest extends TransportRequest implements IndicesRequest.Replaceable { + private static final Logger logger = LogManager.getLogger(DataNodeRequest.class); + private final String sessionId; private final Configuration configuration; private final String clusterAlias; - private final List shardIds; private final Map aliasFilters; private final PhysicalPlan plan; + private List shardIds; private String[] indices; private final IndicesOptions indicesOptions; @@ -115,6 +124,10 @@ public String[] indices() { @Override public IndicesRequest indices(String... indices) { this.indices = indices; + if (Arrays.equals(NO_INDICES_OR_ALIASES_ARRAY, indices) || Arrays.asList(indices).contains(NO_INDEX_PLACEHOLDER)) { + logger.trace(() -> format("Indices empty after index resolution, also clearing shardIds %s", shardIds)); + this.shardIds = Collections.emptyList(); + } return this; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 0a01dfc088cfb..1d10ebab267ce 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.VerificationException; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils; @@ -4910,6 +4911,8 @@ public void testPlanSanityCheck() throws Exception { } public void testPlanSanityCheckWithBinaryPlans() throws Exception { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + var plan = optimizedPlan(""" FROM test | RENAME languages AS language_code @@ -5913,6 +5916,8 @@ public void testLookupStats() { * \_EsRelation[language_code][LOOKUP][language_code{f}#18, language_name{f}#19] */ public void testLookupJoinPushDownFilterOnJoinKeyWithRename() { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + String query = """ FROM test | RENAME languages AS language_code @@ -5954,6 +5959,8 @@ public void testLookupJoinPushDownFilterOnJoinKeyWithRename() { * \_EsRelation[language_code][LOOKUP][language_code{f}#18, language_name{f}#19] */ public void testLookupJoinPushDownFilterOnLeftSideField() { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + String query = """ FROM test | RENAME languages AS language_code @@ -5996,6 +6003,8 @@ public void testLookupJoinPushDownFilterOnLeftSideField() { * \_EsRelation[language_code][LOOKUP][language_code{f}#18, language_name{f}#19] */ public void testLookupJoinPushDownDisabledForLookupField() { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + String query = """ FROM test | RENAME languages AS language_code @@ -6039,6 +6048,8 @@ public void testLookupJoinPushDownDisabledForLookupField() { * \_EsRelation[language_code][LOOKUP][language_code{f}#19, language_name{f}#20] */ public void testLookupJoinPushDownSeparatedForConjunctionBetweenLeftAndRightField() { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + String query = """ FROM test | RENAME languages AS language_code @@ -6090,6 +6101,8 @@ public void testLookupJoinPushDownSeparatedForConjunctionBetweenLeftAndRightFiel * \_EsRelation[language_code][LOOKUP][language_code{f}#19, language_name{f}#20] */ public void testLookupJoinPushDownDisabledForDisjunctionBetweenLeftAndRightField() { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + String query = """ FROM test | RENAME languages AS language_code diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index dc3ae0a3388cb..d0c7a1cd61010 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -2331,6 +2331,8 @@ public void testVerifierOnMissingReferences() { } public void testVerifierOnMissingReferencesWithBinaryPlans() throws Exception { + assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V5.isEnabled()); + // Do not assert serialization: // This will have a LookupJoinExec, which is not serializable because it doesn't leave the coordinator. var plan = physicalPlan(""" diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java index c2e85cc43284a..c4f4dac67acd3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java @@ -17,11 +17,11 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java index e159e5ed0bd7d..bc22fbb6bd828 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java @@ -199,7 +199,7 @@ public void testPushDownFilterOnAliasInEval() { public void testPushDownLikeRlikeFilter() { EsRelation relation = relation(); - org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike conditionA = rlike(getFieldAttribute("a"), "foo"); + RLike conditionA = rlike(getFieldAttribute("a"), "foo"); WildcardLike conditionB = wildcardLike(getFieldAttribute("b"), "bar"); Filter fa = new Filter(EMPTY, relation, conditionA); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java index 20d638a113bf2..c7206c6971bde 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java @@ -11,11 +11,11 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import static java.util.Arrays.asList; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java index 3dfc0f611eb2b..f2a619f0dbd89 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java @@ -65,7 +65,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected ClusterComputeRequest createTestInstance() { var sessionId = randomAlphaOfLength(10); String query = randomQuery(); - PhysicalPlan physicalPlan = DataNodeRequestTests.mapAndMaybeOptimize(parse(query)); + PhysicalPlan physicalPlan = DataNodeRequestSerializationTests.mapAndMaybeOptimize(parse(query)); OriginalIndices originalIndices = new OriginalIndices( generateRandomStringArray(10, 10, false, false), IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java new file mode 100644 index 0000000000000..d1ce064c35d81 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java @@ -0,0 +1,289 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.analysis.Analyzer; +import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.index.EsIndex; +import org.elasticsearch.xpack.esql.index.IndexResolution; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; +import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; +import org.elasticsearch.xpack.esql.parser.EsqlParser; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.planner.mapper.Mapper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; +import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; + +public class DataNodeRequestSerializationTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return DataNodeRequest::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List writeables = new ArrayList<>(); + writeables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables()); + writeables.addAll(new EsqlPlugin().getNamedWriteables()); + return new NamedWriteableRegistry(writeables); + } + + @Override + protected DataNodeRequest createTestInstance() { + var sessionId = randomAlphaOfLength(10); + String query = randomFrom(""" + from test + | where round(emp_no) > 10 + | eval c = salary + | stats x = avg(c) + """, """ + from test + | sort last_name + | limit 10 + | where round(emp_no) > 10 + | eval c = first_name + | stats x = avg(salary) + """); + List shardIds = randomList(1, 10, () -> new ShardId("index-" + between(1, 10), "n/a", between(1, 10))); + PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query)); + Map aliasFilters = Map.of( + new Index("concrete-index", "n/a"), + AliasFilter.of(new TermQueryBuilder("id", "1"), "alias-1") + ); + DataNodeRequest request = new DataNodeRequest( + sessionId, + randomConfiguration(query, randomTables()), + randomAlphaOfLength(10), + shardIds, + aliasFilters, + physicalPlan, + generateRandomStringArray(10, 10, false, false), + IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) + ); + request.setParentTask(randomAlphaOfLength(10), randomNonNegativeLong()); + return request; + } + + @Override + protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException { + return switch (between(0, 8)) { + case 0 -> { + var request = new DataNodeRequest( + randomAlphaOfLength(20), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 1 -> { + var request = new DataNodeRequest( + in.sessionId(), + randomConfiguration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 2 -> { + List shardIds = randomList(1, 10, () -> new ShardId("new-index-" + between(1, 10), "n/a", between(1, 10))); + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + shardIds, + in.aliasFilters(), + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 3 -> { + String newQuery = randomFrom(""" + from test + | where round(emp_no) > 100 + | eval c = salary + | stats x = avg(c) + """, """ + from test + | sort last_name + | limit 10 + | where round(emp_no) > 100 + | eval c = first_name + | stats x = avg(salary) + """); + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + mapAndMaybeOptimize(parse(newQuery)), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(in.getParentTask()); + yield request; + } + case 4 -> { + final Map aliasFilters; + if (randomBoolean()) { + aliasFilters = Map.of(); + } else { + aliasFilters = Map.of(new Index("concrete-index", "n/a"), AliasFilter.of(new TermQueryBuilder("id", "2"), "alias-2")); + } + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + aliasFilters, + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(request.getParentTask()); + yield request; + } + case 5 -> { + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask( + randomValueOtherThan(request.getParentTask().getNodeId(), () -> randomAlphaOfLength(10)), + randomNonNegativeLong() + ); + yield request; + } + case 6 -> { + var clusterAlias = randomValueOtherThan(in.clusterAlias(), () -> randomAlphaOfLength(10)); + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + clusterAlias, + in.shardIds(), + in.aliasFilters(), + in.plan(), + in.indices(), + in.indicesOptions() + ); + request.setParentTask(request.getParentTask()); + yield request; + } + case 7 -> { + var indices = randomValueOtherThan(in.indices(), () -> generateRandomStringArray(10, 10, false, false)); + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + in.plan(), + indices, + in.indicesOptions() + ); + request.setParentTask(request.getParentTask()); + yield request; + } + case 8 -> { + var indicesOptions = randomValueOtherThan( + in.indicesOptions(), + () -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) + ); + var request = new DataNodeRequest( + in.sessionId(), + in.configuration(), + in.clusterAlias(), + in.shardIds(), + in.aliasFilters(), + in.plan(), + in.indices(), + indicesOptions + ); + request.setParentTask(request.getParentTask()); + yield request; + } + default -> throw new AssertionError("invalid value"); + }; + } + + static LogicalPlan parse(String query) { + Map mapping = loadMapping("mapping-basic.json"); + EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); + IndexResolution getIndexResult = IndexResolution.valid(test); + var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var analyzer = new Analyzer( + new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), + TEST_VERIFIER + ); + return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query))); + } + + static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { + var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(TEST_CFG)); + var mapper = new Mapper(); + var physical = mapper.map(logicalPlan); + if (randomBoolean()) { + physical = physicalPlanOptimizer.optimize(physical); + } + return physical; + } + + @Override + protected List filteredWarnings() { + return withDefaultLimitWarning(super.filteredWarnings()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 4553551c40cd3..d0c5ddd0dc927 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -8,282 +8,48 @@ package org.elasticsearch.xpack.esql.plugin; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.Index; -import org.elasticsearch.index.IndexMode; -import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.search.SearchModule; -import org.elasticsearch.search.internal.AliasFilter; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.esql.EsqlTestUtils; -import org.elasticsearch.xpack.esql.analysis.Analyzer; -import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; -import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; -import org.elasticsearch.xpack.esql.index.EsIndex; -import org.elasticsearch.xpack.esql.index.IndexResolution; -import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; -import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; -import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; -import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; -import org.elasticsearch.xpack.esql.parser.EsqlParser; -import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; -import org.elasticsearch.xpack.esql.planner.mapper.Mapper; +import org.elasticsearch.test.ESTestCase; -import java.io.IOException; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; -import java.util.Map; +import static org.elasticsearch.xpack.core.security.authz.IndicesAndAliasesResolverField.NO_INDEX_PLACEHOLDER; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; -public class DataNodeRequestTests extends AbstractWireSerializingTestCase { +public class DataNodeRequestTests extends ESTestCase { - @Override - protected Writeable.Reader instanceReader() { - return DataNodeRequest::new; - } - - @Override - protected NamedWriteableRegistry getNamedWriteableRegistry() { - List writeables = new ArrayList<>(); - writeables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables()); - writeables.addAll(new EsqlPlugin().getNamedWriteables()); - return new NamedWriteableRegistry(writeables); - } - - @Override - protected DataNodeRequest createTestInstance() { + public void testNoIndexPlaceholder() { var sessionId = randomAlphaOfLength(10); - String query = randomFrom(""" - from test - | where round(emp_no) > 10 - | eval c = salary - | stats x = avg(c) - """, """ - from test - | sort last_name - | limit 10 - | where round(emp_no) > 10 - | eval c = first_name - | stats x = avg(salary) - """); List shardIds = randomList(1, 10, () -> new ShardId("index-" + between(1, 10), "n/a", between(1, 10))); - PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query)); - Map aliasFilters = Map.of( - new Index("concrete-index", "n/a"), - AliasFilter.of(new TermQueryBuilder("id", "1"), "alias-1") - ); + DataNodeRequest request = new DataNodeRequest( sessionId, - randomConfiguration(query, randomTables()), + randomConfiguration(""" + from test + | where round(emp_no) > 10 + | eval c = salary + | stats x = avg(c) + """, randomTables()), randomAlphaOfLength(10), shardIds, - aliasFilters, - physicalPlan, + Collections.emptyMap(), + null, generateRandomStringArray(10, 10, false, false), IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) ); - request.setParentTask(randomAlphaOfLength(10), randomNonNegativeLong()); - return request; - } - @Override - protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException { - return switch (between(0, 8)) { - case 0 -> { - var request = new DataNodeRequest( - randomAlphaOfLength(20), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(in.getParentTask()); - yield request; - } - case 1 -> { - var request = new DataNodeRequest( - in.sessionId(), - randomConfiguration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(in.getParentTask()); - yield request; - } - case 2 -> { - List shardIds = randomList(1, 10, () -> new ShardId("new-index-" + between(1, 10), "n/a", between(1, 10))); - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - shardIds, - in.aliasFilters(), - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(in.getParentTask()); - yield request; - } - case 3 -> { - String newQuery = randomFrom(""" - from test - | where round(emp_no) > 100 - | eval c = salary - | stats x = avg(c) - """, """ - from test - | sort last_name - | limit 10 - | where round(emp_no) > 100 - | eval c = first_name - | stats x = avg(salary) - """); - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - mapAndMaybeOptimize(parse(newQuery)), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(in.getParentTask()); - yield request; - } - case 4 -> { - final Map aliasFilters; - if (randomBoolean()) { - aliasFilters = Map.of(); - } else { - aliasFilters = Map.of(new Index("concrete-index", "n/a"), AliasFilter.of(new TermQueryBuilder("id", "2"), "alias-2")); - } - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - aliasFilters, - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(request.getParentTask()); - yield request; - } - case 5 -> { - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask( - randomValueOtherThan(request.getParentTask().getNodeId(), () -> randomAlphaOfLength(10)), - randomNonNegativeLong() - ); - yield request; - } - case 6 -> { - var clusterAlias = randomValueOtherThan(in.clusterAlias(), () -> randomAlphaOfLength(10)); - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - clusterAlias, - in.shardIds(), - in.aliasFilters(), - in.plan(), - in.indices(), - in.indicesOptions() - ); - request.setParentTask(request.getParentTask()); - yield request; - } - case 7 -> { - var indices = randomValueOtherThan(in.indices(), () -> generateRandomStringArray(10, 10, false, false)); - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - in.plan(), - indices, - in.indicesOptions() - ); - request.setParentTask(request.getParentTask()); - yield request; - } - case 8 -> { - var indicesOptions = randomValueOtherThan( - in.indicesOptions(), - () -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()) - ); - var request = new DataNodeRequest( - in.sessionId(), - in.configuration(), - in.clusterAlias(), - in.shardIds(), - in.aliasFilters(), - in.plan(), - in.indices(), - indicesOptions - ); - request.setParentTask(request.getParentTask()); - yield request; - } - default -> throw new AssertionError("invalid value"); - }; - } + assertThat(request.shardIds(), equalTo(shardIds)); - static LogicalPlan parse(String query) { - Map mapping = loadMapping("mapping-basic.json"); - EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); - IndexResolution getIndexResult = IndexResolution.valid(test); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); - var analyzer = new Analyzer( - new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), - TEST_VERIFIER - ); - return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query))); - } + request.indices(generateRandomStringArray(10, 10, false, false)); - static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { - var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(TEST_CFG)); - var mapper = new Mapper(); - var physical = mapper.map(logicalPlan); - if (randomBoolean()) { - physical = physicalPlanOptimizer.optimize(physical); - } - return physical; - } + assertThat(request.shardIds(), equalTo(shardIds)); + + request.indices(NO_INDEX_PLACEHOLDER); - @Override - protected List filteredWarnings() { - return withDefaultLimitWarning(super.filteredWarnings()); + assertThat(request.shardIds(), empty()); } } diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/DeleteSamlServiceProviderRequestTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/DeleteSamlServiceProviderRequestTests.java index 4beda3cc18792..110605b6a6de5 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/DeleteSamlServiceProviderRequestTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/DeleteSamlServiceProviderRequestTests.java @@ -29,7 +29,7 @@ public void testSerialization() throws IOException { ); final TransportVersion version = TransportVersionUtils.randomVersionBetween( random(), - TransportVersions.V_7_7_0, + TransportVersions.V_8_0_0, TransportVersion.current() ); final DeleteSamlServiceProviderRequest read = copyWriteable( diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/PutSamlServiceProviderRequestTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/PutSamlServiceProviderRequestTests.java index 60675b9355973..c9b1008ae795a 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/PutSamlServiceProviderRequestTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/PutSamlServiceProviderRequestTests.java @@ -87,7 +87,7 @@ public void testSerialization() throws IOException { final PutSamlServiceProviderRequest request = new PutSamlServiceProviderRequest(doc, RefreshPolicy.NONE); final TransportVersion version = TransportVersionUtils.randomVersionBetween( random(), - TransportVersions.V_7_7_0, + TransportVersions.V_8_0_0, TransportVersion.current() ); final PutSamlServiceProviderRequest read = copyWriteable( diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/SamlServiceProviderDocumentTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/SamlServiceProviderDocumentTests.java index 233702d7ddd9a..b6c20cebe595f 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/SamlServiceProviderDocumentTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/sp/SamlServiceProviderDocumentTests.java @@ -162,7 +162,7 @@ private SamlServiceProviderDocument assertXContentRoundTrip(SamlServiceProviderD private SamlServiceProviderDocument assertSerializationRoundTrip(SamlServiceProviderDocument doc) throws IOException { final TransportVersion version = TransportVersionUtils.randomVersionBetween( random(), - TransportVersions.V_7_7_0, + TransportVersions.V_8_0_0, TransportVersion.current() ); final SamlServiceProviderDocument read = copyWriteable( diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlAuthenticationStateTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlAuthenticationStateTests.java index 934126e7f0fa0..61018fe77d47e 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlAuthenticationStateTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlAuthenticationStateTests.java @@ -90,7 +90,7 @@ private SamlAuthenticationState assertXContentRoundTrip(SamlAuthenticationState private SamlAuthenticationState assertSerializationRoundTrip(SamlAuthenticationState state) throws IOException { final TransportVersion version = TransportVersionUtils.randomVersionBetween( random(), - TransportVersions.V_7_7_0, + TransportVersions.V_8_0_0, TransportVersion.current() ); final SamlAuthenticationState read = copyWriteable( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index f5f682b143a72..a6888f28159f4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,7 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -37,7 +37,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import java.io.IOException; @@ -151,7 +151,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { @@ -176,9 +176,24 @@ private InferenceTextEmbeddingFloatResults makeResults(List input, int d return new InferenceTextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, int dimensions) { + private List makeChunkedResults(List input, int dimensions) { InferenceTextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions); - return InferenceChunkedTextEmbeddingFloatResults.listOf(input, nonChunkedResults); + + var results = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + nonChunkedResults.embeddings().get(i).values(), + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) + ) + ); + } + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index fa1e27005c287..bbb773aa5129a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -139,7 +139,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 64569fd8c5c6a..eea64304f503a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -35,9 +35,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -142,7 +141,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input)); @@ -167,16 +166,22 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List input) { - List results = new ArrayList<>(); + private List makeChunkedResults(List input) { + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } results.add( - new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)) + new ChunkedInferenceEmbeddingSparse( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 80696a285fb26..a39be00d9e7fa 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -234,7 +234,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index a4187f4c4fa90..71fbcf6d8ef49 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -18,10 +18,6 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; @@ -108,7 +104,6 @@ public static List getNamedWriteables() { ); addInferenceResultsNamedWriteables(namedWriteables); - addChunkedInferenceResultsNamedWriteables(namedWriteables); // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); @@ -433,37 +428,6 @@ private static void addInternalNamedWriteables(List namedWriteables) { - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - ErrorChunkedInferenceResults.NAME, - ErrorChunkedInferenceResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedSparseEmbeddingResults.NAME, - InferenceChunkedSparseEmbeddingResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingFloatResults.NAME, - InferenceChunkedTextEmbeddingFloatResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingByteResults.NAME, - InferenceChunkedTextEmbeddingByteResults::new - ) - ); - } - private static void addChunkingSettingsNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index d178e927aa65d..a9195ea24af3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -29,7 +29,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -37,11 +37,12 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -141,7 +142,7 @@ private record FieldInferenceResponse( int inputOrder, boolean isOriginalFieldInput, Model model, - ChunkedInferenceServiceResults chunkedResults + ChunkedInference chunkedResults ) {} private record FieldInferenceResponseAccumulator( @@ -273,19 +274,19 @@ public void onFailure(Exception exc) { final List currentBatch = requests.subList(0, currentBatchSize); final List nextBatch = requests.subList(currentBatchSize, requests.size()); final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); - ActionListener> completionListener = new ActionListener<>() { + ActionListener> completionListener = new ActionListener<>() { @Override - public void onResponse(List results) { + public void onResponse(List results) { try { var requestsIterator = requests.iterator(); - for (ChunkedInferenceServiceResults result : results) { + for (ChunkedInference result : results) { var request = requestsIterator.next(); var acc = inferenceResults.get(request.index); - if (result instanceof ErrorChunkedInferenceResults error) { + if (result instanceof ChunkedInferenceError error) { acc.addFailure( new ElasticsearchException( "Exception when running inference id [{}] on field [{}]", - error.getException(), + error.exception(), inferenceProvider.model.getInferenceEntityId(), request.field ) @@ -359,7 +360,7 @@ private void addInferenceResponseFailure(int id, Exception failure) { * Otherwise, the source of the request is augmented with the field inference results under the * {@link SemanticTextField#INFERENCE_FIELD} field. */ - private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException { if (response.failures().isEmpty() == false) { for (var failure : response.failures()) { item.abort(item.index(), failure); @@ -376,7 +377,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); - List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); var result = new SemanticTextField( fieldName, inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2aef54e56f4b9..9b0b1104df660 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -11,18 +11,17 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import java.util.ArrayList; import java.util.List; @@ -72,8 +71,8 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private List>> floatResults; private List>> byteResults; private List>> sparseResults; - private AtomicArray errors; - private ActionListener> finalListener; + private AtomicArray errors; + private ActionListener> finalListener; public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, EmbeddingType embeddingType) { this(inputs, maxNumberOfInputsPerBatch, DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP, embeddingType); @@ -189,7 +188,7 @@ private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { * @param finalListener The listener to call once all the batches are processed * @return Batches and listeners */ - public List batchRequestsWithListeners(ActionListener> finalListener) { + public List batchRequestsWithListeners(ActionListener> finalListener) { this.finalListener = finalListener; int numberOfRequests = batchedRequests.size(); @@ -331,9 +330,8 @@ private ElasticsearchStatusException unexpectedResultTypeException(String got, S @Override public void onFailure(Exception e) { - var errorResult = new ErrorChunkedInferenceResults(e); for (var pos : positions) { - errors.set(pos.inputIndex(), errorResult); + errors.set(pos.inputIndex(), e); } if (resultCount.incrementAndGet() == totalNumberOfRequests) { @@ -342,10 +340,10 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedOffsets.size()); + var response = new ArrayList(chunkedOffsets.size()); for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { - response.add(errors.get(i)); + response.add(new ChunkedInferenceError(errors.get(i))); } else { response.add(mergeResultsWithInputs(i)); } @@ -355,16 +353,16 @@ private void sendResponse() { } } - private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { + private ChunkedInference mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex), sparseResults.get(resultIndex)); }; } - private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -375,18 +373,22 @@ private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(embeddingChunks); + return new ChunkedInferenceEmbeddingFloat(embeddingChunks); } - private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -397,18 +399,22 @@ private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false); + return new ChunkedInferenceEmbeddingByte(embeddingChunks); } - private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -419,12 +425,18 @@ private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { - embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), all.get(i).tokens())); + embeddingChunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + all.get(i).tokens(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(embeddingChunks); + return new ChunkedInferenceEmbeddingSparse(embeddingChunks); } public record BatchRequest(List subBatches) { @@ -460,5 +472,13 @@ record ChunkOffsetsAndInput(List offsets, String input) { List toChunkText() { return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } + + int size() { + return offsets.size(); + } + + String chunkText(int index) { + return input.substring(offsets.get(index).start(), offsets.get(index).end()); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 0f26f6577860f..d651729dee259 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -31,7 +31,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -70,7 +69,7 @@ public record SemanticTextField(String fieldName, List originalValues, I public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} - public record Chunk(String text, BytesReference rawEmbeddings) {} + record Chunk(String text, BytesReference rawEmbeddings) {} public record ModelSettings( TaskType taskType, @@ -307,13 +306,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. */ - public static List toSemanticTextFieldChunks(List results, XContentType contentType) { + public static List toSemanticTextFieldChunks(List results, XContentType contentType) throws IOException { List chunks = new ArrayList<>(); for (var result : results) { - for (Iterator it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it - .hasNext();) { + for (var it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it.hasNext();) { var chunkAsByteReference = it.next(); chunks.add(new Chunk(chunkAsByteReference.matchedText(), chunkAsByteReference.bytesReference())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index e9b75e9ec7796..7ea68b2c9bf1e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -13,7 +13,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -102,7 +102,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { init(); // a non-null query is not supported and is dropped by all providers @@ -131,7 +131,7 @@ protected abstract void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); public void start(Model model, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index ffd26b9ac534d..42a276c6ee838 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -299,7 +299,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AlibabaCloudSearchModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index d224e50bb650d..a88881220f933 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -16,7 +16,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -126,7 +126,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index f1840af18779f..41852e4758a8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -232,7 +232,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { throw new UnsupportedOperationException("Anthropic service does not support chunked inference"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index f8ea11e4b15a5..bcd4a1abfbf00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -119,7 +119,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index a38c265d2613c..0ed69604c258a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -273,7 +273,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureOpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index ccb8d79dacd6c..a7d17192bfa92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -272,7 +272,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof CohereModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index fe8ee52eb8816..27abc03cb6e82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -30,8 +30,8 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; @@ -121,7 +121,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { // Pass-through without actually performing chunking (result will have a single chunk per input) ActionListener inferListener = listener.delegateFailureAndWrap( @@ -273,14 +273,12 @@ public void checkModelConfig(Model model, ActionListener listener) { } } - private static List translateToChunkedResults( - InferenceInputs inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(DocumentsOnlyInput.of(inputs).getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClass = Strings.format("%s", SparseEmbeddingResults.class.getSimpleName()); throw createInvalidChunkedResultException(expectedClass, inferenceResults.getWriteableName()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 5f613d6be5869..cb3eb5ca8e826 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceResults; @@ -699,7 +699,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); } @@ -712,7 +712,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if ((TaskType.TEXT_EMBEDDING.equals(model.getTaskType()) || TaskType.SPARSE_EMBEDDING.equals(model.getTaskType())) == false) { listener.onFailure( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index b681722a82136..837a001d1f8f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -326,7 +326,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 87a2d98dca92c..b412f20289880 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -225,7 +225,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index b74ec01cd76e7..acb082cd2de8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -119,7 +119,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof HuggingFaceModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 5b038781b96af..01d2a75ebff5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -28,9 +28,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -44,12 +44,14 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; @@ -100,7 +102,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { ActionListener inferListener = listener.delegateFailureAndWrap( (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) @@ -110,16 +112,31 @@ protected void doChunkedInfer( doInfer(model, inputs, taskSettings, inputType, timeout, inferListener); } - private static List translateToChunkedResults( - DocumentsOnlyInput inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(DocumentsOnlyInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs.getInputs(), textEmbeddingResults); + validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.getInputs().size()); + + for (int i = 0; i < inputs.getInputs().size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + textEmbeddingResults.embeddings().get(i).values(), + inputs.getInputs().get(i), + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + ) + ) + ) + ); + } + return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(inputs.getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClasses = Strings.format( "One of [%s,%s]", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index cc66d5fd7ee74..482554e060a47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -295,7 +295,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 881e7d36f2a21..dc0576651aeba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -107,7 +107,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 7b51b068708ca..ed7a01e829dd2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -291,7 +291,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof OpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index c68a629b999c5..0b7d136ffb04c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -34,8 +34,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -57,9 +57,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.containsString; @@ -160,8 +160,8 @@ public void testItemFailures() throws Exception { Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10) ); - model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -290,10 +290,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[6]; + ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { - List results = new ArrayList<>(); + List results = new ArrayList<>(); for (String input : inputs) { results.add(model.getResults(input)); } @@ -348,7 +347,7 @@ private static BulkItemRequest[] randomBulkItemRequest( // This prevents a situation where embeddings in the expected docMap do not match those in the model, which could happen if // embeddings were overwritten. if (model.hasResult(inputText)) { - ChunkedInferenceServiceResults results = model.getResults(inputText); + var results = model.getResults(inputText); semanticTextField = semanticTextFieldFromChunkedInferenceResults( field, model, @@ -371,7 +370,7 @@ private static BulkItemRequest[] randomBulkItemRequest( } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -397,11 +396,11 @@ public static StaticModel createRandomInstance() { ); } - ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new InferenceChunkedSparseEmbeddingResults(List.of())); + ChunkedInference getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedInferenceEmbeddingSparse(List.of())); } - void putResult(String text, ChunkedInferenceServiceResults result) { + void putResult(String text, ChunkedInference result) { resultMap.put(text, result); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index dec7d15760aa6..03249163c7f82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,12 +8,12 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -313,16 +313,16 @@ public void testMergingListener_Float() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -333,15 +333,15 @@ public void testMergingListener_Float() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText()); } @@ -386,16 +386,16 @@ public void testMergingListener_Byte() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -406,15 +406,15 @@ public void testMergingListener_Byte() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); } @@ -466,34 +466,34 @@ public void testMergingListener_Sparse() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("1st small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("2nd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("3rd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(chunkedSparseResult.getChunkedResults().get(0).matchedText(), startsWith("passage_input0 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(1).matchedText(), startsWith(" passage_input10 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(8).matchedText(), startsWith(" passage_input80 ")); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each + assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 ")); + assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 ")); } } @@ -501,13 +501,13 @@ public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of("1st small", "2nd small", "3rd small"); var failureMessage = new AtomicReference(); - var listener = new ActionListener>() { + var listener = new ActionListener>() { @Override - public void onResponse(List chunkedInferenceServiceResults) { - assertThat(chunkedInferenceServiceResults.get(0), instanceOf(ErrorChunkedInferenceResults.class)); - var error = (ErrorChunkedInferenceResults) chunkedInferenceServiceResults.get(0); - failureMessage.set(error.getException().getMessage()); + public void onResponse(List chunkedResults) { + assertThat(chunkedResults.get(0), instanceOf(ChunkedInferenceError.class)); + var error = (ChunkedInferenceError) chunkedResults.get(0); + failureMessage.set(error.exception().getMessage()); } @Override @@ -531,12 +531,12 @@ private ChunkedResultsListener testListener() { return new ChunkedResultsListener(); } - private static class ChunkedResultsListener implements ActionListener> { - List results; + private static class ChunkedResultsListener implements ActionListener> { + List results; @Override - public void onResponse(List chunkedInferenceServiceResults) { - this.results = chunkedInferenceServiceResults; + public void onResponse(List chunks) { + this.results = chunks; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 563093930c358..dcdd9b3d42341 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -19,9 +19,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.model.TestModel; @@ -158,38 +157,39 @@ public void testModelSettingsValidation() { assertThat(ex.getMessage(), containsString("required [element_type] field is missing")); } - public static InferenceChunkedTextEmbeddingFloatResults randomInferenceChunkedTextEmbeddingFloatResults( - Model model, - List inputs - ) throws IOException { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingFloat randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { float[] values = new float[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = (float) randomDouble(); } - chunks.add(new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(input, values)); + chunks.add( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(values, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } - public static InferenceChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingSparse randomChunkedInferenceEmbeddingSparse(List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { var tokens = new ArrayList(); for (var token : input.split("\\s+")) { tokens.add(new WeightedToken(token, randomFloat())); } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(input, tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(tokens, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) throws IOException { - ChunkedInferenceServiceResults results = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomInferenceChunkedTextEmbeddingFloatResults(model, inputs); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + ChunkedInference results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType); @@ -199,9 +199,9 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( String fieldName, Model model, List inputs, - ChunkedInferenceServiceResults results, + ChunkedInference results, XContentType contentType - ) { + ) throws IOException { return new SemanticTextField( fieldName, inputs, @@ -232,18 +232,24 @@ public static Object randomSemanticTextInput() { } } - public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) throws IOException { + public static ChunkedInference toChunkedResult(SemanticTextField field) throws IOException { switch (field.inference().modelSettings().taskType()) { case SPARSE_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + chunk.text(), + new ChunkedInference.TextOffset(0, chunk.text().length()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } case TEXT_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { double[] values = parseDenseVector( chunk.rawEmbeddings(), @@ -251,13 +257,14 @@ public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField f field.contentType() ); chunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + FloatConversionUtils.floatArrayOf(values), chunk.text(), - FloatConversionUtils.floatArrayOf(values) + new ChunkedInference.TextOffset(0, chunk.text().length()) ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java deleted file mode 100644 index 4be00ea9e5822..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; - -import java.io.IOException; - -public class ErrorChunkedInferenceResultsTests extends AbstractWireSerializingTestCase { - - public static ErrorChunkedInferenceResults createRandomResults() { - return new ErrorChunkedInferenceResults( - randomBoolean() - ? new ElasticsearchTimeoutException(randomAlphaOfLengthBetween(10, 50)) - : new ElasticsearchStatusException(randomAlphaOfLengthBetween(10, 50), randomFrom(RestStatus.values())) - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return ErrorChunkedInferenceResults::new; - } - - @Override - protected ErrorChunkedInferenceResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected ErrorChunkedInferenceResults mutateInstance(ErrorChunkedInferenceResults instance) throws IOException { - return new ErrorChunkedInferenceResults(new RuntimeException(randomAlphaOfLengthBetween(10, 50))); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java deleted file mode 100644 index 8685ad9f0e124..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedSparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { - - public static InferenceChunkedSparseEmbeddingResults createRandomResults() { - var chunks = new ArrayList(); - int numChunks = randomIntBetween(1, 5); - - for (int i = 0; i < numChunks; i++) { - var tokenWeights = new ArrayList(); - int numTokens = randomIntBetween(1, 8); - for (int j = 0; j < numTokens; j++) { - tokenWeights.add(new WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false))); - } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights)); - } - - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult("text", List.of(new WeightedToken("token", 0.1f)))) - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_FromSparseEmbeddingResults() { - var entity = InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text", "text2"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedSparseEmbeddingResults::new; - } - - @Override - protected InferenceChunkedSparseEmbeddingResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedSparseEmbeddingResults mutateInstance(InferenceChunkedSparseEmbeddingResults instance) throws IOException { - return randomValueOtherThan(instance, InferenceChunkedSparseEmbeddingResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java deleted file mode 100644 index c1215e8a3d71b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase< - InferenceChunkedTextEmbeddingByteResults> { - - public static InferenceChunkedTextEmbeddingByteResults createRandomResults() { - int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); - - for (int i = 0; i < numChunks; i++) { - chunks.add(createRandomChunk()); - } - - return new InferenceChunkedTextEmbeddingByteResults(chunks, randomBoolean()); - } - - private static InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk createRandomChunk() { - int columns = randomIntBetween(1, 10); - byte[] bytes = new byte[columns]; - for (int i = 0; i < columns; i++) { - bytes[i] = randomByte(); - } - - return new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(randomAlphaOfLength(6), bytes); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedTextEmbeddingByteResults( - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })), - false - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingByteResults() { - var entity = InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text", "text2"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedTextEmbeddingByteResults::new; - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults mutateInstance(InferenceChunkedTextEmbeddingByteResults instance) - throws IOException { - return randomValueOtherThan(instance, InferenceChunkedTextEmbeddingByteResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 6768583598b2d..f8a5bd8a54d31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -135,7 +135,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index a154ded395822..46c3a062f7db0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -29,8 +29,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -392,7 +392,7 @@ public void testChunkedInfer_InvalidTaskType() throws IOException { null ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); try { service.chunkedInfer( model, @@ -417,7 +417,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { var model = createModelForTaskType(taskType, chunkingSettings); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -425,9 +425,9 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin assertThat(results, hasSize(2)); var firstResult = results.get(0); if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 35b5642b7a60c..7ca66d54d8ac3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; @@ -1551,7 +1551,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep requestSender.enqueue(mockResults2); } - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -1565,15 +1565,15 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123F, 0.678F }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("xyz", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 8636ba8890e87..fe60eaa760e69 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1186,7 +1186,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -1200,15 +1200,15 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.0123f, -1.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index b0c590e237a44..cf79d70749fda 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1335,7 +1335,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); model.setUri(new URI(getUrl(webServer))); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -1349,15 +1349,15 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.123f, -1.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 259a32aa6254d..bf67db3b2a117 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,8 +36,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1442,7 +1442,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1457,15 +1457,15 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); @@ -1533,7 +1533,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { "model", CohereEmbeddingType.BYTE ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1548,15 +1548,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingByteResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var floatResult = (ChunkedInferenceEmbeddingByte) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 23, -23 }, floatResult.chunks().get(0).embedding()); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var byteResult = (InferenceChunkedTextEmbeddingByteResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var byteResult = (ChunkedInferenceEmbeddingByte) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals("bar", byteResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 24, -24 }, byteResult.chunks().get(0).embedding()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d3101099d06c7..6cdd58c4cd99f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -31,8 +31,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -453,7 +453,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -465,22 +465,21 @@ public void testChunkedInfer_PassesThrough() throws IOException { ); var results = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat( - results.get(0).asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of( - ChunkedNlpInferenceResults.TEXT, - "input text", - ChunkedNlpInferenceResults.INFERENCE, - Map.of("hello", 2.1259406f, "greet", 1.7073475f) - ) + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) results.get(0); + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + "input text", + new ChunkedInference.TextOffset(0, "input text".length()) ) ) ) ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 17e6583f11c8f..21d7efbc7b03c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -24,7 +24,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -42,9 +42,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -865,26 +865,26 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result1 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result1 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - result1.getChunks().get(0).embedding(), + result1.chunks().get(0).embedding(), 0.0001f ); - assertEquals("foo", result1.getChunks().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result2 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result2 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - result2.getChunks().get(0).embedding(), + result2.chunks().get(0).embedding(), 0.0001f ); - assertEquals("bar", result2.getChunks().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -940,22 +940,22 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1010,22 +1010,22 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1126,12 +1126,12 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(3)); // a single failure fails the batch for (var er : chunkedResponse) { - assertThat(er, instanceOf(ErrorChunkedInferenceResults.class)); - assertEquals("boom", ((ErrorChunkedInferenceResults) er).getException().getMessage()); + assertThat(er, instanceOf(ChunkedInferenceError.class)); + assertEquals("boom", ((ChunkedInferenceError) er).exception().getMessage()); } gotResults.set(true); @@ -1190,10 +1190,10 @@ public void testChunkingLargeDocument() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(1)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var sparseResults = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(sparseResults.chunks(), hasSize(numChunks)); gotResults.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 375c583cce13a..3e43bd7c77684 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -868,7 +868,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -876,8 +876,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -885,8 +885,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 8f0e481213cdf..12811b6be87b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -24,14 +24,13 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -39,7 +38,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; @@ -49,6 +47,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; @@ -90,7 +90,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -102,14 +102,16 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE ); var result = listener.actionGet(TIMEOUT).get(0); - - MatcherAssert.assertThat( - result.asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f)) + assertThat(result, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) result; + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken(".", 0.13315596f)), + "abc", + new ChunkedInference.TextOffset(0, "abc".length()) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 022cbecd1ea6a..2d30fd5aba4c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,8 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -46,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -59,7 +57,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -774,7 +771,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -786,19 +783,12 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th ); var result = listener.actionGet(TIMEOUT).get(0); - assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - - MatcherAssert.assertThat( - asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result), - Matchers.is( - Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(-0.0123f, 0.0123f)) - ) - ) - ) - ); + assertThat(result, CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var embeddingResult = (ChunkedInferenceEmbeddingFloat) result; + assertThat(embeddingResult.chunks(), hasSize(1)); + assertThat(embeddingResult.chunks().get(0).matchedText(), is("abc")); + assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); + assertArrayEquals(new float[] { -0.0123f, 0.0123f }, embeddingResult.chunks().get(0).embedding(), 0.001f); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( @@ -829,7 +819,7 @@ public void testChunkedInfer() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -843,8 +833,8 @@ public void testChunkedInfer() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(1)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 5aa826f1d80fe..3d298823ea19f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionCreator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -684,7 +684,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws apiKey, getUrl(webServer) ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -692,8 +692,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -701,8 +701,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 73bf03fd43ec5..f164afe604b43 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,7 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -665,7 +665,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -680,14 +680,14 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 159b77789482d..b595862775053 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -1613,7 +1613,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, null, @@ -1627,15 +1627,15 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/DelegatingCircuitBreakerServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/DelegatingCircuitBreakerServiceTests.java index aa1a82116387e..a452ff49b7a74 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/DelegatingCircuitBreakerServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/DelegatingCircuitBreakerServiceTests.java @@ -26,49 +26,43 @@ public class DelegatingCircuitBreakerServiceTests extends ESTestCase { * the purpose of the test is not hitting the `IllegalStateException("already closed")` in * PreallocatedCircuitBreaker#addEstimateBytesAndMaybeBreak in {@link PreallocatedCircuitBreakerService} */ - public void testThreadedExecution() { + public void testThreadedExecution() throws InterruptedException { + HierarchyCircuitBreakerService topBreaker = new HierarchyCircuitBreakerService( + CircuitBreakerMetrics.NOOP, + Settings.builder() + .put(REQUEST_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "10mb") + // Disable the real memory checking because it causes other tests to interfere with this one. + .put(USE_REAL_MEMORY_USAGE_SETTING.getKey(), false) + .build(), + List.of(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + PreallocatedCircuitBreakerService preallocated = new PreallocatedCircuitBreakerService( + topBreaker, + CircuitBreaker.REQUEST, + 10_000, + "test" + ); - try ( - HierarchyCircuitBreakerService topBreaker = new HierarchyCircuitBreakerService( - CircuitBreakerMetrics.NOOP, - Settings.builder() - .put(REQUEST_CIRCUIT_BREAKER_LIMIT_SETTING.getKey(), "10mb") - // Disable the real memory checking because it causes other tests to interfere with this one. - .put(USE_REAL_MEMORY_USAGE_SETTING.getKey(), false) - .build(), - List.of(), - new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) - ) - ) { - PreallocatedCircuitBreakerService preallocated = new PreallocatedCircuitBreakerService( - topBreaker, - CircuitBreaker.REQUEST, - 10_000, - "test" - ); + CircuitBreaker breaker = preallocated.getBreaker(CircuitBreaker.REQUEST); - CircuitBreaker breaker = preallocated.getBreaker(CircuitBreaker.REQUEST); + DelegatingCircuitBreakerService delegatingCircuitBreakerService = new DelegatingCircuitBreakerService(breaker, (bytes -> { + breaker.addEstimateBytesAndMaybeBreak(bytes, "test"); + })); - DelegatingCircuitBreakerService delegatingCircuitBreakerService = new DelegatingCircuitBreakerService(breaker, (bytes -> { - breaker.addEstimateBytesAndMaybeBreak(bytes, "test"); - })); + Thread consumerThread = new Thread(() -> { + for (int i = 0; i < 100; ++i) { + delegatingCircuitBreakerService.getBreaker("ignored").addEstimateBytesAndMaybeBreak(i % 2 == 0 ? 10 : -10, "ignored"); + } + }); - Thread consumerThread = new Thread(() -> { - for (int i = 0; i < 100; ++i) { - delegatingCircuitBreakerService.getBreaker("ignored").addEstimateBytesAndMaybeBreak(i % 2 == 0 ? 10 : -10, "ignored"); - } - }); - - final Thread producerThread = new Thread(() -> { - delegatingCircuitBreakerService.disconnect(); - preallocated.close(); - }); - consumerThread.start(); - producerThread.start(); - consumerThread.join(); - producerThread.join(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } + final Thread producerThread = new Thread(() -> { + delegatingCircuitBreakerService.disconnect(); + preallocated.close(); + }); + consumerThread.start(); + producerThread.start(); + consumerThread.join(); + producerThread.join(); } } diff --git a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java index 42fe09691d249..e36ae4994c872 100644 --- a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java +++ b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.engine.EngineFactory; import org.elasticsearch.index.engine.ReadOnlyEngine; @@ -34,6 +35,7 @@ import org.elasticsearch.index.shard.IndexEventListener; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.translog.TranslogStats; +import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.LicensedFeature; @@ -201,6 +203,12 @@ private static SegmentInfos convertToNewerLuceneVersion(OldSegmentInfos oldSegme if (map.containsKey(Engine.MAX_UNSAFE_AUTO_ID_TIMESTAMP_COMMIT_ID) == false) { map.put(Engine.MAX_UNSAFE_AUTO_ID_TIMESTAMP_COMMIT_ID, "-1"); } + if (map.containsKey(Engine.ES_VERSION) == false) { + assert oldSegmentInfos.getLuceneVersion() + .onOrAfter(RecoverySettings.SEQ_NO_SNAPSHOT_RECOVERIES_SUPPORTED_VERSION.luceneVersion()) == false + : oldSegmentInfos.getLuceneVersion() + " should contain the ES_VERSION"; + map.put(Engine.ES_VERSION, IndexVersions.MINIMUM_COMPATIBLE.toString()); + } segmentInfos.setUserData(map, false); for (SegmentCommitInfo infoPerCommit : oldSegmentInfos.asList()) { final SegmentInfo newInfo = BWCCodec.wrap(infoPerCommit.info); diff --git a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldSegmentInfos.java b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldSegmentInfos.java index 18adebb145f98..b2af41653da61 100644 --- a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldSegmentInfos.java +++ b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldSegmentInfos.java @@ -564,6 +564,10 @@ public long getLastGeneration() { return lastGeneration; } + public Version getLuceneVersion() { + return luceneVersion; + } + /** * Prints the given message to the infoStream. Note, this method does not check for null * infoStream. It assumes this check has been performed by the caller, which is recommended to diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java index 6115bec91ad62..a3ced0bf1b607 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java @@ -74,7 +74,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @ESIntegTestCase.ClusterScope(supportsDedicatedMasters = false, numClientNodes = 0) @@ -241,7 +240,7 @@ protected void checkSoftDeletesNotEagerlyLoaded(String restoredIndexName) { } } - protected void assertShardFolders(String indexName, boolean snapshotDirectory) throws IOException { + protected void assertShardFolders(String indexName, boolean isSearchableSnapshot) throws IOException { final Index restoredIndex = resolveIndex(indexName); final String customDataPath = resolveCustomDataPath(indexName); final ShardId shardId = new ShardId(restoredIndex, 0); @@ -261,16 +260,16 @@ protected void assertShardFolders(String indexName, boolean snapshotDirectory) t translogExists ); assertThat( - snapshotDirectory ? "Index file should not exist" : "Index file should exist", + isSearchableSnapshot ? "Index file should not exist" : "Index file should exist", indexExists, - not(snapshotDirectory) + not(isSearchableSnapshot) ); - assertThat("Translog should exist", translogExists, is(true)); - try (Stream dir = Files.list(shardPath.resolveTranslog())) { - final long translogFiles = dir.filter(path -> path.getFileName().toString().contains("translog")).count(); - if (snapshotDirectory) { - assertThat("There should be 2 translog files for a snapshot directory", translogFiles, equalTo(2L)); - } else { + if (isSearchableSnapshot) { + assertThat("Translog should not exist", translogExists, equalTo(false)); + } else { + assertThat("Translog should exist", translogExists, equalTo(true)); + try (Stream dir = Files.list(shardPath.resolveTranslog())) { + final long translogFiles = dir.filter(path -> path.getFileName().toString().contains("translog")).count(); assertThat( "There should be 2+ translog files for a non-snapshot directory", translogFiles, diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java index 67d9d7a82acf3..2797202e5f24e 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/FrozenSearchableSnapshotsIntegTests.java @@ -59,7 +59,6 @@ import java.time.ZoneId; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -169,12 +168,9 @@ public void testCreateAndRestorePartialSearchableSnapshot() throws Exception { logger.info("--> restoring partial index [{}] with cache enabled", restoredIndexName); Settings.Builder indexSettingsBuilder = Settings.builder().put(SearchableSnapshots.SNAPSHOT_CACHE_ENABLED_SETTING.getKey(), true); - final List nonCachedExtensions; if (randomBoolean()) { - nonCachedExtensions = randomSubsetOf(Arrays.asList("fdt", "fdx", "nvd", "dvd", "tip", "cfs", "dim")); + var nonCachedExtensions = randomSubsetOf(Arrays.asList("fdt", "fdx", "nvd", "dvd", "tip", "cfs", "dim")); indexSettingsBuilder.putList(SearchableSnapshots.SNAPSHOT_CACHE_EXCLUDED_FILE_TYPES_SETTING.getKey(), nonCachedExtensions); - } else { - nonCachedExtensions = Collections.emptyList(); } if (randomBoolean()) { indexSettingsBuilder.put( @@ -264,8 +260,6 @@ public void testCreateAndRestorePartialSearchableSnapshot() throws Exception { final long originalSize = snapshotShards.get(shardRouting.getId()).getStats().getTotalSize(); totalExpectedSize += originalSize; - // an extra segments_N file is created for bootstrapping new history and associating translog. We can extract the size of this - // extra file but we have to unwrap the in-memory directory first. final Directory unwrappedDir = FilterDirectory.unwrap( internalCluster().getInstance(IndicesService.class, getDiscoveryNodes().resolveNode(shardRouting.currentNodeId()).getName()) .indexServiceSafe(shardRouting.index()) @@ -277,7 +271,7 @@ public void testCreateAndRestorePartialSearchableSnapshot() throws Exception { assertThat(shardRouting.toString(), unwrappedDir, instanceOf(ByteBuffersDirectory.class)); final ByteBuffersDirectory inMemoryDir = (ByteBuffersDirectory) unwrappedDir; - assertThat(inMemoryDir.listAll(), arrayWithSize(1)); + assertThat(inMemoryDir.listAll(), arrayWithSize(0)); assertThat(shardRouting.toString(), store.totalDataSetSizeInBytes(), equalTo(originalSize)); } diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotIndexEventListener.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotIndexEventListener.java index cf0306e3e6ef2..4caf932a99807 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotIndexEventListener.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotIndexEventListener.java @@ -21,6 +21,7 @@ import org.elasticsearch.index.shard.IndexEventListener; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.store.ByteSizeCachingDirectory; import org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots; @@ -61,6 +62,12 @@ public SearchableSnapshotIndexEventListener( public void beforeIndexShardRecovery(IndexShard indexShard, IndexSettings indexSettings, ActionListener listener) { assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC); ensureSnapshotIsLoaded(indexShard); + var sizeCachingDirectory = ByteSizeCachingDirectory.unwrapDirectory(indexShard.store().directory()); + if (sizeCachingDirectory != null) { + // Marks the cached estimation of the directory size as stale in ByteSizeCachingDirectory since we just loaded the snapshot + // files list into the searchable snapshot directory. + sizeCachingDirectory.markEstimatedSizeAsStale(); + } listener.onResponse(null); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index 7b66a95609b05..5eb9fb9b41a22 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -124,6 +124,7 @@ import java.util.function.Consumer; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; +import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; import static org.elasticsearch.test.SecurityTestsUtils.assertAuthenticationException; import static org.elasticsearch.test.TestMatchers.throwableWithMessage; @@ -1955,6 +1956,37 @@ public void testInvalidToken() throws Exception { final User user = new User("_username", "r1"); when(firstRealm.token(threadContext)).thenReturn(token); when(firstRealm.supports(token)).thenReturn(true); + + when(securityIndex.defensiveCopy()).thenReturn(securityIndex); + // An invalid token might decode to something that looks like a UUID + // Randomise it being invalid because the index doesn't exist, or the document doesn't exist + if (randomBoolean()) { + when(securityIndex.isAvailable(any())).thenReturn(false); + when(securityIndex.getUnavailableReason(any())).thenReturn(new ElasticsearchException(getTestName())); + } else { + when(securityIndex.isAvailable(any())).thenReturn(true); + doAnswer(inv -> { + final GetRequest request = inv.getArgument(0); + final ActionListener listener = inv.getArgument(1); + listener.onResponse( + new GetResponse( + new GetResult( + request.index(), + request.id(), + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + Map.of(), + Map.of() + ) + ) + ); + return null; + }).when(client).get(any(GetRequest.class), any()); + } + mockAuthenticate(firstRealm, token, user); final int numBytes = randomIntBetween(TokenService.MINIMUM_BYTES, TokenService.MINIMUM_BYTES + 32); final byte[] randomBytes = new byte[numBytes]; diff --git a/x-pack/plugin/sql/qa/jdbc/security/build.gradle b/x-pack/plugin/sql/qa/jdbc/security/build.gradle index 82510285cb996..b40a8aaaa5580 100644 --- a/x-pack/plugin/sql/qa/jdbc/security/build.gradle +++ b/x-pack/plugin/sql/qa/jdbc/security/build.gradle @@ -50,6 +50,7 @@ subprojects { tasks.withType(RestIntegTestTask).configureEach { + def taskName = name dependsOn copyTestClasses classpath += configurations.testArtifacts testClassesDirs = project.files(testArtifactsDir) @@ -58,15 +59,15 @@ subprojects { project.gradle.sharedServices, TestClustersPlugin.REGISTRY_SERVICE_NAME ) - project.getProviders().of(TestClusterValueSource.class) { + + def clusterInfo = project.getProviders().of(TestClusterValueSource.class) { it.parameters.path.set(clusterPath) - it.parameters.clusterName.set("javaRestTest") + it.parameters.clusterName.set(taskName) it.parameters.service = serviceProvider } - nonInputProperties.systemProperty 'tests.audit.logfile', - "${-> testClusters.javaRestTest.singleNode().getAuditLog()}" + nonInputProperties.systemProperty 'tests.audit.logfile', clusterInfo.map { it.auditLogs.get(0) } nonInputProperties.systemProperty 'tests.audit.yesterday.logfile', - "${-> testClusters.javaRestTest.singleNode().getAuditLog().getParentFile()}/javaRestTest_audit-${new Date().format('yyyy-MM-dd')}.json" + clusterInfo.map { it.auditLogs.get(0).getParentFile().toString() + "/javaRestTest_audit-${new Date().format('yyyy-MM-dd')}-1.json" } } } diff --git a/x-pack/plugin/sql/qa/server/security/build.gradle b/x-pack/plugin/sql/qa/server/security/build.gradle index 2d9f7b563d073..ef81ebda5cb89 100644 --- a/x-pack/plugin/sql/qa/server/security/build.gradle +++ b/x-pack/plugin/sql/qa/server/security/build.gradle @@ -1,3 +1,9 @@ +import org.elasticsearch.gradle.internal.test.RestIntegTestTask +import org.elasticsearch.gradle.testclusters.TestClusterValueSource +import org.elasticsearch.gradle.testclusters.TestClustersPlugin +import org.elasticsearch.gradle.testclusters.TestClustersRegistry +import org.elasticsearch.gradle.util.GradleUtils + apply plugin: 'elasticsearch.internal-test-artifact' dependencies { @@ -16,6 +22,8 @@ subprojects { // Use tests from the root security qa project in subprojects configurations.create('testArtifacts').transitive(false) + def clusterPath = getPath() + dependencies { javaRestTestImplementation project(":x-pack:plugin:core") javaRestTestImplementation(testArtifact(project(xpackModule('core')))) @@ -50,12 +58,23 @@ subprojects { tasks.named("javaRestTest").configure { dependsOn copyTestClasses + + Provider serviceProvider = GradleUtils.getBuildService( + project.gradle.sharedServices, + TestClustersPlugin.REGISTRY_SERVICE_NAME + ) + + def clusterInfo = project.getProviders().of(TestClusterValueSource.class) { + it.parameters.path.set(clusterPath) + it.parameters.clusterName.set("javaRestTest") + it.parameters.service = serviceProvider + } + testClassesDirs += project.files(testArtifactsDir) classpath += configurations.testArtifacts - nonInputProperties.systemProperty 'tests.audit.logfile', - "${-> testClusters.javaRestTest.singleNode().getAuditLog()}" + nonInputProperties.systemProperty 'tests.audit.logfile', clusterInfo.map { it.auditLogs.get(0) } nonInputProperties.systemProperty 'tests.audit.yesterday.logfile', - "${-> testClusters.javaRestTest.singleNode().getAuditLog().getParentFile()}/javaRestTest_audit-${new Date().format('yyyy-MM-dd')}-1.json.gz" + clusterInfo.map { it.auditLogs.get(0).getParentFile().toString() + "/javaRestTest_audit-${new Date().format('yyyy-MM-dd')}-1.json.gz"} } }