From 1c859c9f7acee9cc432a61a6c26bc74e52507301 Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Mon, 8 Jul 2024 09:55:34 +0200 Subject: [PATCH 01/64] Trying to make SharedBlobCacheService async? --- .../shared/SharedBlobCacheService.java | 74 +++++++++++++------ .../shared/SharedBlobCacheServiceTests.java | 32 ++++---- .../store/input/FrozenIndexInput.java | 3 +- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index ac22d22d5affb..192685d7ede1a 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -643,12 +643,13 @@ private RangeMissingHandler writerWithOffset(RangeMissingHandler writer, int wri // no need to allocate a new capturing lambda if the offset isn't adjusted return writer; } - return (channel, channelPos, relativePos, len, progressUpdater) -> writer.fillCacheRange( + return (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> writer.fillCacheRange( channel, channelPos, relativePos - writeOffset, len, - progressUpdater + progressUpdater, + completionListener ); } @@ -710,7 +711,7 @@ public void close() { sharedBytes.decRef(); } - private record RegionKey(KeyType file, int region) { + private record RegionKey (KeyType file, int region) { @Override public String toString() { return "Chunk{" + "file=" + file + ", region=" + region + '}'; @@ -914,9 +915,12 @@ void populate( final List gaps = tracker.waitForRange( rangeToWrite, rangeToWrite, - Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> { - assert regionOwners.get(io) == this; - }), refs.acquire()) : refs.acquireListener() + Assertions.ENABLED + ? ActionListener.releaseAfter( + ActionListener.running(() -> { assert regionOwners.get(io) == this; }), + refs.acquire() + ) + : refs.acquireListener() ); if (gaps.isEmpty()) { listener.onResponse(false); @@ -989,10 +993,12 @@ private AbstractRunnable fillGapRunnable(SparseFileTracker.Gap gap, RangeMissing start, start, Math.toIntExact(gap.end() - start), - progress -> gap.onProgress(start + progress) + progress -> gap.onProgress(start + progress), + ActionListener.running(() -> { + writeCount.increment(); + gap.onCompletion(); + }) ); - writeCount.increment(); - gap.onCompletion(); }); } @@ -1077,11 +1083,20 @@ public int populateAndRead( int channelPos, int relativePos, int length, - IntConsumer progressUpdater) -> { - writer.fillCacheRange(channel, channelPos, relativePos, length, progressUpdater); - var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime); - SharedBlobCacheService.this.blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime); - SharedBlobCacheService.this.blobCacheMetrics.getCacheMissCounter().increment(); + IntConsumer progressUpdater, + ActionListener completionListener) -> { + writer.fillCacheRange( + channel, + channelPos, + relativePos, + length, + progressUpdater, + ActionListener.runAfter(completionListener, () -> { + var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime); + SharedBlobCacheService.this.blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime); + SharedBlobCacheService.this.blobCacheMetrics.getCacheMissCounter().increment(); + }) + ); }; if (rangeToRead.isEmpty()) { // nothing to read, skip @@ -1165,20 +1180,29 @@ private RangeMissingHandler writerWithOffset(RangeMissingHandler writer, CacheFi // no need to allocate a new capturing lambda if the offset isn't adjusted adjustedWriter = writer; } else { - adjustedWriter = (channel, channelPos, relativePos, len, progressUpdater) -> writer.fillCacheRange( + adjustedWriter = (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> writer.fillCacheRange( channel, channelPos, relativePos - writeOffset, len, - progressUpdater + progressUpdater, + completionListener ); } if (Assertions.ENABLED) { - return (channel, channelPos, relativePos, len, progressUpdater) -> { + return (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> { assert assertValidRegionAndLength(fileRegion, channelPos, len); - adjustedWriter.fillCacheRange(channel, channelPos, relativePos, len, progressUpdater); - assert regionOwners.get(fileRegion.io) == fileRegion - : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]"; + adjustedWriter.fillCacheRange( + channel, + channelPos, + relativePos, + len, + progressUpdater, + ActionListener.runAfter(completionListener, () -> { + assert regionOwners.get(fileRegion.io) == fileRegion + : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]"; + }) + ); }; } return adjustedWriter; @@ -1250,8 +1274,14 @@ public interface RangeMissingHandler { * @param progressUpdater consumer to invoke with the number of copied bytes as they are written in cache. * This is used to notify waiting readers that data become available in cache. */ - void fillCacheRange(SharedBytes.IO channel, int channelPos, int relativePos, int length, IntConsumer progressUpdater) - throws IOException; + void fillCacheRange( + SharedBytes.IO channel, + int channelPos, + int relativePos, + int length, + IntConsumer progressUpdater, + ActionListener completionListener + ) throws IOException; } public record Stats( diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java index edeed9a16034a..5a7e1f2067f25 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java @@ -104,7 +104,7 @@ public void testBasicEviction() throws IOException { ByteRange.of(0L, 1L), ByteRange.of(0L, 1L), (channel, channelPos, relativePos, length) -> 1, - (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), taskQueue.getThreadPool().generic(), bytesReadFuture ); @@ -538,7 +538,7 @@ public void execute(Runnable command) { final long size = size(250); AtomicLong bytesRead = new AtomicLong(size); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchFullEntry(cacheKey, size, (channel, channelPos, relativePos, length, progressUpdater) -> { + cacheService.maybeFetchFullEntry(cacheKey, size, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesRead.addAndGet(-length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -552,7 +552,7 @@ public void execute(Runnable command) { // a download that would use up all regions should not run final var cacheKey = generateCacheKey(); assertEquals(2, cacheService.freeRegionCount()); - var configured = cacheService.maybeFetchFullEntry(cacheKey, size(500), (ch, chPos, relPos, len, update) -> { + var configured = cacheService.maybeFetchFullEntry(cacheKey, size(500), (ch, chPos, relPos, len, update, completionListener) -> { throw new AssertionError("Should never reach here"); }, bulkExecutor, ActionListener.noop()); assertFalse(configured); @@ -596,7 +596,7 @@ public void testFetchFullCacheEntryConcurrently() throws Exception { f -> cacheService.maybeFetchFullEntry( cacheKey, size, - (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), bulkExecutor, f ) @@ -843,7 +843,7 @@ public void testMaybeEvictLeastUsed() throws Exception { var entry = cacheService.get(cacheKey, regionSize, 0); entry.populate( ByteRange.of(0L, regionSize), - (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), taskQueue.getThreadPool().generic(), ActionListener.noop() ); @@ -934,7 +934,7 @@ public void execute(Runnable command) { final long blobLength = size(250); // 3 regions AtomicLong bytesRead = new AtomicLong(0L); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater) -> { + cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -961,7 +961,7 @@ public void execute(Runnable command) { cacheKey, region, blobLength, - (channel, channelPos, relativePos, length, progressUpdater) -> { + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, @@ -985,7 +985,7 @@ public void execute(Runnable command) { cacheKey, randomIntBetween(0, 10), randomLongBetween(1L, regionSize), - (channel, channelPos, relativePos, length, progressUpdater) -> { + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { throw new AssertionError("should not be executed"); }, bulkExecutor, @@ -1003,7 +1003,7 @@ public void execute(Runnable command) { long blobLength = randomLongBetween(1L, regionSize); AtomicLong bytesRead = new AtomicLong(0L); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater) -> { + cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -1077,7 +1077,7 @@ public void execute(Runnable command) { region, range, blobLength, - (channel, channelPos, relativePos, length, progressUpdater) -> { + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { assertThat(range.start() + relativePos, equalTo(cacheService.getRegionStart(region) + regionRange.start())); assertThat(channelPos, equalTo(Math.toIntExact(regionRange.start()))); assertThat(length, equalTo(Math.toIntExact(regionRange.length()))); @@ -1117,7 +1117,7 @@ public void execute(Runnable command) { region, ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length), + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> bytesCopied.addAndGet(length), bulkExecutor, listener ); @@ -1140,7 +1140,7 @@ public void execute(Runnable command) { randomIntBetween(0, 10), ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater) -> { + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { throw new AssertionError("should not be executed"); }, bulkExecutor, @@ -1163,7 +1163,7 @@ public void execute(Runnable command) { 0, ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length), + (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> bytesCopied.addAndGet(length), bulkExecutor, future ); @@ -1204,7 +1204,7 @@ public void testPopulate() throws Exception { var entry = cacheService.get(cacheKey, blobLength, 0); AtomicLong bytesWritten = new AtomicLong(0L); final PlainActionFuture future1 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future1); @@ -1215,7 +1215,7 @@ public void testPopulate() throws Exception { // start populating the second region entry = cacheService.get(cacheKey, blobLength, 1); final PlainActionFuture future2 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future2); @@ -1223,7 +1223,7 @@ public void testPopulate() throws Exception { // start populating again the first region, listener should be called immediately entry = cacheService.get(cacheKey, blobLength, 0); final PlainActionFuture future3 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future3); diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java index 931e8790f98c6..34baa969f9abf 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java @@ -146,7 +146,7 @@ private void readWithoutBlobCacheSlow(ByteBuffer b, long position, int length) t final int read = SharedBytes.readCacheFile(channel, pos, relativePos, len, byteBufferReference); stats.addCachedBytesRead(read); return read; - }, (channel, channelPos, relativePos, len, progressUpdater) -> { + }, (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> { final long startTimeNanos = stats.currentTimeNanos(); try (InputStream input = openInputStreamFromBlobStore(rangeToWrite.start() + relativePos, len)) { assert ThreadPool.assertCurrentThreadPool(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME); @@ -169,6 +169,7 @@ private void readWithoutBlobCacheSlow(ByteBuffer b, long position, int length) t ); final long endTimeNanos = stats.currentTimeNanos(); stats.addCachedBytesWritten(len, endTimeNanos - startTimeNanos); + completionListener.onResponse(null); } }); assert bytesRead == length : bytesRead + " vs " + length; From 8d0a1ac71df13861f3a54799ee52ada1166289c7 Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Mon, 8 Jul 2024 10:04:31 +0200 Subject: [PATCH 02/64] Revert "Trying to make SharedBlobCacheService async?" This reverts commit 1c859c9f7acee9cc432a61a6c26bc74e52507301. --- .../shared/SharedBlobCacheService.java | 74 ++++++------------- .../shared/SharedBlobCacheServiceTests.java | 32 ++++---- .../store/input/FrozenIndexInput.java | 3 +- 3 files changed, 39 insertions(+), 70 deletions(-) diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index 192685d7ede1a..ac22d22d5affb 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -643,13 +643,12 @@ private RangeMissingHandler writerWithOffset(RangeMissingHandler writer, int wri // no need to allocate a new capturing lambda if the offset isn't adjusted return writer; } - return (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> writer.fillCacheRange( + return (channel, channelPos, relativePos, len, progressUpdater) -> writer.fillCacheRange( channel, channelPos, relativePos - writeOffset, len, - progressUpdater, - completionListener + progressUpdater ); } @@ -711,7 +710,7 @@ public void close() { sharedBytes.decRef(); } - private record RegionKey (KeyType file, int region) { + private record RegionKey(KeyType file, int region) { @Override public String toString() { return "Chunk{" + "file=" + file + ", region=" + region + '}'; @@ -915,12 +914,9 @@ void populate( final List gaps = tracker.waitForRange( rangeToWrite, rangeToWrite, - Assertions.ENABLED - ? ActionListener.releaseAfter( - ActionListener.running(() -> { assert regionOwners.get(io) == this; }), - refs.acquire() - ) - : refs.acquireListener() + Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> { + assert regionOwners.get(io) == this; + }), refs.acquire()) : refs.acquireListener() ); if (gaps.isEmpty()) { listener.onResponse(false); @@ -993,12 +989,10 @@ private AbstractRunnable fillGapRunnable(SparseFileTracker.Gap gap, RangeMissing start, start, Math.toIntExact(gap.end() - start), - progress -> gap.onProgress(start + progress), - ActionListener.running(() -> { - writeCount.increment(); - gap.onCompletion(); - }) + progress -> gap.onProgress(start + progress) ); + writeCount.increment(); + gap.onCompletion(); }); } @@ -1083,20 +1077,11 @@ public int populateAndRead( int channelPos, int relativePos, int length, - IntConsumer progressUpdater, - ActionListener completionListener) -> { - writer.fillCacheRange( - channel, - channelPos, - relativePos, - length, - progressUpdater, - ActionListener.runAfter(completionListener, () -> { - var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime); - SharedBlobCacheService.this.blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime); - SharedBlobCacheService.this.blobCacheMetrics.getCacheMissCounter().increment(); - }) - ); + IntConsumer progressUpdater) -> { + writer.fillCacheRange(channel, channelPos, relativePos, length, progressUpdater); + var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime); + SharedBlobCacheService.this.blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime); + SharedBlobCacheService.this.blobCacheMetrics.getCacheMissCounter().increment(); }; if (rangeToRead.isEmpty()) { // nothing to read, skip @@ -1180,29 +1165,20 @@ private RangeMissingHandler writerWithOffset(RangeMissingHandler writer, CacheFi // no need to allocate a new capturing lambda if the offset isn't adjusted adjustedWriter = writer; } else { - adjustedWriter = (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> writer.fillCacheRange( + adjustedWriter = (channel, channelPos, relativePos, len, progressUpdater) -> writer.fillCacheRange( channel, channelPos, relativePos - writeOffset, len, - progressUpdater, - completionListener + progressUpdater ); } if (Assertions.ENABLED) { - return (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> { + return (channel, channelPos, relativePos, len, progressUpdater) -> { assert assertValidRegionAndLength(fileRegion, channelPos, len); - adjustedWriter.fillCacheRange( - channel, - channelPos, - relativePos, - len, - progressUpdater, - ActionListener.runAfter(completionListener, () -> { - assert regionOwners.get(fileRegion.io) == fileRegion - : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]"; - }) - ); + adjustedWriter.fillCacheRange(channel, channelPos, relativePos, len, progressUpdater); + assert regionOwners.get(fileRegion.io) == fileRegion + : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]"; }; } return adjustedWriter; @@ -1274,14 +1250,8 @@ public interface RangeMissingHandler { * @param progressUpdater consumer to invoke with the number of copied bytes as they are written in cache. * This is used to notify waiting readers that data become available in cache. */ - void fillCacheRange( - SharedBytes.IO channel, - int channelPos, - int relativePos, - int length, - IntConsumer progressUpdater, - ActionListener completionListener - ) throws IOException; + void fillCacheRange(SharedBytes.IO channel, int channelPos, int relativePos, int length, IntConsumer progressUpdater) + throws IOException; } public record Stats( diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java index 5a7e1f2067f25..edeed9a16034a 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java @@ -104,7 +104,7 @@ public void testBasicEviction() throws IOException { ByteRange.of(0L, 1L), ByteRange.of(0L, 1L), (channel, channelPos, relativePos, length) -> 1, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), taskQueue.getThreadPool().generic(), bytesReadFuture ); @@ -538,7 +538,7 @@ public void execute(Runnable command) { final long size = size(250); AtomicLong bytesRead = new AtomicLong(size); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchFullEntry(cacheKey, size, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + cacheService.maybeFetchFullEntry(cacheKey, size, (channel, channelPos, relativePos, length, progressUpdater) -> { bytesRead.addAndGet(-length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -552,7 +552,7 @@ public void execute(Runnable command) { // a download that would use up all regions should not run final var cacheKey = generateCacheKey(); assertEquals(2, cacheService.freeRegionCount()); - var configured = cacheService.maybeFetchFullEntry(cacheKey, size(500), (ch, chPos, relPos, len, update, completionListener) -> { + var configured = cacheService.maybeFetchFullEntry(cacheKey, size(500), (ch, chPos, relPos, len, update) -> { throw new AssertionError("Should never reach here"); }, bulkExecutor, ActionListener.noop()); assertFalse(configured); @@ -596,7 +596,7 @@ public void testFetchFullCacheEntryConcurrently() throws Exception { f -> cacheService.maybeFetchFullEntry( cacheKey, size, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), bulkExecutor, f ) @@ -843,7 +843,7 @@ public void testMaybeEvictLeastUsed() throws Exception { var entry = cacheService.get(cacheKey, regionSize, 0); entry.populate( ByteRange.of(0L, regionSize), - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> progressUpdater.accept(length), + (channel, channelPos, relativePos, length, progressUpdater) -> progressUpdater.accept(length), taskQueue.getThreadPool().generic(), ActionListener.noop() ); @@ -934,7 +934,7 @@ public void execute(Runnable command) { final long blobLength = size(250); // 3 regions AtomicLong bytesRead = new AtomicLong(0L); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -961,7 +961,7 @@ public void execute(Runnable command) { cacheKey, region, blobLength, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + (channel, channelPos, relativePos, length, progressUpdater) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, @@ -985,7 +985,7 @@ public void execute(Runnable command) { cacheKey, randomIntBetween(0, 10), randomLongBetween(1L, regionSize), - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + (channel, channelPos, relativePos, length, progressUpdater) -> { throw new AssertionError("should not be executed"); }, bulkExecutor, @@ -1003,7 +1003,7 @@ public void execute(Runnable command) { long blobLength = randomLongBetween(1L, regionSize); AtomicLong bytesRead = new AtomicLong(0L); final PlainActionFuture future = new PlainActionFuture<>(); - cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + cacheService.maybeFetchRegion(cacheKey, 0, blobLength, (channel, channelPos, relativePos, length, progressUpdater) -> { bytesRead.addAndGet(length); progressUpdater.accept(length); }, bulkExecutor, future); @@ -1077,7 +1077,7 @@ public void execute(Runnable command) { region, range, blobLength, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + (channel, channelPos, relativePos, length, progressUpdater) -> { assertThat(range.start() + relativePos, equalTo(cacheService.getRegionStart(region) + regionRange.start())); assertThat(channelPos, equalTo(Math.toIntExact(regionRange.start()))); assertThat(length, equalTo(Math.toIntExact(regionRange.length()))); @@ -1117,7 +1117,7 @@ public void execute(Runnable command) { region, ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> bytesCopied.addAndGet(length), + (channel, channelPos, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length), bulkExecutor, listener ); @@ -1140,7 +1140,7 @@ public void execute(Runnable command) { randomIntBetween(0, 10), ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + (channel, channelPos, relativePos, length, progressUpdater) -> { throw new AssertionError("should not be executed"); }, bulkExecutor, @@ -1163,7 +1163,7 @@ public void execute(Runnable command) { 0, ByteRange.of(0L, blobLength), blobLength, - (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> bytesCopied.addAndGet(length), + (channel, channelPos, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length), bulkExecutor, future ); @@ -1204,7 +1204,7 @@ public void testPopulate() throws Exception { var entry = cacheService.get(cacheKey, blobLength, 0); AtomicLong bytesWritten = new AtomicLong(0L); final PlainActionFuture future1 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future1); @@ -1215,7 +1215,7 @@ public void testPopulate() throws Exception { // start populating the second region entry = cacheService.get(cacheKey, blobLength, 1); final PlainActionFuture future2 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future2); @@ -1223,7 +1223,7 @@ public void testPopulate() throws Exception { // start populating again the first region, listener should be called immediately entry = cacheService.get(cacheKey, blobLength, 0); final PlainActionFuture future3 = new PlainActionFuture<>(); - entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater, completionListener) -> { + entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, relativePos, length, progressUpdater) -> { bytesWritten.addAndGet(length); progressUpdater.accept(length); }, taskQueue.getThreadPool().generic(), future3); diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java index 34baa969f9abf..931e8790f98c6 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java @@ -146,7 +146,7 @@ private void readWithoutBlobCacheSlow(ByteBuffer b, long position, int length) t final int read = SharedBytes.readCacheFile(channel, pos, relativePos, len, byteBufferReference); stats.addCachedBytesRead(read); return read; - }, (channel, channelPos, relativePos, len, progressUpdater, completionListener) -> { + }, (channel, channelPos, relativePos, len, progressUpdater) -> { final long startTimeNanos = stats.currentTimeNanos(); try (InputStream input = openInputStreamFromBlobStore(rangeToWrite.start() + relativePos, len)) { assert ThreadPool.assertCurrentThreadPool(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME); @@ -169,7 +169,6 @@ private void readWithoutBlobCacheSlow(ByteBuffer b, long position, int length) t ); final long endTimeNanos = stats.currentTimeNanos(); stats.addCachedBytesWritten(len, endTimeNanos - startTimeNanos); - completionListener.onResponse(null); } }); assert bytesRead == length : bytesRead + " vs " + length; From ae4aa2ee778b524405fb96b32225d62e199cac10 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Mon, 8 Jul 2024 09:17:10 +0100 Subject: [PATCH 03/64] Add known-issues for all affected releases for the feature upgrade issue (#110523) --- docs/reference/release-notes/8.13.0.asciidoc | 7 +++++++ docs/reference/release-notes/8.13.1.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.13.2.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.13.3.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.13.4.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.14.0.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.14.1.asciidoc | 10 ++++++++++ docs/reference/release-notes/8.14.2.asciidoc | 12 +++++++++++- 8 files changed, 78 insertions(+), 1 deletion(-) diff --git a/docs/reference/release-notes/8.13.0.asciidoc b/docs/reference/release-notes/8.13.0.asciidoc index dba4fdbe5f67e..4bb2913f07be7 100644 --- a/docs/reference/release-notes/8.13.0.asciidoc +++ b/docs/reference/release-notes/8.13.0.asciidoc @@ -21,6 +21,13 @@ This affects clusters running version 8.10 or later, with an active downsampling https://www.elastic.co/guide/en/elasticsearch/reference/current/downsampling-ilm.html[configuration] or a configuration that was activated at some point since upgrading to version 8.10 or later. +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[breaking-8.13.0]] [float] === Breaking changes diff --git a/docs/reference/release-notes/8.13.1.asciidoc b/docs/reference/release-notes/8.13.1.asciidoc index 7b3dbff74cc6e..572f9fe1172a9 100644 --- a/docs/reference/release-notes/8.13.1.asciidoc +++ b/docs/reference/release-notes/8.13.1.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.1]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.1]] [float] diff --git a/docs/reference/release-notes/8.13.2.asciidoc b/docs/reference/release-notes/8.13.2.asciidoc index 514118f5ea575..20ae7abbb5769 100644 --- a/docs/reference/release-notes/8.13.2.asciidoc +++ b/docs/reference/release-notes/8.13.2.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.2]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.2]] [float] diff --git a/docs/reference/release-notes/8.13.3.asciidoc b/docs/reference/release-notes/8.13.3.asciidoc index 9aee0dd815f6d..ea51bd6f9b743 100644 --- a/docs/reference/release-notes/8.13.3.asciidoc +++ b/docs/reference/release-notes/8.13.3.asciidoc @@ -10,6 +10,16 @@ Also see <>. SQL:: * Limit how much space some string functions can use {es-pull}107333[#107333] +[[known-issues-8.13.3]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.3]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.13.4.asciidoc b/docs/reference/release-notes/8.13.4.asciidoc index bf3f2f497d8fc..b60c9f485bb31 100644 --- a/docs/reference/release-notes/8.13.4.asciidoc +++ b/docs/reference/release-notes/8.13.4.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.4]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.4]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.0.asciidoc b/docs/reference/release-notes/8.14.0.asciidoc index 42f2f86a123ed..5b92c49ced70a 100644 --- a/docs/reference/release-notes/8.14.0.asciidoc +++ b/docs/reference/release-notes/8.14.0.asciidoc @@ -12,6 +12,16 @@ Security:: * Apply stricter Document Level Security (DLS) rules for the validate query API with the rewrite parameter {es-pull}105709[#105709] * Apply stricter Document Level Security (DLS) rules for terms aggregations when min_doc_count is set to 0 {es-pull}105714[#105714] +[[known-issues-8.14.0]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.0]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.1.asciidoc b/docs/reference/release-notes/8.14.1.asciidoc index f161c7d08099c..1cab442eb9ac1 100644 --- a/docs/reference/release-notes/8.14.1.asciidoc +++ b/docs/reference/release-notes/8.14.1.asciidoc @@ -4,6 +4,16 @@ Also see <>. +[[known-issues-8.14.1]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.1]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.2.asciidoc b/docs/reference/release-notes/8.14.2.asciidoc index 2bb374451b2ac..9273355106a03 100644 --- a/docs/reference/release-notes/8.14.2.asciidoc +++ b/docs/reference/release-notes/8.14.2.asciidoc @@ -5,6 +5,16 @@ coming[8.14.2] Also see <>. +[[known-issues-8.14.2]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.2]] [float] === Bug fixes @@ -35,4 +45,4 @@ Ranking:: Search:: * Add hexstring support byte painless scorers {es-pull}109492[#109492] -* Fix automatic tracking of collapse with `docvalue_fields` {es-pull}110103[#110103] \ No newline at end of file +* Fix automatic tracking of collapse with `docvalue_fields` {es-pull}110103[#110103] From 58bb05df94e5edf4a988a6be1266ee6c8bc03669 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 8 Jul 2024 09:20:28 +0100 Subject: [PATCH 04/64] Clarify logs/errors re. publish addresses (#110570) These warning logs and error messages assume some level of understanding of Elasticsearch's networking config and are not particularly actionable. This commit adds links to the relevant section of the manual, rewords them a little to match the terminology used in the manual, and also documents that each node must have its own publish address, distinct from those of all other nodes. --- docs/reference/modules/network.asciidoc | 2 ++ .../elasticsearch/common/ReferenceDocs.java | 1 + .../HandshakingTransportAddressConnector.java | 10 ++++++++-- .../transport/TransportService.java | 15 ++++++++++++++- .../common/reference-docs-links.json | 3 ++- ...shakingTransportAddressConnectorTests.java | 19 ++++++++++++------- .../TransportServiceHandshakeTests.java | 14 +++++++++++++- 7 files changed, 52 insertions(+), 12 deletions(-) diff --git a/docs/reference/modules/network.asciidoc b/docs/reference/modules/network.asciidoc index 55c236ce43574..593aa79ded4d9 100644 --- a/docs/reference/modules/network.asciidoc +++ b/docs/reference/modules/network.asciidoc @@ -153,6 +153,8 @@ The only requirements are that each node must be: cluster, and by any remote clusters that will discover it using <>. +Each node must have its own distinct publish address. + If you specify the transport publish address using a hostname then {es} will resolve this hostname to an IP address once during startup, and other nodes will use the resulting IP address instead of resolving the name again diff --git a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java index 3605204a9b2a9..1953c1680040a 100644 --- a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java +++ b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java @@ -74,6 +74,7 @@ public enum ReferenceDocs { EXECUTABLE_JNA_TMPDIR, NETWORK_THREADING_MODEL, ALLOCATION_EXPLAIN_API, + NETWORK_BINDING_AND_PUBLISHING, // this comment keeps the ';' on the next line so every entry above has a trailing ',' which makes the diff for adding new links cleaner ; diff --git a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java index 209faa7207be1..1b68383b8f99f 100644 --- a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java +++ b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -154,10 +155,15 @@ public void onFailure(Exception e) { // publish address. logger.warn( () -> format( - "completed handshake with [%s] at [%s] but followup connection to [%s] failed", + """ + Successfully discovered master-eligible node [%s] at address [%s] but could not \ + connect to it at its publish address of [%s]. Each node in a cluster must be \ + accessible at its publish address by all other nodes in the cluster. See %s for \ + more information.""", remoteNode.descriptionWithoutAttributes(), transportAddress, - remoteNode.getAddress() + remoteNode.getAddress(), + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING ), e ); diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index c3d53855a9c75..33ea35ecffd94 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; @@ -518,7 +519,19 @@ public ConnectionManager.ConnectionValidator connectionValidator(DiscoveryNode n handshake(newConnection, actualProfile.getHandshakeTimeout(), Predicates.always(), listener.map(resp -> { final DiscoveryNode remote = resp.discoveryNode; if (node.equals(remote) == false) { - throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); + throw new ConnectTransportException( + node, + Strings.format( + """ + Connecting to [%s] failed: expected to connect to [%s] but found [%s] instead. Ensure that each node has \ + its own distinct publish address, and that your network is configured so that every connection to a node's \ + publish address is routed to the correct node. See %s for more information.""", + node.getAddress(), + node.descriptionWithoutAttributes(), + remote.descriptionWithoutAttributes(), + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ) + ); } return null; })); diff --git a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json index 931e0576b85b8..303ae22f16269 100644 --- a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json +++ b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json @@ -34,5 +34,6 @@ "UNASSIGNED_SHARDS": "red-yellow-cluster-status.html", "EXECUTABLE_JNA_TMPDIR": "executable-jna-tmpdir.html", "NETWORK_THREADING_MODEL": "modules-network.html#modules-network-threading-model", - "ALLOCATION_EXPLAIN_API": "cluster-allocation-explain.html" + "ALLOCATION_EXPLAIN_API": "cluster-allocation-explain.html", + "NETWORK_BINDING_AND_PUBLISHING": "modules-network.html#modules-network-binding-publishing" } diff --git a/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java b/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java index 8ca96aff9c3e5..5c6afc1e805ce 100644 --- a/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java @@ -18,6 +18,8 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.ReferenceDocs; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.core.Nullable; @@ -159,13 +161,16 @@ public void testLogsFullConnectionFailureAfterSuccessfulHandshake() throws Excep "message", HandshakingTransportAddressConnector.class.getCanonicalName(), Level.WARN, - "completed handshake with [" - + remoteNode.descriptionWithoutAttributes() - + "] at [" - + discoveryAddress - + "] but followup connection to [" - + remoteNodeAddress - + "] failed" + Strings.format( + """ + Successfully discovered master-eligible node [%s] at address [%s] but could not connect to it at its publish \ + address of [%s]. Each node in a cluster must be accessible at its publish address by all other nodes in the \ + cluster. See %s for more information.""", + remoteNode.descriptionWithoutAttributes(), + discoveryAddress, + remoteNodeAddress, + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ) ) ); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index 761d369d6fc39..c5034f51d1e26 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -46,6 +46,7 @@ import static java.util.Collections.emptySet; import static org.elasticsearch.transport.AbstractSimpleTransportTestCase.IGNORE_DESERIALIZATION_ERRORS_SETTING; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -306,7 +307,18 @@ public void testNodeConnectWithDifferentNodeId() { ConnectTransportException.class, () -> AbstractSimpleTransportTestCase.connectToNode(transportServiceA, discoveryNode, TestProfiles.LIGHT_PROFILE) ); - assertThat(ex.getMessage(), containsString("unexpected remote node")); + assertThat( + ex.getMessage(), + allOf( + containsString("Connecting to [" + discoveryNode.getAddress() + "] failed"), + containsString("expected to connect to [" + discoveryNode.descriptionWithoutAttributes() + "]"), + containsString("found [" + transportServiceB.getLocalNode().descriptionWithoutAttributes() + "] instead"), + containsString("Ensure that each node has its own distinct publish address"), + containsString("routed to the correct node"), + containsString("https://www.elastic.co/guide/en/elasticsearch/reference/"), + containsString("modules-network.html") + ) + ); assertFalse(transportServiceA.nodeConnected(discoveryNode)); } From 3392f6193edf5423092cd71acb43d1faaaa95ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Mon, 8 Jul 2024 11:12:41 +0200 Subject: [PATCH 05/64] ESQL: Change TopList tests with random cases (#110327) Instead of using simple hardcoded cases for aggregation tests, use random ones provided by the `Cases()` methods --- .../function/MultiRowTestCaseSupplier.java | 281 ++++++++++++ .../expression/function/TestCaseSupplier.java | 24 ++ .../function/aggregate/TopTests.java | 403 +++++++++--------- 3 files changed, 505 insertions(+), 203 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java new file mode 100644 index 0000000000000..5621e63061e15 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java @@ -0,0 +1,281 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.type.DataType; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.test.ESTestCase.randomBoolean; +import static org.elasticsearch.test.ESTestCase.randomList; +import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedDataSupplier; + +/** + * Extension of {@link TestCaseSupplier} that provided multi-row test cases. + */ +public final class MultiRowTestCaseSupplier { + + private MultiRowTestCaseSupplier() {} + + public static List intCases(int minRows, int maxRows, int min, int max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0 <= max && 0 >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 ints>", () -> randomList(minRows, maxRows, () -> 0), DataType.INTEGER, false, true)); + } + + if (max != 0) { + cases.add( + new TypedDataSupplier("<" + max + " ints>", () -> randomList(minRows, maxRows, () -> max), DataType.INTEGER, false, true) + ); + } + + if (min != 0 && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " ints>", () -> randomList(minRows, maxRows, () -> min), DataType.INTEGER, false, true) + ); + } + + int lower = Math.max(min, 1); + int upper = Math.min(max, Integer.MAX_VALUE); + if (lower < upper) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomIntBetween(lower, upper)), + DataType.INTEGER, + false, + true + ) + ); + } + + int lower1 = Math.max(min, Integer.MIN_VALUE); + int upper1 = Math.min(max, -1); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomIntBetween(lower1, upper1)), + DataType.INTEGER, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomIntBetween(min, max); + } + return randomBoolean() ? ESTestCase.randomIntBetween(min, -1) : ESTestCase.randomIntBetween(1, max); + }), DataType.INTEGER, false, true)); + } + + return cases; + } + + public static List longCases(int minRows, int maxRows, long min, long max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0 <= max && 0 >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 longs>", () -> randomList(minRows, maxRows, () -> 0L), DataType.LONG, false, true)); + } + + if (max != 0) { + cases.add( + new TypedDataSupplier("<" + max + " longs>", () -> randomList(minRows, maxRows, () -> max), DataType.LONG, false, true) + ); + } + + if (min != 0 && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " longs>", () -> randomList(minRows, maxRows, () -> min), DataType.LONG, false, true) + ); + } + + long lower = Math.max(min, 1); + long upper = Math.min(max, Long.MAX_VALUE); + if (lower < upper) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(lower, upper)), + DataType.LONG, + false, + true + ) + ); + } + + long lower1 = Math.max(min, Long.MIN_VALUE); + long upper1 = Math.min(max, -1); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(lower1, upper1)), + DataType.LONG, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomLongBetween(min, max); + } + return randomBoolean() ? ESTestCase.randomLongBetween(min, -1) : ESTestCase.randomLongBetween(1, max); + }), DataType.LONG, false, true)); + } + + return cases; + } + + public static List doubleCases(int minRows, int maxRows, double min, double max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0d <= max && 0d >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 doubles>", () -> randomList(minRows, maxRows, () -> 0d), DataType.DOUBLE, false, true)); + cases.add(new TypedDataSupplier("<-0 doubles>", () -> randomList(minRows, maxRows, () -> -0d), DataType.DOUBLE, false, true)); + } + + if (max != 0d) { + cases.add( + new TypedDataSupplier("<" + max + " doubles>", () -> randomList(minRows, maxRows, () -> max), DataType.DOUBLE, false, true) + ); + } + + if (min != 0d && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " doubles>", () -> randomList(minRows, maxRows, () -> min), DataType.DOUBLE, false, true) + ); + } + + double lower1 = Math.max(min, 0d); + double upper1 = Math.min(max, 1d); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower1, upper1, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower2 = Math.max(min, -1d); + double upper2 = Math.min(max, 0d); + if (lower2 < upper2) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower2, upper2, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower3 = Math.max(min, 1d); + double upper3 = Math.min(max, Double.MAX_VALUE); + if (lower3 < upper3) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower3, upper3, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower4 = Math.max(min, -Double.MAX_VALUE); + double upper4 = Math.min(max, -1d); + if (lower4 < upper4) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower4, upper4, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomDoubleBetween(min, max, true); + } + return randomBoolean() ? ESTestCase.randomDoubleBetween(min, -1, true) : ESTestCase.randomDoubleBetween(1, max, true); + }), DataType.DOUBLE, false, true)); + } + + return cases; + } + + public static List dateCases(int minRows, int maxRows) { + List cases = new ArrayList<>(); + + cases.add( + new TypedDataSupplier( + "<1970-01-01T00:00:00Z dates>", + () -> randomList(minRows, maxRows, () -> 0L), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // 1970-01-01T00:00:00Z - 2286-11-20T17:46:40Z + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(0, 10 * (long) 10e11)), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // 2286-11-20T17:46:40Z - +292278994-08-17T07:12:55.807Z + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(10 * (long) 10e11, Long.MAX_VALUE)), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // very close to +292278994-08-17T07:12:55.807Z, the maximum supported millis since epoch + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(Long.MAX_VALUE / 100 * 99, Long.MAX_VALUE)), + DataType.DATETIME, + false, + true + ) + ); + + return cases; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 77c45bbd69854..6ece7151ccd7a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -820,6 +820,12 @@ public static void unary( unary(suppliers, expectedEvaluatorToString, valueSuppliers, expectedOutputType, expected, unused -> warnings); } + /** + * Generate cases for {@link DataType#INTEGER}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#intCases}. + *

+ */ public static List intCases(int min, int max, boolean includeZero) { List cases = new ArrayList<>(); if (0 <= max && 0 >= min && includeZero) { @@ -844,6 +850,12 @@ public static List intCases(int min, int max, boolean include return cases; } + /** + * Generate cases for {@link DataType#LONG}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#longCases}. + *

+ */ public static List longCases(long min, long max, boolean includeZero) { List cases = new ArrayList<>(); if (0L <= max && 0L >= min && includeZero) { @@ -909,6 +921,12 @@ public static List ulongCases(BigInteger min, BigInteger max, return cases; } + /** + * Generate cases for {@link DataType#DOUBLE}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#doubleCases}. + *

+ */ public static List doubleCases(double min, double max, boolean includeZero) { List cases = new ArrayList<>(); @@ -980,6 +998,12 @@ public static List booleanCases() { ); } + /** + * Generate cases for {@link DataType#DATETIME}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#dateCases}. + *

+ */ public static List dateCases() { return List.of( new TypedDataSupplier("<1970-01-01T00:00:00Z>", () -> 0L, DataType.DATETIME), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index 7b77decb560a9..00457f46266d8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -15,10 +15,14 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.function.Supplier; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -29,212 +33,175 @@ public TopTests(@Name("TestCase") Supplier testCaseSu @ParametersFactory public static Iterable parameters() { - var suppliers = List.of( - // All types - new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(List.of(200, 8, 5, 0).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(List.of(200L, 8L, 5L, 0L).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(List.of(200., 8., 5., 0.).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(List.of(200L, 8L, 5L, 0L).subList(0, limit)) - ); - }), + var suppliers = new ArrayList(); - // Surrogates - new TestCaseSupplier( - List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(200) - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(200.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(200L) - ) - ), + for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) { + for (String order : List.of("asc", "desc")) { + for (var fieldCaseSupplier : Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000) + ).flatMap(List::stream).toList()) { + suppliers.add(TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order)); + } + } + } - // Folding - new TestCaseSupplier( - List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(200) - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(200.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(200L) - ) - ), + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DATETIME, + equalTo(200L) + ) + ), - // Resolution errors - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(0, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Limit must be greater than 0 in [], found [0]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(2, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("wrong-order"), DataType.KEYWORD, "order").forceLiteral() - ), - "Invalid order value in [], expected [ASC, DESC] but got [wrong-order]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(null, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "second argument of [] cannot be null, received [limit]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(null, DataType.KEYWORD, "order").forceLiteral() - ), - "third argument of [] cannot be null, received [order]" + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + + // Resolution errors + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(0, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Limit must be greater than 0 in [], found [0]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(2, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("wrong-order"), DataType.KEYWORD, "order").forceLiteral() + ), + "Invalid order value in [], expected [ASC, DESC] but got [wrong-order]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(null, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "second argument of [] cannot be null, received [limit]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(null, DataType.KEYWORD, "order").forceLiteral() + ), + "third argument of [] cannot be null, received [order]" + ) ) ) ); @@ -246,4 +213,34 @@ public static Iterable parameters() { protected Expression build(Source source, List args) { return new Top(source, args.get(0), args.get(1), args.get(2)); } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier( + TestCaseSupplier.TypedDataSupplier fieldSupplier, + TestCaseSupplier.TypedDataSupplier limitCaseSupplier, + String order + ) { + return new TestCaseSupplier(List.of(fieldSupplier.type(), DataType.INTEGER, DataType.KEYWORD), () -> { + var fieldTypedData = fieldSupplier.get(); + var limitTypedData = limitCaseSupplier.get().forceLiteral(); + var limit = (int) limitTypedData.getValue(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .sorted(order.equals("asc") ? Comparator.naturalOrder() : Comparator.reverseOrder()) + .limit(limit) + .toList(); + + return new TestCaseSupplier.TestCase( + List.of( + fieldTypedData, + limitTypedData, + new TestCaseSupplier.TypedData(new BytesRef(order), DataType.KEYWORD, order + " order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + fieldSupplier.type(), + equalTo(expected) + ); + }); + } } From 3802d84429f14d1cae5a39fc24449f9d8ecd0ea5 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 8 Jul 2024 10:29:48 +0100 Subject: [PATCH 06/64] Simplify `HandshakingTransportAddressConnector` (#110572) Replaces the deeply-nested listener stack with a `SubscribableListener` sequence to clarify the process, particularly its failure handling. Also adds assertions that each step does not double-complete its listener. --- .../HandshakingTransportAddressConnector.java | 199 +++++++++--------- .../elasticsearch/discovery/PeerFinder.java | 145 +++++++------ 2 files changed, 177 insertions(+), 167 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java index 1b68383b8f99f..d234c1797e090 100644 --- a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java +++ b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.Randomness; @@ -27,6 +28,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportRequestOptions.Type; import org.elasticsearch.transport.TransportService; @@ -73,11 +75,26 @@ public HandshakingTransportAddressConnector(Settings settings, TransportService @Override public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionListener listener) { - try { + new ConnectionAttempt(transportAddress).run(listener); + } + + private class ConnectionAttempt { + private final TransportAddress transportAddress; + + ConnectionAttempt(TransportAddress transportAddress) { + this.transportAddress = transportAddress; + } + + void run(ActionListener listener) { + SubscribableListener.newForked(this::openProbeConnection) + .andThen(this::handshakeProbeConnection) + .andThen(this::openFullConnection) + .addListener(listener); + } + private void openProbeConnection(ActionListener listener) { // We could skip this if the transportService were already connected to the given address, but the savings would be minimal so // we open a new connection anyway. - logger.trace("[{}] opening probe connection", transportAddress); transportService.openConnection( new DiscoveryNode( @@ -96,103 +113,91 @@ public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionL ) ), handshakeConnectionProfile, - listener.delegateFailure((l, connection) -> { - logger.trace("[{}] opened probe connection", transportAddress); - final var probeHandshakeTimeout = handshakeConnectionProfile.getHandshakeTimeout(); - // use NotifyOnceListener to make sure the following line does not result in onFailure being called when - // the connection is closed in the onResponse handler - transportService.handshake(connection, probeHandshakeTimeout, ActionListener.notifyOnce(new ActionListener<>() { - - @Override - public void onResponse(DiscoveryNode remoteNode) { - try { - // success means (amongst other things) that the cluster names match - logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode); - IOUtils.closeWhileHandlingException(connection); - - if (remoteNode.equals(transportService.getLocalNode())) { - listener.onFailure( - new ConnectTransportException( - remoteNode, - String.format( - Locale.ROOT, - "successfully discovered local node %s at [%s]", - remoteNode.descriptionWithoutAttributes(), - transportAddress - ) - ) - ); - } else if (remoteNode.isMasterNode() == false) { - listener.onFailure( - new ConnectTransportException( - remoteNode, - String.format( - Locale.ROOT, - """ - successfully discovered master-ineligible node %s at [%s]; to suppress this message, \ - remove address [%s] from your discovery configuration or ensure that traffic to this \ - address is routed only to master-eligible nodes""", - remoteNode.descriptionWithoutAttributes(), - transportAddress, - transportAddress - ) - ) - ); - } else { - transportService.connectToNode(remoteNode, new ActionListener<>() { - @Override - public void onResponse(Releasable connectionReleasable) { - logger.trace("[{}] completed full connection with [{}]", transportAddress, remoteNode); - listener.onResponse(new ProbeConnectionResult(remoteNode, connectionReleasable)); - } - - @Override - public void onFailure(Exception e) { - // we opened a connection and successfully performed a handshake, so we're definitely - // talking to a master-eligible node with a matching cluster name and a good version, but - // the attempt to open a full connection to its publish address failed; a common reason is - // that the remote node is listening on 0.0.0.0 but has made an inappropriate choice for its - // publish address. - logger.warn( - () -> format( - """ - Successfully discovered master-eligible node [%s] at address [%s] but could not \ - connect to it at its publish address of [%s]. Each node in a cluster must be \ - accessible at its publish address by all other nodes in the cluster. See %s for \ - more information.""", - remoteNode.descriptionWithoutAttributes(), - transportAddress, - remoteNode.getAddress(), - ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING - ), - e - ); - listener.onFailure(e); - } - }); - } - } catch (Exception e) { - listener.onFailure(e); - } - } - - @Override - public void onFailure(Exception e) { - // we opened a connection and successfully performed a low-level handshake, so we were definitely - // talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of - // mismatched configurations (e.g. cluster name) that the user should address - logger.warn(() -> "handshake to [" + transportAddress + "] failed", e); - IOUtils.closeWhileHandlingException(connection); - listener.onFailure(e); - } - - })); - - }) + ActionListener.assertOnce(listener) ); + } + + private void handshakeProbeConnection(ActionListener listener, Transport.Connection connection) { + logger.trace("[{}] opened probe connection", transportAddress); + final var probeHandshakeTimeout = handshakeConnectionProfile.getHandshakeTimeout(); + transportService.handshake(connection, probeHandshakeTimeout, ActionListener.assertOnce(new ActionListener<>() { + @Override + public void onResponse(DiscoveryNode remoteNode) { + // success means (amongst other things) that the cluster names match + logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode); + IOUtils.closeWhileHandlingException(connection); + listener.onResponse(remoteNode); + } + + @Override + public void onFailure(Exception e) { + // We opened a connection and successfully performed a low-level handshake, so we were definitely talking to an + // Elasticsearch node, but the high-level handshake failed indicating some kind of mismatched configurations (e.g. + // cluster name) that the user should address. + logger.warn(() -> "handshake to [" + transportAddress + "] failed", e); + IOUtils.closeWhileHandlingException(connection); + listener.onFailure(e); + } + })); + } - } catch (Exception e) { - listener.onFailure(e); + private void openFullConnection(ActionListener listener, DiscoveryNode remoteNode) { + if (remoteNode.equals(transportService.getLocalNode())) { + throw new ConnectTransportException( + remoteNode, + String.format( + Locale.ROOT, + "successfully discovered local node %s at [%s]", + remoteNode.descriptionWithoutAttributes(), + transportAddress + ) + ); + } + + if (remoteNode.isMasterNode() == false) { + throw new ConnectTransportException( + remoteNode, + String.format( + Locale.ROOT, + """ + successfully discovered master-ineligible node %s at [%s]; to suppress this message, remove address [%s] from \ + your discovery configuration or ensure that traffic to this address is routed only to master-eligible nodes""", + remoteNode.descriptionWithoutAttributes(), + transportAddress, + transportAddress + ) + ); + } + + transportService.connectToNode(remoteNode, ActionListener.assertOnce(new ActionListener<>() { + @Override + public void onResponse(Releasable connectionReleasable) { + logger.trace("[{}] completed full connection with [{}]", transportAddress, remoteNode); + listener.onResponse(new ProbeConnectionResult(remoteNode, connectionReleasable)); + } + + @Override + public void onFailure(Exception e) { + // We opened a connection and successfully performed a handshake, so we're definitely talking to a master-eligible node + // with a matching cluster name and a good version, but the attempt to open a full connection to its publish address + // failed; a common reason is that the remote node is listening on 0.0.0.0 but has made an inappropriate choice for its + // publish address. + logger.warn( + () -> format( + """ + Successfully discovered master-eligible node [%s] at address [%s] but could not connect to it at its \ + publish address of [%s]. Each node in a cluster must be accessible at its publish address by all other \ + nodes in the cluster. See %s for more information.""", + remoteNode.descriptionWithoutAttributes(), + transportAddress, + remoteNode.getAddress(), + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ), + e + ); + listener.onFailure(e); + } + })); } } } diff --git a/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java b/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java index 83660cede004e..11f3bbdc13bbf 100644 --- a/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java +++ b/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.cluster.coordination.ClusterFormationFailureHelper; import org.elasticsearch.cluster.coordination.PeersResponse; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -413,86 +414,90 @@ void establishConnection() { - activatedAtMillis > verbosityIncreaseTimeout.millis(); logger.trace("{} attempting connection", this); - transportAddressConnector.connectToRemoteMasterNode(transportAddress, new ActionListener() { - @Override - public void onResponse(ProbeConnectionResult connectResult) { - assert holdsLock() == false : "PeerFinder mutex is held in error"; - final DiscoveryNode remoteNode = connectResult.getDiscoveryNode(); - assert remoteNode.isMasterNode() : remoteNode + " is not master-eligible"; - assert remoteNode.equals(getLocalNode()) == false : remoteNode + " is the local node"; - boolean retainConnection = false; - try { - synchronized (mutex) { - if (isActive() == false) { - logger.trace("Peer#establishConnection inactive: {}", Peer.this); - return; + transportAddressConnector.connectToRemoteMasterNode( + transportAddress, + // may be completed on the calling thread, and therefore under the mutex, so must always fork + new ThreadedActionListener<>(clusterCoordinationExecutor, new ActionListener<>() { + @Override + public void onResponse(ProbeConnectionResult connectResult) { + assert holdsLock() == false : "PeerFinder mutex is held in error"; + final DiscoveryNode remoteNode = connectResult.getDiscoveryNode(); + assert remoteNode.isMasterNode() : remoteNode + " is not master-eligible"; + assert remoteNode.equals(getLocalNode()) == false : remoteNode + " is the local node"; + boolean retainConnection = false; + try { + synchronized (mutex) { + if (isActive() == false) { + logger.trace("Peer#establishConnection inactive: {}", Peer.this); + return; + } + + assert probeConnectionResult.get() == null + : "connection result unexpectedly already set to " + probeConnectionResult.get(); + probeConnectionResult.set(connectResult); + + requestPeers(); } - assert probeConnectionResult.get() == null - : "connection result unexpectedly already set to " + probeConnectionResult.get(); - probeConnectionResult.set(connectResult); - - requestPeers(); - } - - onFoundPeersUpdated(); + onFoundPeersUpdated(); - retainConnection = true; - } finally { - if (retainConnection == false) { - Releasables.close(connectResult); + retainConnection = true; + } finally { + if (retainConnection == false) { + Releasables.close(connectResult); + } } } - } - @Override - public void onFailure(Exception e) { - if (verboseFailureLogging) { - - final String believedMasterBy; - synchronized (mutex) { - believedMasterBy = peersByAddress.values() - .stream() - .filter(p -> p.lastKnownMasterNode.map(DiscoveryNode::getAddress).equals(Optional.of(transportAddress))) - .findFirst() - .map(p -> " [current master according to " + p.getDiscoveryNode().descriptionWithoutAttributes() + "]") - .orElse(""); - } + @Override + public void onFailure(Exception e) { + if (verboseFailureLogging) { + + final String believedMasterBy; + synchronized (mutex) { + believedMasterBy = peersByAddress.values() + .stream() + .filter(p -> p.lastKnownMasterNode.map(DiscoveryNode::getAddress).equals(Optional.of(transportAddress))) + .findFirst() + .map(p -> " [current master according to " + p.getDiscoveryNode().descriptionWithoutAttributes() + "]") + .orElse(""); + } - if (logger.isDebugEnabled()) { - // log message at level WARN, but since DEBUG logging is enabled we include the full stack trace - logger.warn(() -> format("%s%s discovery result", Peer.this, believedMasterBy), e); - } else { - final StringBuilder messageBuilder = new StringBuilder(); - Throwable cause = e; - while (cause != null && messageBuilder.length() <= 1024) { - messageBuilder.append(": ").append(cause.getMessage()); - cause = cause.getCause(); + if (logger.isDebugEnabled()) { + // log message at level WARN, but since DEBUG logging is enabled we include the full stack trace + logger.warn(() -> format("%s%s discovery result", Peer.this, believedMasterBy), e); + } else { + final StringBuilder messageBuilder = new StringBuilder(); + Throwable cause = e; + while (cause != null && messageBuilder.length() <= 1024) { + messageBuilder.append(": ").append(cause.getMessage()); + cause = cause.getCause(); + } + final String message = messageBuilder.length() < 1024 + ? messageBuilder.toString() + : (messageBuilder.substring(0, 1023) + "..."); + logger.warn( + "{}{} discovery result{}; for summary, see logs from {}; for troubleshooting guidance, see {}", + Peer.this, + believedMasterBy, + message, + ClusterFormationFailureHelper.class.getCanonicalName(), + ReferenceDocs.DISCOVERY_TROUBLESHOOTING + ); } - final String message = messageBuilder.length() < 1024 - ? messageBuilder.toString() - : (messageBuilder.substring(0, 1023) + "..."); - logger.warn( - "{}{} discovery result{}; for summary, see logs from {}; for troubleshooting guidance, see {}", - Peer.this, - believedMasterBy, - message, - ClusterFormationFailureHelper.class.getCanonicalName(), - ReferenceDocs.DISCOVERY_TROUBLESHOOTING - ); + } else { + logger.debug(() -> format("%s discovery result", Peer.this), e); + } + synchronized (mutex) { + assert probeConnectionResult.get() == null + : "discoveryNode unexpectedly already set to " + probeConnectionResult.get(); + if (isActive()) { + peersByAddress.remove(transportAddress); + } // else this Peer has been superseded by a different instance which should be left in place } - } else { - logger.debug(() -> format("%s discovery result", Peer.this), e); - } - synchronized (mutex) { - assert probeConnectionResult.get() == null - : "discoveryNode unexpectedly already set to " + probeConnectionResult.get(); - if (isActive()) { - peersByAddress.remove(transportAddress); - } // else this Peer has been superseded by a different instance which should be left in place } - } - }); + }) + ); } private void requestPeers() { From 35c44f7ade82d08ac50b654ff899b7c2ed20f44d Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Mon, 8 Jul 2024 14:13:10 +0200 Subject: [PATCH 07/64] Fix bug in union-types with type-casting in grouping key of STATS (#110476) * Allow auto-generated type-cast fields in CsvTests This allows, for example, a csv-spec test result header like `client_ip::ip:ip`, which is generated with a command like `STATS count=count(*) BY client_ip::ip` It is also a small cleanup of the header parsing code, since it was using Strings.split() in an odd way. * Fix bug in union-types with type-casting in grouping key of STATS * Update docs/changelog/110476.yaml * Added casting_operator required capability Using the new `::` syntax requires disabling support for older versions in multi-cluster tests. * Added more tests for inline stats over long/datetime * Trying to fix the STATS...STATS bug This makes two changes: * Keeps the Alias in the aggs.aggregates from the grouping key, so that ReplaceStatsNestedExpressionWithEval still works * Adds explicit support for union-types conversion at grouping key loading in the ordinalGroupingOperatorFactory Neither fix the particular edge case, but do seem correct * Added EsqlCapability for this change So that mixed cluster tests don't fail these new queries. * Fix InsertFieldExtract for union types Union types require a FieldExtractExec to be performed first thing at the bottom of local physical plans. In queries like ``` from testidx* | eval x = to_string(client_ip) | stats c = count(*) by x | keep c ``` The `stats` has the grouping `x` but the aggregates get pruned to just `c`. In cases like this, we did not insert a FieldExtractExec, which this fixes. * Revert query that previously failed With Alex's fix, this query now passes. * Revert integration of union-types to ordinals aggregator This is because we have not found a test case that actually demonstrates this is necessary. * More tests that would fail without the latest fix * Correct code style * Fix failing case when aggregating on union-type with invalid grouping key * Capabilities restrictions on the new YML tests * Update docs/changelog/110476.yaml --------- Co-authored-by: Alexander Spies --- docs/changelog/110476.yaml | 7 + .../xpack/esql/CsvTestUtils.java | 23 ++- .../src/main/resources/union_types.csv-spec | 175 +++++++++++++++++- .../xpack/esql/action/EsqlCapabilities.java | 7 +- .../xpack/esql/analysis/Analyzer.java | 22 +++ .../optimizer/LocalPhysicalPlanOptimizer.java | 10 +- .../test/esql/160_union_types.yml | 160 +++++++++++++++- 7 files changed, 386 insertions(+), 18 deletions(-) create mode 100644 docs/changelog/110476.yaml diff --git a/docs/changelog/110476.yaml b/docs/changelog/110476.yaml new file mode 100644 index 0000000000000..bc12b3711a366 --- /dev/null +++ b/docs/changelog/110476.yaml @@ -0,0 +1,7 @@ +pr: 110476 +summary: Fix bug in union-types with type-casting in grouping key of STATS +area: ES|QL +type: bug +issues: + - 109922 + - 110477 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java index d88d7f9b9448f..3b3e12978ae04 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java @@ -10,7 +10,6 @@ import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.time.DateFormatters; @@ -332,15 +331,15 @@ public static ExpectedResults loadCsvSpecValues(String csv) { columnTypes = new ArrayList<>(header.length); for (String c : header) { - String[] nameWithType = Strings.split(c, ":"); - if (nameWithType == null || nameWithType.length != 2) { + String[] nameWithType = escapeTypecast(c).split(":"); + if (nameWithType.length != 2) { throw new IllegalArgumentException("Invalid CSV header " + c); } - String typeName = nameWithType[1].trim(); - if (typeName.length() == 0) { - throw new IllegalArgumentException("A type is always expected in the csv file; found " + nameWithType); + String typeName = unescapeTypecast(nameWithType[1]).trim(); + if (typeName.isEmpty()) { + throw new IllegalArgumentException("A type is always expected in the csv file; found " + Arrays.toString(nameWithType)); } - String name = nameWithType[0].trim(); + String name = unescapeTypecast(nameWithType[0]).trim(); columnNames.add(name); Type type = Type.asType(typeName); if (type == null) { @@ -398,6 +397,16 @@ public static ExpectedResults loadCsvSpecValues(String csv) { } } + private static final String TYPECAST_SPACER = "__TYPECAST__"; + + private static String escapeTypecast(String typecast) { + return typecast.replace("::", TYPECAST_SPACER); + } + + private static String unescapeTypecast(String typecast) { + return typecast.replace(TYPECAST_SPACER, "::"); + } + public enum Type { INTEGER(Integer::parseInt, Integer.class), LONG(Long::parseLong, Long.class), diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec index ee8c4be385e0f..5783489195458 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec @@ -45,8 +45,10 @@ FROM sample_data_ts_long ; singleIndexIpStats +required_capability: casting_operator + FROM sample_data -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -60,8 +62,10 @@ count:long | client_ip:ip ; singleIndexIpStringStats +required_capability: casting_operator + FROM sample_data_str -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -74,12 +78,28 @@ count:long | client_ip:ip 1 | 172.21.2.162 ; +singleIndexIpStringStatsInline +required_capability: casting_operator + +FROM sample_data_str +| STATS count=count(*) BY client_ip::ip +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +| KEEP mc, count +; + +mc:l | count:l +3 | 1 +1 | 4 +; + multiIndexIpString required_capability: union_types required_capability: metadata_fields +required_capability: casting_operator FROM sample_data, sample_data_str METADATA _index -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | KEEP _index, @timestamp, client_ip, event_duration, message | SORT _index ASC, @timestamp DESC ; @@ -104,9 +124,10 @@ sample_data_str | 2023-10-23T12:15:03.360Z | 172.21.2.162 | 3450233 multiIndexIpStringRename required_capability: union_types required_capability: metadata_fields +required_capability: casting_operator FROM sample_data, sample_data_str METADATA _index -| EVAL host_ip = TO_IP(client_ip) +| EVAL host_ip = client_ip::ip | KEEP _index, @timestamp, host_ip, event_duration, message | SORT _index ASC, @timestamp DESC ; @@ -191,9 +212,10 @@ sample_data_str | 2023-10-23T12:15:03.360Z | 3450233 | Connected multiIndexIpStringStats required_capability: union_types +required_capability: casting_operator FROM sample_data, sample_data_str -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -208,9 +230,10 @@ count:long | client_ip:ip multiIndexIpStringRenameStats required_capability: union_types +required_capability: casting_operator FROM sample_data, sample_data_str -| EVAL host_ip = TO_IP(client_ip) +| EVAL host_ip = client_ip::ip | STATS count=count(*) BY host_ip | SORT count DESC, host_ip ASC | KEEP count, host_ip @@ -240,6 +263,24 @@ count:long | host_ip:keyword 2 | 172.21.2.162 ; +multiIndexIpStringStatsDrop +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| KEEP count +| SORT count DESC +; + +count:long +8 +2 +2 +2 +; + multiIndexIpStringStatsInline required_capability: union_types required_capability: union_types_inline_fix @@ -257,6 +298,39 @@ count:long | client_ip:ip 2 | 172.21.2.162 ; +multiIndexIpStringStatsInline2 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| SORT count DESC, `client_ip::ip` ASC +; + +count:long | client_ip::ip:ip +8 | 172.21.3.15 +2 | 172.21.0.5 +2 | 172.21.2.113 +2 | 172.21.2.162 +; + +multiIndexIpStringStatsInline3 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +3 | 2 +1 | 8 +; + multiIndexWhereIpStringStats required_capability: union_types @@ -385,6 +459,61 @@ count:long | @timestamp:date 4 | 2023-10-23T12:00:00.000Z ; +multiIndexTsLongStatsDrop +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| KEEP count +; + +count:long +2 +2 +2 +2 +2 +2 +2 +; + +multiIndexTsLongStatsInline2 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| SORT count DESC, `@timestamp::datetime` DESC +; + +count:long | @timestamp::datetime:datetime +2 | 2023-10-23T13:55:01.543Z +2 | 2023-10-23T13:53:55.832Z +2 | 2023-10-23T13:52:55.015Z +2 | 2023-10-23T13:51:54.732Z +2 | 2023-10-23T13:33:34.937Z +2 | 2023-10-23T12:27:28.948Z +2 | 2023-10-23T12:15:03.360Z +; + +multiIndexTsLongStatsInline3 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +7 | 2 +; + multiIndexTsLongRenameStats required_capability: union_types @@ -717,3 +846,37 @@ null | null | 8268153 | Connection error | samp null | null | 8268153 | Connection error | sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 null | null | 8268153 | Connection error | sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 ; + +multiIndexMultiColumnTypesRenameAndKeep +required_capability: union_types +required_capability: metadata_fields + +FROM sample_data* METADATA _index +| WHERE event_duration > 8000000 +| EVAL ts = TO_DATETIME(@timestamp), ts_str = TO_STRING(@timestamp), ts_l = TO_LONG(@timestamp), ip = TO_IP(client_ip), ip_str = TO_STRING(client_ip) +| KEEP _index, ts, ts_str, ts_l, ip, ip_str, event_duration +| SORT _index ASC, ts DESC +; + +_index:keyword | ts:date | ts_str:keyword | ts_l:long | ip:ip | ip_str:k | event_duration:long +sample_data | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +; + +multiIndexMultiColumnTypesRenameAndDrop +required_capability: union_types +required_capability: metadata_fields + +FROM sample_data* METADATA _index +| WHERE event_duration > 8000000 +| EVAL ts = TO_DATETIME(@timestamp), ts_str = TO_STRING(@timestamp), ts_l = TO_LONG(@timestamp), ip = TO_IP(client_ip), ip_str = TO_STRING(client_ip) +| DROP @timestamp, client_ip, message +| SORT _index ASC, ts DESC +; + +event_duration:long | _index:keyword | ts:date | ts_str:keyword | ts_l:long | ip:ip | ip_str:k +8268153 | sample_data | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 +8268153 | sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 +8268153 | sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 07362311d37a5..88f6ff0c95b05 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -106,7 +106,12 @@ public enum Cap { /** * Support for WEIGHTED_AVG function. */ - AGG_WEIGHTED_AVG; + AGG_WEIGHTED_AVG, + + /** + * Fix for union-types when aggregating over an inline conversion with casting operator. Done in #110476. + */ + UNION_TYPES_AGG_CAST; private final boolean snapshotOnly; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index cdb5935f9bd72..30ffffd4770a9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -1086,6 +1086,19 @@ protected LogicalPlan doRule(LogicalPlan plan) { return plan; } + // In ResolveRefs the aggregates are resolved from the groupings, which might have an unresolved MultiTypeEsField. + // Now that we have resolved those, we need to re-resolve the aggregates. + if (plan instanceof EsqlAggregate agg && agg.expressionsResolved() == false) { + Map resolved = new HashMap<>(); + for (Expression e : agg.groupings()) { + Attribute attr = Expressions.attribute(e); + if (attr != null && attr.resolved()) { + resolved.put(attr, e); + } + } + plan = agg.transformExpressionsOnly(UnresolvedAttribute.class, ua -> resolveAttribute(ua, resolved)); + } + // Otherwise drop the converted attributes after the alias function, as they are only needed for this function, and // the original version of the attribute should still be seen as unconverted. plan = dropConvertedAttributes(plan, unionFieldAttributes); @@ -1109,6 +1122,15 @@ protected LogicalPlan doRule(LogicalPlan plan) { return plan; } + private Expression resolveAttribute(UnresolvedAttribute ua, Map resolved) { + var named = resolveAgainstList(ua, resolved.keySet()); + return switch (named.size()) { + case 0 -> ua; + case 1 -> named.get(0).equals(ua) ? ua : resolved.get(named.get(0)); + default -> ua.withUnresolvedMessage("Resolved [" + ua + "] unexpectedly to multiple attributes " + named); + }; + } + private LogicalPlan dropConvertedAttributes(LogicalPlan plan, List unionFieldAttributes) { List projections = new ArrayList<>(plan.output()); for (var e : unionFieldAttributes) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 1b40a1c2b02ad..f78ae6930d9ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -77,6 +77,7 @@ import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.EsqlTranslatorHandler; import org.elasticsearch.xpack.esql.stats.SearchStats; +import org.elasticsearch.xpack.esql.type.MultiTypeEsField; import java.nio.ByteOrder; import java.util.ArrayList; @@ -193,7 +194,10 @@ public PhysicalPlan apply(PhysicalPlan plan) { * it loads the field lazily. If we have more than one field we need to * make sure the fields are loaded for the standard hash aggregator. */ - if (p instanceof AggregateExec agg && agg.groupings().size() == 1) { + if (p instanceof AggregateExec agg + && agg.groupings().size() == 1 + && (isMultiTypeFieldAttribute(agg.groupings().get(0)) == false) // Union types rely on field extraction. + ) { var leaves = new LinkedList<>(); // TODO: this seems out of place agg.aggregates() @@ -217,6 +221,10 @@ public PhysicalPlan apply(PhysicalPlan plan) { return plan; } + private static boolean isMultiTypeFieldAttribute(Expression attribute) { + return attribute instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField; + } + private static Set missingAttributes(PhysicalPlan p) { var missing = new LinkedHashSet(); var input = p.inputSet(); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml index f3403ca8751c0..aac60d9aaa8d0 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml @@ -147,6 +147,9 @@ setup: - '{"index": {}}' - '{"@timestamp": "2023-10-23T12:15:03.360Z", "client_ip": "172.21.2.162", "event_duration": "3450233", "message": "Connected to 10.1.0.3"}' +############################################################################################################ +# Test a single index as a control of the expected results + --- load single index ip_long: - do: @@ -173,9 +176,6 @@ load single index ip_long: - match: { values.0.3: 1756467 } - match: { values.0.4: "Connected to 10.1.0.1" } -############################################################################################################ -# Test a single index as a control of the expected results - --- load single index keyword_keyword: - do: @@ -202,6 +202,83 @@ load single index keyword_keyword: - match: { values.0.3: "1756467" } - match: { values.0.4: "Connected to 10.1.0.1" } +--- +load single index ip_long and aggregate by client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(*) BY client_ip::ip | SORT count DESC, `client_ip::ip` ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "client_ip::ip" } + - match: { columns.1.type: "ip" } + - length: { values: 4 } + - match: { values.0.0: 4 } + - match: { values.0.1: "172.21.3.15" } + - match: { values.1.0: 1 } + - match: { values.1.1: "172.21.0.5" } + - match: { values.2.0: 1 } + - match: { values.2.1: "172.21.2.113" } + - match: { values.3.0: 1 } + - match: { values.3.1: "172.21.2.162" } + +--- +load single index ip_long and aggregate client_ip my message: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(client_ip::ip) BY message | SORT count DESC, message ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "message" } + - match: { columns.1.type: "keyword" } + - length: { values: 5 } + - match: { values.0.0: 3 } + - match: { values.0.1: "Connection error" } + - match: { values.1.0: 1 } + - match: { values.1.1: "Connected to 10.1.0.1" } + - match: { values.2.0: 1 } + - match: { values.2.1: "Connected to 10.1.0.2" } + - match: { values.3.0: 1 } + - match: { values.3.1: "Connected to 10.1.0.3" } + - match: { values.4.0: 1 } + - match: { values.4.1: "Disconnected" } + +--- +load single index ip_long stats invalid grouping: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + catch: '/Unknown column \[x\]/' + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(client_ip::ip) BY x' + ############################################################################################################ # Test two indices where the event_duration is mapped as a LONG and as a KEYWORD @@ -512,6 +589,83 @@ load two indices, convert, rename but not drop ambiguous field client_ip: - match: { values.1.5: "172.21.3.15" } - match: { values.1.6: "172.21.3.15" } +--- +load two indexes and group by converted client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(*) BY client_ip::ip | SORT count DESC, `client_ip::ip` ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "client_ip::ip" } + - match: { columns.1.type: "ip" } + - length: { values: 4 } + - match: { values.0.0: 8 } + - match: { values.0.1: "172.21.3.15" } + - match: { values.1.0: 2 } + - match: { values.1.1: "172.21.0.5" } + - match: { values.2.0: 2 } + - match: { values.2.1: "172.21.2.113" } + - match: { values.3.0: 2 } + - match: { values.3.1: "172.21.2.162" } + +--- +load two indexes and aggregate converted client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(client_ip::ip) BY message | SORT count DESC, message ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "message" } + - match: { columns.1.type: "keyword" } + - length: { values: 5 } + - match: { values.0.0: 6 } + - match: { values.0.1: "Connection error" } + - match: { values.1.0: 2 } + - match: { values.1.1: "Connected to 10.1.0.1" } + - match: { values.2.0: 2 } + - match: { values.2.1: "Connected to 10.1.0.2" } + - match: { values.3.0: 2 } + - match: { values.3.1: "Connected to 10.1.0.3" } + - match: { values.4.0: 2 } + - match: { values.4.1: "Disconnected" } + +--- +load two indexes, convert client_ip and group by something invalid: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + catch: '/Unknown column \[x\]/' + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(client_ip::ip) BY x' + ############################################################################################################ # Test four indices with both the client_IP (IP and KEYWORD) and event_duration (LONG and KEYWORD) mappings From 146e15a92ced94a24a9eb3049691b64c7b32e284 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Mon, 8 Jul 2024 08:40:54 -0400 Subject: [PATCH 08/64] ESQL: Rework sequence for analyzing queries (#110566) In service of the incoming INLINESTATS this flips the ordering of analysis. Previously we made the entire sequence of analyze, optimize, convert to physical plan, and optimize in a single async sequence in `EsqlSession`. This flips it so `analyze` comes first in it's own async sequence and then runs the remaining stuff in a separate sequence. That's nice for INLINESTATS where we want to analyze one time, and then run many runs of the extra sequence. While we're here, we also take that sequence call it directly from the CsvTests. That works well because that final sequence is exactly what CsvTests have to do. They "analyze" totally differently, but they run the final sequence in the same way. Closes #107953 --- .../xpack/esql/core/util/ActionListeners.java | 26 -- .../xpack/esql/session/EsqlSession.java | 85 ++++--- .../elasticsearch/xpack/esql/CsvTests.java | 228 ++++++++++-------- 3 files changed, 175 insertions(+), 164 deletions(-) delete mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java deleted file mode 100644 index 025f9c2b6fd7a..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.util; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.CheckedConsumer; -import org.elasticsearch.core.CheckedFunction; - -import java.util.function.Consumer; - -public class ActionListeners { - - private ActionListeners() {} - - /** - * Combination of {@link ActionListener#wrap(CheckedConsumer, Consumer)} and {@link ActionListener#map} - */ - public static ActionListener map(ActionListener delegate, CheckedFunction fn) { - return delegate.delegateFailureAndWrap((l, r) -> l.onResponse(fn.apply(r))); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 2a4f07a1aa319..8c831cc260e03 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -64,7 +64,6 @@ import java.util.stream.Collectors; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; -import static org.elasticsearch.xpack.esql.core.util.ActionListeners.map; import static org.elasticsearch.xpack.esql.core.util.StringUtils.WILDCARD; public class EsqlSession { @@ -111,33 +110,29 @@ public String sessionId() { return sessionId; } + /** + * Execute an ESQL request. + */ public void execute( EsqlQueryRequest request, BiConsumer> runPhase, ActionListener listener ) { LOGGER.debug("ESQL query:\n{}", request.query()); - LogicalPlan logicalPlan = parse(request.query(), request.params()); - logicalPlanToPhysicalPlan(logicalPlan, request, listener.delegateFailureAndWrap((l, r) -> runPhase.accept(r, l))); + analyzedPlan( + parse(request.query(), request.params()), + listener.delegateFailureAndWrap((next, analyzedPlan) -> executeAnalyzedPlan(request, runPhase, analyzedPlan, next)) + ); } - private void logicalPlanToPhysicalPlan(LogicalPlan logicalPlan, EsqlQueryRequest request, ActionListener listener) { - optimizedPhysicalPlan( - logicalPlan, - listener.map(plan -> EstimatesRowSize.estimateRowSize(0, plan.transformUp(FragmentExec.class, f -> { - QueryBuilder filter = request.filter(); - if (filter != null) { - var fragmentFilter = f.esFilter(); - // TODO: have an ESFilter and push down to EsQueryExec / EsSource - // This is an ugly hack to push the filter parameter to Lucene - // TODO: filter integration testing - filter = fragmentFilter != null ? boolQuery().filter(fragmentFilter).must(filter) : filter; - LOGGER.debug("Fold filter {} to EsQueryExec", filter); - f = f.withFilter(filter); - } - return f; - }))) - ); + public void executeAnalyzedPlan( + EsqlQueryRequest request, + BiConsumer> runPhase, + LogicalPlan analyzedPlan, + ActionListener listener + ) { + // TODO phased execution lands here. + runPhase.accept(logicalPlanToPhysicalPlan(analyzedPlan, request), listener); } private LogicalPlan parse(String query, QueryParams params) { @@ -155,6 +150,7 @@ public void analyzedPlan(LogicalPlan parsed, ActionListener listene preAnalyze(parsed, (indices, policies) -> { Analyzer analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indices, policies), verifier); var plan = analyzer.analyze(parsed); + plan.setAnalyzed(); LOGGER.debug("Analyzed plan:\n{}", plan); return plan; }, listener); @@ -315,28 +311,41 @@ private static Set subfields(Set names) { return names.stream().filter(name -> name.endsWith(WILDCARD) == false).map(name -> name + ".*").collect(Collectors.toSet()); } - public void optimizedPlan(LogicalPlan logicalPlan, ActionListener listener) { - analyzedPlan(logicalPlan, map(listener, p -> { - var plan = logicalPlanOptimizer.optimize(p); - LOGGER.debug("Optimized logicalPlan plan:\n{}", plan); - return plan; - })); + private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan logicalPlan, EsqlQueryRequest request) { + PhysicalPlan physicalPlan = optimizedPhysicalPlan(logicalPlan); + physicalPlan = physicalPlan.transformUp(FragmentExec.class, f -> { + QueryBuilder filter = request.filter(); + if (filter != null) { + var fragmentFilter = f.esFilter(); + // TODO: have an ESFilter and push down to EsQueryExec / EsSource + // This is an ugly hack to push the filter parameter to Lucene + // TODO: filter integration testing + filter = fragmentFilter != null ? boolQuery().filter(fragmentFilter).must(filter) : filter; + LOGGER.debug("Fold filter {} to EsQueryExec", filter); + f = f.withFilter(filter); + } + return f; + }); + return EstimatesRowSize.estimateRowSize(0, physicalPlan); } - public void physicalPlan(LogicalPlan optimized, ActionListener listener) { - optimizedPlan(optimized, map(listener, p -> { - var plan = mapper.map(p); - LOGGER.debug("Physical plan:\n{}", plan); - return plan; - })); + public LogicalPlan optimizedPlan(LogicalPlan logicalPlan) { + assert logicalPlan.analyzed(); + var plan = logicalPlanOptimizer.optimize(logicalPlan); + LOGGER.debug("Optimized logicalPlan plan:\n{}", plan); + return plan; } - public void optimizedPhysicalPlan(LogicalPlan logicalPlan, ActionListener listener) { - physicalPlan(logicalPlan, map(listener, p -> { - var plan = physicalPlanOptimizer.optimize(p); - LOGGER.debug("Optimized physical plan:\n{}", plan); - return plan; - })); + public PhysicalPlan physicalPlan(LogicalPlan logicalPlan) { + var plan = mapper.map(optimizedPlan(logicalPlan)); + LOGGER.debug("Physical plan:\n{}", plan); + return plan; + } + + public PhysicalPlan optimizedPhysicalPlan(LogicalPlan logicalPlan) { + var plan = physicalPlanOptimizer.optimize(physicalPlan(logicalPlan)); + LOGGER.debug("Optimized physical plan:\n{}", plan); + return plan; } public static InvalidMappedField specificValidity(String fieldName, Map types) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index f61f581f29a13..e8a403ae7d9d0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -48,16 +47,16 @@ import org.elasticsearch.xpack.esql.CsvTestUtils.ActualResults; import org.elasticsearch.xpack.esql.CsvTestUtils.Type; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; import org.elasticsearch.xpack.esql.core.CsvSpecReader; import org.elasticsearch.xpack.esql.core.SpecReader; -import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; @@ -73,7 +72,6 @@ import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; import org.elasticsearch.xpack.esql.plan.physical.OutputExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -85,6 +83,8 @@ import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; +import org.elasticsearch.xpack.esql.session.EsqlSession; +import org.elasticsearch.xpack.esql.session.Result; import org.elasticsearch.xpack.esql.stats.DisabledSearchStats; import org.junit.After; import org.junit.Before; @@ -100,6 +100,7 @@ import java.util.TreeMap; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import static org.elasticsearch.xpack.esql.CsvTestUtils.ExpectedResults; import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled; @@ -330,16 +331,14 @@ private static EnrichPolicy loadEnrichPolicyMapping(String policyFileName) { } } - private PhysicalPlan physicalPlan(LogicalPlan parsed, CsvTestsDataLoader.TestsDataset dataset) { + private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.TestsDataset dataset) { var indexResolution = loadIndexResolution(dataset.mappingFileName(), dataset.indexName()); var enrichPolicies = loadEnrichPolicies(); var analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies), TEST_VERIFIER); - var analyzed = analyzer.analyze(parsed); - var logicalOptimized = new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)).optimize(analyzed); - var physicalPlan = mapper.map(logicalOptimized); - var optimizedPlan = EstimatesRowSize.estimateRowSize(0, physicalPlanOptimizer.optimize(physicalPlan)); - opportunisticallyAssertPlanSerialization(physicalPlan, optimizedPlan); // comment out to disable serialization - return optimizedPlan; + LogicalPlan plan = analyzer.analyze(parsed); + plan.setAnalyzed(); + LOGGER.debug("Analyzed plan:\n{}", plan); + return plan; } private static CsvTestsDataLoader.TestsDataset testsDataset(LogicalPlan parsed) { @@ -381,90 +380,43 @@ private static TestPhysicalOperationProviders testOperationProviders(CsvTestsDat } private ActualResults executePlan(BigArrays bigArrays) throws Exception { - var parsed = parser.createStatement(testCase.query); + LogicalPlan parsed = parser.createStatement(testCase.query); var testDataset = testsDataset(parsed); + LogicalPlan analyzed = analyzedPlan(parsed, testDataset); - String sessionId = "csv-test"; - BlockFactory blockFactory = new BlockFactory( - bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), - bigArrays, - ByteSizeValue.ofBytes(randomLongBetween(1, BlockFactory.DEFAULT_MAX_BLOCK_PRIMITIVE_ARRAY_SIZE.getBytes() * 2)) - ); - ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor); - ExchangeSinkHandler exchangeSink = new ExchangeSinkHandler(blockFactory, between(1, 64), threadPool::relativeTimeInMillis); - LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( - sessionId, - "", - new CancellableTask(1, "transport", "esql", null, TaskId.EMPTY_TASK_ID, Map.of()), - bigArrays, - blockFactory, - randomNodeSettings(), + EsqlSession session = new EsqlSession( + getTestName(), configuration, - exchangeSource, - exchangeSink, - Mockito.mock(EnrichLookupService.class), - testOperationProviders(testDataset) + null, + null, + null, + functionRegistry, + new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)), + mapper, + TEST_VERIFIER ); - // - // Keep in sync with ComputeService#execute - // - PhysicalPlan physicalPlan = physicalPlan(parsed, testDataset); - Tuple coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode( - physicalPlan, - configuration + TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(testDataset); + + PlainActionFuture listener = new PlainActionFuture<>(); + + session.executeAnalyzedPlan( + new EsqlQueryRequest(), + runPhase(bigArrays, physicalOperationProviders), + analyzed, + listener.delegateFailureAndWrap( + // Wrap so we can capture the warnings in the calling thread + (next, result) -> next.onResponse( + new ActualResults( + result.schema().stream().map(Attribute::name).toList(), + result.schema().stream().map(a -> Type.asType(a.dataType().nameUpper())).toList(), + result.schema().stream().map(Attribute::dataType).toList(), + result.pages(), + threadPool.getThreadContext().getResponseHeaders() + ) + ) + ) ); - PhysicalPlan coordinatorPlan = coordinatorAndDataNodePlan.v1(); - PhysicalPlan dataNodePlan = coordinatorAndDataNodePlan.v2(); - - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Coordinator plan\n" + coordinatorPlan); - LOGGER.trace("DataNode plan\n" + dataNodePlan); - } - - List columnNames = Expressions.names(coordinatorPlan.output()); - List dataTypes = new ArrayList<>(columnNames.size()); - List columnTypes = coordinatorPlan.output() - .stream() - .peek(o -> dataTypes.add(o.dataType())) - .map(o -> Type.asType(o.dataType().nameUpper())) - .toList(); - - List drivers = new ArrayList<>(); - List collectedPages = Collections.synchronizedList(new ArrayList<>()); - - // replace fragment inside the coordinator plan - try { - LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add)); - drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(sessionId)); - if (dataNodePlan != null) { - var searchStats = new DisabledSearchStats(); - var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); - var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer( - new LocalPhysicalOptimizerContext(configuration, searchStats) - ); - - var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); - exchangeSource.addRemoteSink(exchangeSink::fetchPageAsync, randomIntBetween(1, 3)); - LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); - drivers.addAll(dataNodeExecutionPlan.createDrivers(sessionId)); - Randomness.shuffle(drivers); - } - // Execute the driver - DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) { - @Override - protected void start(Driver driver, ActionListener driverListener) { - Driver.start(threadPool.getThreadContext(), executor, driver, between(1, 1000), driverListener); - } - }; - PlainActionFuture future = new PlainActionFuture<>(); - runner.runToCompletion(drivers, ActionListener.releaseAfter(future, () -> Releasables.close(drivers)).map(ignore -> { - var responseHeaders = threadPool.getThreadContext().getResponseHeaders(); - return new ActualResults(columnNames, columnTypes, dataTypes, collectedPages, responseHeaders); - })); - return future.actionGet(TimeValue.timeValueSeconds(30)); - } finally { - Releasables.close(() -> Releasables.close(drivers)); - } + return listener.get(); } private Settings randomNodeSettings() { @@ -487,17 +439,15 @@ private Throwable reworkException(Throwable th) { } // Asserts that the serialization and deserialization of the plan creates an equivalent plan. - private void opportunisticallyAssertPlanSerialization(PhysicalPlan... plans) { - for (var plan : plans) { - var tmp = plan; - do { - if (tmp instanceof LocalSourceExec) { - return; // skip plans with localSourceExec - } - } while (tmp.children().isEmpty() == false && (tmp = tmp.children().get(0)) != null); + private void opportunisticallyAssertPlanSerialization(PhysicalPlan plan) { + var tmp = plan; + do { + if (tmp instanceof LocalSourceExec) { + return; // skip plans with localSourceExec + } + } while (tmp.children().isEmpty() == false && (tmp = tmp.children().get(0)) != null); - SerializationTestUtils.assertSerialization(plan, configuration); - } + SerializationTestUtils.assertSerialization(plan, configuration); } private void assertWarnings(List warnings) { @@ -511,4 +461,82 @@ private void assertWarnings(List warnings) { } EsqlTestUtils.assertWarnings(normalized, testCase.expectedWarnings(true), testCase.expectedWarningsRegex()); } + + BiConsumer> runPhase( + BigArrays bigArrays, + TestPhysicalOperationProviders physicalOperationProviders + ) { + return (physicalPlan, listener) -> runPhase(bigArrays, physicalOperationProviders, physicalPlan, listener); + } + + void runPhase( + BigArrays bigArrays, + TestPhysicalOperationProviders physicalOperationProviders, + PhysicalPlan physicalPlan, + ActionListener listener + ) { + // Keep in sync with ComputeService#execute + opportunisticallyAssertPlanSerialization(physicalPlan); + Tuple coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode( + physicalPlan, + configuration + ); + PhysicalPlan coordinatorPlan = coordinatorAndDataNodePlan.v1(); + PhysicalPlan dataNodePlan = coordinatorAndDataNodePlan.v2(); + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("Coordinator plan\n" + coordinatorPlan); + LOGGER.trace("DataNode plan\n" + dataNodePlan); + } + + BlockFactory blockFactory = new BlockFactory( + bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), + bigArrays, + ByteSizeValue.ofBytes(randomLongBetween(1, BlockFactory.DEFAULT_MAX_BLOCK_PRIMITIVE_ARRAY_SIZE.getBytes() * 2)) + ); + ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor); + ExchangeSinkHandler exchangeSink = new ExchangeSinkHandler(blockFactory, between(1, 64), threadPool::relativeTimeInMillis); + + LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( + getTestName(), + "", + new CancellableTask(1, "transport", "esql", null, TaskId.EMPTY_TASK_ID, Map.of()), + bigArrays, + blockFactory, + randomNodeSettings(), + configuration, + exchangeSource, + exchangeSink, + Mockito.mock(EnrichLookupService.class), + physicalOperationProviders + ); + + List collectedPages = Collections.synchronizedList(new ArrayList<>()); + + // replace fragment inside the coordinator plan + List drivers = new ArrayList<>(); + LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add)); + drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(getTestName())); + if (dataNodePlan != null) { + var searchStats = new DisabledSearchStats(); + var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); + var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats)); + + var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); + exchangeSource.addRemoteSink(exchangeSink::fetchPageAsync, randomIntBetween(1, 3)); + LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); + + drivers.addAll(dataNodeExecutionPlan.createDrivers(getTestName())); + Randomness.shuffle(drivers); + } + // Execute the drivers + DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) { + @Override + protected void start(Driver driver, ActionListener driverListener) { + Driver.start(threadPool.getThreadContext(), executor, driver, between(1, 1000), driverListener); + } + }; + listener = ActionListener.releaseAfter(listener, () -> Releasables.close(drivers)); + runner.runToCompletion(drivers, listener.map(ignore -> new Result(physicalPlan.output(), collectedPages, List.of()))); + } } From facabf627bf8a37978fc6ab952f538d7e2d3d2b0 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Mon, 8 Jul 2024 15:34:16 +0200 Subject: [PATCH 09/64] [Gradle] Only resolve latest patch version for resolveAllDependencies (#110584) * Only resolve latest patch version for resolveAllDependencies This should avoid downloading to many elasticsearch distributions and reduce disk usage and speed up image creations. * Some cleanup --- qa/packaging/build.gradle | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/qa/packaging/build.gradle b/qa/packaging/build.gradle index 758dfe6661766..02bc30ecd6b39 100644 --- a/qa/packaging/build.gradle +++ b/qa/packaging/build.gradle @@ -36,3 +36,13 @@ tasks.named("test").configure { enabled = false } tasks.register('destructivePackagingTest') { dependsOn 'destructiveDistroTest' } + +tasks.named('resolveAllDependencies') { + // Don't try and resolve all distros but only the latest patch versions of each minor + def latestBugfixVersions = org.elasticsearch.gradle.internal.info.BuildParams.getBwcVersions().getIndexCompatible() + .groupBy { [it.major, it.minor] } + .collectEntries { key, value -> [key, value.max()] } + .values() + + configs = configurations.matching { configName -> latestBugfixVersions.any { v -> configName.name.endsWith(v.toString()) } } +} From ec944f5b23a77ed67649ed797503a56bd95e3276 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 8 Jul 2024 15:37:42 +0200 Subject: [PATCH 10/64] [Inference API] Use extractOptionalPositiveInteger in AzureOpenAiEmbeddingsServiceSettings for dims and maxInputTokens (#110483) --- .../AzureOpenAiEmbeddingsServiceSettings.java | 11 ++- ...eOpenAiEmbeddingsServiceSettingsTests.java | 86 +++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java index 1c426815a83c0..a9e40569d4e7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -33,9 +33,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -88,8 +88,13 @@ private static CommonFields fromMap( String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); - Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java index cbb9eea223802..8b754257e9d83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -203,6 +203,92 @@ public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { ); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dimensions = 0; + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, dimensions, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dimensions = randomNegativeInt(); + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, dimensions, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 0; + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, null, maxInputTokens); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, null, maxInputTokens); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_Persistent_CreatesSettingsCorrectly() { var resourceName = "this-resource"; var deploymentId = "this-deployment"; From 00e744ef5ce11d51399783f47d47c8623445fe23 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 8 Jul 2024 09:48:57 -0400 Subject: [PATCH 11/64] A small tidiness refactor of the GeoIpTaskState's Metadata (#110553) --- .../ingest/geoip/GeoIpDownloaderIT.java | 2 +- .../ingest/geoip/GeoIpDownloader.java | 19 ++++++++++++------- .../ingest/geoip/GeoIpTaskState.java | 16 +++++++--------- .../ingest/geoip/GeoIpDownloaderTests.java | 8 ++++---- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java index 9dcd8abc7bc57..9eab00fbadf20 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java @@ -242,7 +242,7 @@ public void testGeoIpDatabasesDownload() throws Exception { Set.of("GeoLite2-ASN.mmdb", "GeoLite2-City.mmdb", "GeoLite2-Country.mmdb", "MyCustomGeoLite2-City.mmdb"), state.getDatabases().keySet() ); - GeoIpTaskState.Metadata metadata = state.get(id); + GeoIpTaskState.Metadata metadata = state.getDatabases().get(id); int size = metadata.lastChunk() - metadata.firstChunk() + 1; assertResponse( prepareSearch(GeoIpDownloader.DATABASES_INDEX).setSize(size) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java index 895c9315d2325..5239e96856b7f 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java @@ -170,23 +170,28 @@ private List fetchDatabasesOverview() throws IOException { } // visible for testing - void processDatabase(Map databaseInfo) { + void processDatabase(final Map databaseInfo) { String name = databaseInfo.get("name").toString().replace(".tgz", "") + ".mmdb"; String md5 = (String) databaseInfo.get("md5_hash"); - if (state.contains(name) && Objects.equals(md5, state.get(name).md5())) { - updateTimestamp(name, state.get(name)); - return; - } - logger.debug("downloading geoip database [{}]", name); String url = databaseInfo.get("url").toString(); if (url.startsWith("http") == false) { // relative url, add it after last slash (i.e. resolve sibling) or at the end if there's no slash after http[s]:// int lastSlash = endpoint.substring(8).lastIndexOf('/'); url = (lastSlash != -1 ? endpoint.substring(0, lastSlash + 8) : endpoint) + "/" + url; } + processDatabase(name, md5, url); + } + + private void processDatabase(final String name, final String md5, final String url) { + Metadata metadata = state.getDatabases().getOrDefault(name, Metadata.EMPTY); + if (Objects.equals(metadata.md5(), md5)) { + updateTimestamp(name, metadata); + return; + } + logger.debug("downloading geoip database [{}]", name); long start = System.currentTimeMillis(); try (InputStream is = httpClient.get(url)) { - int firstChunk = state.contains(name) ? state.get(name).lastChunk() + 1 : 0; + int firstChunk = metadata.lastChunk() + 1; // if there is no metadata, then Metadata.EMPTY.lastChunk() + 1 = 0 int lastChunk = indexChunks(name, is, firstChunk, md5, start); if (lastChunk > firstChunk) { state = state.put(name, new Metadata(start, firstChunk, lastChunk - 1, md5, start)); diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java index d55f517b46e24..a405d90b24dcc 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java @@ -84,14 +84,6 @@ public Map getDatabases() { return databases; } - public boolean contains(String name) { - return databases.containsKey(name); - } - - public Metadata get(String name) { - return databases.get(name); - } - @Override public boolean equals(Object o) { if (this == o) return true; @@ -142,7 +134,13 @@ public void writeTo(StreamOutput out) throws IOException { record Metadata(long lastUpdate, int firstChunk, int lastChunk, String md5, long lastCheck) implements ToXContentObject { - static final String NAME = GEOIP_DOWNLOADER + "-metadata"; + /** + * An empty Metadata object useful for getOrDefault -type calls. Crucially, the 'lastChunk' is -1, so it's safe to use + * with logic that says the new firstChunk is the old lastChunk + 1. + */ + static Metadata EMPTY = new Metadata(-1, -1, -1, "", -1); + + private static final String NAME = GEOIP_DOWNLOADER + "-metadata"; private static final ParseField LAST_CHECK = new ParseField("last_check"); private static final ParseField LAST_UPDATE = new ParseField("last_update"); private static final ParseField FIRST_CHUNK = new ParseField("first_chunk"); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java index 9cc5405c1b617..4834c581e9386 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java @@ -290,8 +290,8 @@ int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long @Override void updateTaskState() { - assertEquals(0, state.get("test.mmdb").firstChunk()); - assertEquals(10, state.get("test.mmdb").lastChunk()); + assertEquals(0, state.getDatabases().get("test.mmdb").firstChunk()); + assertEquals(10, state.getDatabases().get("test.mmdb").lastChunk()); } @Override @@ -341,8 +341,8 @@ int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long @Override void updateTaskState() { - assertEquals(9, state.get("test.mmdb").firstChunk()); - assertEquals(10, state.get("test.mmdb").lastChunk()); + assertEquals(9, state.getDatabases().get("test.mmdb").firstChunk()); + assertEquals(10, state.getDatabases().get("test.mmdb").lastChunk()); } @Override From 333e1bbb81d96815aad0dd81c3e4c082dae64f07 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Tue, 9 Jul 2024 00:36:34 +1000 Subject: [PATCH 12/64] Forward port release notes for v8.14.2 (#110538) --- docs/reference/release-notes/8.14.2.asciidoc | 2 - .../release-notes/highlights.asciidoc | 157 +++++++++++++++++- 2 files changed, 151 insertions(+), 8 deletions(-) diff --git a/docs/reference/release-notes/8.14.2.asciidoc b/docs/reference/release-notes/8.14.2.asciidoc index 9273355106a03..d94067f030c61 100644 --- a/docs/reference/release-notes/8.14.2.asciidoc +++ b/docs/reference/release-notes/8.14.2.asciidoc @@ -1,8 +1,6 @@ [[release-notes-8.14.2]] == {es} version 8.14.2 -coming[8.14.2] - Also see <>. [[known-issues-8.14.2]] diff --git a/docs/reference/release-notes/highlights.asciidoc b/docs/reference/release-notes/highlights.asciidoc index e70892ef25928..0ed01ff422700 100644 --- a/docs/reference/release-notes/highlights.asciidoc +++ b/docs/reference/release-notes/highlights.asciidoc @@ -30,13 +30,158 @@ Other versions: endif::[] -// The notable-highlights tag marks entries that -// should be featured in the Stack Installation and Upgrade Guide: // tag::notable-highlights[] -// [discrete] -// === Heading -// -// Description. + +[discrete] +[[stored_fields_are_compressed_with_zstandard_instead_of_lz4_deflate]] +=== Stored fields are now compressed with ZStandard instead of LZ4/DEFLATE +Stored fields are now compressed by splitting documents into blocks, which +are then compressed independently with ZStandard. `index.codec: default` +(default) uses blocks of at most 14kB or 128 documents compressed with level +0, while `index.codec: best_compression` uses blocks of at most 240kB or +2048 documents compressed at level 3. On most datasets that we tested +against, this yielded storage improvements in the order of 10%, slightly +faster indexing and similar retrieval latencies. + +{es-pull}103374[#103374] + +[discrete] +[[stricter_failure_handling_in_multi_repo_get_snapshots_request_handling]] +=== Stricter failure handling in multi-repo get-snapshots request handling +If a multi-repo get-snapshots request encounters a failure in one of the +targeted repositories then earlier versions of Elasticsearch would proceed +as if the faulty repository did not exist, except for a per-repository +failure report in a separate section of the response body. This makes it +impossible to paginate the results properly in the presence of failures. In +versions 8.15.0 and later this API's failure handling behaviour has been +made stricter, reporting an overall failure if any targeted repository's +contents cannot be listed. + +{es-pull}107191[#107191] + +[discrete] +[[add_new_int4_quantization_to_dense_vector]] +=== Add new int4 quantization to dense_vector +New int4 (half-byte) scalar quantization support via two knew index types: `int4_hnsw` and `int4_flat`. +This gives an 8x reduction from `float32` with some accuracy loss. In addition to less memory required, this +improves query and merge speed significantly when compared to raw vectors. + +{es-pull}109317[#109317] + +[discrete] +[[mark_query_rules_as_ga]] +=== Mark Query Rules as GA +This PR marks query rules as Generally Available. All APIs are no longer +in tech preview. + +{es-pull}110004[#110004] + +[discrete] +[[adds_new_bit_element_type_for_dense_vectors]] +=== Adds new `bit` `element_type` for `dense_vectors` +This adds `bit` vector support by adding `element_type: bit` for +vectors. This new element type works for indexed and non-indexed +vectors. Additionally, it works with `hnsw` and `flat` index types. No +quantization based codec works with this element type, this is +consistent with `byte` vectors. + +`bit` vectors accept up to `32768` dimensions in size and expect vectors +that are being indexed to be encoded either as a hexidecimal string or a +`byte[]` array where each element of the `byte` array represents `8` +bits of the vector. + +`bit` vectors support script usage and regular query usage. When +indexed, all comparisons done are `xor` and `popcount` summations (aka, +hamming distance), and the scores are transformed and normalized given +the vector dimensions. + +For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is +`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported. + +Note, the dimensions expected by this element_type are always to be +divisible by `8`, and the `byte[]` vectors provided for index must be +have size `dim/8` size, where each byte element represents `8` bits of +the vectors. + +{es-pull}110059[#110059] + +[discrete] +[[redact_processor_generally_available]] +=== The Redact processor is Generally Available +The Redact processor uses the Grok rules engine to obscure text in the input document matching the given Grok patterns. The Redact processor was initially released as Technical Preview in `8.7.0`, and is now released as Generally Available. + +{es-pull}110395[#110395] + // end::notable-highlights[] +[discrete] +[[new_custom_parser_for_iso_8601_datetimes]] +=== New custom parser for ISO-8601 datetimes +This introduces a new custom parser for ISO-8601 datetimes, for the `iso8601`, `strict_date_optional_time`, and +`strict_date_optional_time_nanos` built-in date formats. This provides a performance improvement over the +default Java date-time parsing. Whilst it maintains much of the same behaviour, +the new parser does not accept nonsensical date-time strings that have multiple fractional seconds fields +or multiple timezone specifiers. If the new parser fails to parse a string, it will then use the previous parser +to parse it. If a large proportion of the input data consists of these invalid strings, this may cause +a small performance degradation. If you wish to force the use of the old parsers regardless, +set the JVM property `es.datetime.java_time_parsers=true` on all ES nodes. + +{es-pull}106486[#106486] + +[discrete] +[[new_custom_parser_for_more_iso_8601_date_formats]] +=== New custom parser for more ISO-8601 date formats +Following on from #106486, this extends the custom ISO-8601 datetime parser to cover the `strict_year`, +`strict_year_month`, `strict_date_time`, `strict_date_time_no_millis`, `strict_date_hour_minute_second`, +`strict_date_hour_minute_second_millis`, and `strict_date_hour_minute_second_fraction` date formats. +As before, the parser will use the existing java.time parser if there are parsing issues, and the +`es.datetime.java_time_parsers=true` JVM property will force the use of the old parsers regardless. + +{es-pull}108606[#108606] + +[discrete] +[[preview_support_for_connection_type_domain_isp_databases_in_geoip_processor]] +=== Preview: Support for the 'Connection Type, 'Domain', and 'ISP' databases in the geoip processor +As a Technical Preview, the {ref}/geoip-processor.html[`geoip`] processor can now use the commercial +https://dev.maxmind.com/geoip/docs/databases/connection-type[GeoIP2 'Connection Type'], +https://dev.maxmind.com/geoip/docs/databases/domain[GeoIP2 'Domain'], +and +https://dev.maxmind.com/geoip/docs/databases/isp[GeoIP2 'ISP'] +databases from MaxMind. + +{es-pull}108683[#108683] + +[discrete] +[[update_elasticsearch_to_lucene_9_11]] +=== Update Elasticsearch to Lucene 9.11 +Elasticsearch is now updated using the latest Lucene version 9.11. +Here are the full release notes: +But, here are some particular highlights: +- Usage of MADVISE for better memory management: https://github.com/apache/lucene/pull/13196 +- Use RWLock to access LRUQueryCache to reduce contention: https://github.com/apache/lucene/pull/13306 +- Speedup multi-segment HNSW graph search for nested kNN queries: https://github.com/apache/lucene/pull/13121 +- Add a MemorySegment Vector scorer - for scoring without copying on-heap vectors: https://github.com/apache/lucene/pull/13339 + +{es-pull}109219[#109219] + +[discrete] +[[synthetic_source_improvements]] +=== Synthetic `_source` improvements +There are multiple improvements to synthetic `_source` functionality: + +* Synthetic `_source` is now supported for all field types including `nested` and `object`. `object` fields are supported with `enabled` set to `false`. + +* Synthetic `_source` can be enabled together with `ignore_malformed` and `ignore_above` parameters for all field types that support them. + +{es-pull}109501[#109501] + +[discrete] +[[index_sorting_on_indexes_with_nested_fields]] +=== Index sorting on indexes with nested fields +Index sorting is now supported for indexes with mappings containing nested objects. +The index sort spec (as specified by `index.sort.field`) can't contain any nested +fields, still. + +{es-pull}110251[#110251] + From e8556c1c1dd6688227fe9855ae01ebda6dae0915 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 8 Jul 2024 16:40:12 +0200 Subject: [PATCH 13/64] [Inference API] Use extractOptionalPositiveInteger in OpenAiEmbeddingsServiceSettings for dims and maxInputTokens (#110484) --- .../OpenAiEmbeddingsServiceSettings.java | 10 +- .../services/openai/OpenAiServiceTests.java | 27 +++--- .../OpenAiEmbeddingsServiceSettingsTests.java | 91 +++++++++++++++++++ 3 files changed, 114 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index 080251bf1ba3a..d474e935fbda7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; @@ -99,8 +100,13 @@ private static CommonFields fromMap( String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 9e35180547bf2..9ff175ca9685e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -342,7 +342,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", 100, false), + getServiceSettingsMap("model", "url", "org", 100, null, false), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -393,7 +393,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, true), + getServiceSettingsMap("model", null, null, null, null, true), getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -419,7 +419,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWi public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -450,7 +450,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), secretSettingsMap ); @@ -476,7 +476,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -503,7 +503,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, true); + var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user"), getSecretSettingsMap("secret")); @@ -532,7 +532,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -558,7 +558,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user") ); @@ -593,7 +593,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", null, null, null, true), getTaskSettingsMap(null)); + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null) + ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -611,7 +614,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUr public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user") ); persistedConfig.config().put("extra_key", "value"); @@ -631,7 +634,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, true); + var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user")); @@ -654,7 +657,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( var taskSettingsMap = getTaskSettingsMap("user"); taskSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, true), taskSettingsMap); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index cc0004a2d678c..10ccbb4eb39f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -257,6 +257,92 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat(settings, is(new OpenAiEmbeddingsServiceSettings("m", (URI) null, null, null, null, null, true, null))); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var dimensions = 0; + + var settingsMap = getServiceSettingsMap(modelId, url, org, dimensions, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var dimensions = randomNegativeInt(); + + var settingsMap = getServiceSettingsMap(modelId, url, org, dimensions, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var maxInputTokens = 0; + + var settingsMap = getServiceSettingsMap(modelId, url, org, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = getServiceSettingsMap(modelId, url, org, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsSetByUserIsNull() { OpenAiEmbeddingsServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.DIMENSIONS, 1, ServiceFields.MODEL_ID, "m")), @@ -464,6 +550,7 @@ public static Map getServiceSettingsMap( @Nullable String url, @Nullable String org, @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, @Nullable Boolean dimensionsSetByUser ) { var map = new HashMap(); @@ -481,6 +568,10 @@ public static Map getServiceSettingsMap( map.put(ServiceFields.DIMENSIONS, dimensions); } + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (dimensionsSetByUser != null) { map.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, dimensionsSetByUser); } From fd790ff351f43523e6c05621b5d1be7fe30f141c Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 8 Jul 2024 10:50:44 -0400 Subject: [PATCH 14/64] Fix ExactKnnQueryBuilderTests testToQuery (#110357) (#110589) closes https://github.com/elastic/elasticsearch/issues/110357 With the loosening of what is considered a unit vector, we need to ensure we only normalize for equality checking if the query vector is indeed not a unit vector. --- muted-tests.yml | 3 --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 2 +- .../search/vectors/ExactKnnQueryBuilderTests.java | 5 ++++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 990b7d5dc5130..d46a9355c201f 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -88,9 +88,6 @@ tests: - class: org.elasticsearch.backwards.SearchWithMinCompatibleSearchNodeIT method: testMinVersionAsOldVersion issue: https://github.com/elastic/elasticsearch/issues/109454 -- class: org.elasticsearch.search.vectors.ExactKnnQueryBuilderTests - method: testToQuery - issue: https://github.com/elastic/elasticsearch/issues/110357 - class: org.elasticsearch.search.aggregations.bucket.terms.RareTermsIT method: testSingleValuedString issue: https://github.com/elastic/elasticsearch/issues/110388 diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 989c92e909ce2..d27c0acdb6b2e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -98,7 +98,7 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; private static final float EPS = 1e-3f; - static boolean isNotUnitVector(float magnitude) { + public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java index 1e77e35b60a4c..627f8a184a147 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; @@ -87,7 +88,9 @@ protected void doAssertLuceneQuery(ExactKnnQueryBuilder queryBuilder, Query quer DenseVectorQuery.Floats denseVectorQuery = (DenseVectorQuery.Floats) query; assertEquals(VECTOR_FIELD, denseVectorQuery.field); float[] expected = Arrays.copyOf(queryBuilder.getQuery().asFloatVector(), queryBuilder.getQuery().asFloatVector().length); - if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE)) { + float magnitude = VectorUtil.dotProduct(expected, expected); + if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE) + && DenseVectorFieldMapper.isNotUnitVector(magnitude)) { VectorUtil.l2normalize(expected); assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); } else { From c3fd01d14ceb24f5de58d939cf6066e9de771ab3 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:12:04 -0400 Subject: [PATCH 15/64] AwaitsFix: https://github.com/elastic/elasticsearch/issues/110591 --- muted-tests.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index d46a9355c201f..79372be872928 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -4,7 +4,8 @@ tests: method: "testGuessIsDayFirstFromLocale" - class: "org.elasticsearch.test.rest.ClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/108857" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.upgrades.SearchStatesIT" issue: "https://github.com/elastic/elasticsearch/issues/108991" method: "testCanMatch" @@ -13,7 +14,8 @@ tests: method: "testTrainedModelInference" - class: "org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109188" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT" issue: "https://github.com/elastic/elasticsearch/issues/109189" method: "test {p0=esql/70_locale/Date format with Italian locale}" @@ -28,7 +30,8 @@ tests: method: "testTimestampFieldTypeExposedByAllIndicesServices" - class: "org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109318" - method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling (too complex pattern)}" + method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling\ + \ (too complex pattern)}" - class: "org.elasticsearch.xpack.ml.integration.ClassificationHousePricingIT" issue: "https://github.com/elastic/elasticsearch/issues/101598" method: "testFeatureImportanceValues" @@ -95,8 +98,11 @@ tests: issue: "https://github.com/elastic/elasticsearch/issues/110408" method: "testCreateAndRestorePartialSearchableSnapshot" - class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT - method: test {p0=search.vectors/41_knn_search_half_byte_quantized/Test create, merge, and search cosine} + method: test {p0=search.vectors/41_knn_search_half_byte_quantized/Test create, merge, + and search cosine} issue: https://github.com/elastic/elasticsearch/issues/109978 +- class: "org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT" + issue: "https://github.com/elastic/elasticsearch/issues/110591" # Examples: # From d05f97021cf5f1dea8cd54c2c42c261850e9c02a Mon Sep 17 00:00:00 2001 From: Oleksandr Kolomiiets Date: Mon, 8 Jul 2024 08:36:19 -0700 Subject: [PATCH 16/64] Fix MapperBuilderContext#isDataStream when used in dynamic mappers (#110554) --- docs/changelog/110554.yaml | 5 ++ .../index/mapper/DocumentParserContext.java | 2 +- .../mapper/DocumentParserContextTests.java | 52 +++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/110554.yaml diff --git a/docs/changelog/110554.yaml b/docs/changelog/110554.yaml new file mode 100644 index 0000000000000..8c0b896a4c979 --- /dev/null +++ b/docs/changelog/110554.yaml @@ -0,0 +1,5 @@ +pr: 110554 +summary: Fix `MapperBuilderContext#isDataStream` when used in dynamic mappers +area: "Mapping" +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index d8fa2919b795f..248369b249007 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -673,7 +673,7 @@ public final MapperBuilderContext createDynamicMapperBuilderContext() { return new MapperBuilderContext( p, mappingLookup.isSourceSynthetic(), - false, + mappingLookup.isDataStreamTimestampFieldEnabled(), containsDimensions, dynamic, MergeReason.MAPPING_UPDATE, diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java index ab1c93cd98277..2826243e4c866 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java @@ -11,7 +11,9 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; @@ -81,4 +83,54 @@ public void testSwitchParser() throws IOException { assertEquals(parser, newContext.parser()); assertEquals("1", newContext.indexSettings().getSettings().get("index.mapping.total_fields.limit")); } + + public void testCreateDynamicMapperBuilderContextFromEmptyContext() throws IOException { + var resultFromEmptyParserContext = context.createDynamicMapperBuilderContext(); + + assertEquals("hey", resultFromEmptyParserContext.buildFullName("hey")); + assertFalse(resultFromEmptyParserContext.isSourceSynthetic()); + assertFalse(resultFromEmptyParserContext.isDataStream()); + assertFalse(resultFromEmptyParserContext.parentObjectContainsDimensions()); + assertEquals(ObjectMapper.Defaults.DYNAMIC, resultFromEmptyParserContext.getDynamic()); + assertEquals(MapperService.MergeReason.MAPPING_UPDATE, resultFromEmptyParserContext.getMergeReason()); + assertFalse(resultFromEmptyParserContext.isInNestedContext()); + } + + public void testCreateDynamicMapperBuilderContext() throws IOException { + var mapping = XContentBuilder.builder(XContentType.JSON.xContent()) + .startObject() + .startObject("_doc") + .startObject("_source") + .field("mode", "synthetic") + .endObject() + .startObject(DataStreamTimestampFieldMapper.NAME) + .field("enabled", "true") + .endObject() + .startObject("properties") + .startObject(DataStreamTimestampFieldMapper.DEFAULT_PATH) + .field("type", "date") + .endObject() + .startObject("foo") + .field("type", "passthrough") + .field("time_series_dimension", "true") + .field("priority", "100") + .endObject() + .endObject() + .endObject() + .endObject(); + var documentMapper = new MapperServiceTestCase() { + }.createDocumentMapper(mapping); + var parserContext = new TestDocumentParserContext(documentMapper.mappers(), null); + parserContext.path().add("foo"); + + var resultFromParserContext = parserContext.createDynamicMapperBuilderContext(); + + assertEquals("foo.hey", resultFromParserContext.buildFullName("hey")); + assertTrue(resultFromParserContext.isSourceSynthetic()); + assertTrue(resultFromParserContext.isDataStream()); + assertTrue(resultFromParserContext.parentObjectContainsDimensions()); + assertEquals(ObjectMapper.Defaults.DYNAMIC, resultFromParserContext.getDynamic()); + assertEquals(MapperService.MergeReason.MAPPING_UPDATE, resultFromParserContext.getMergeReason()); + assertFalse(resultFromParserContext.isInNestedContext()); + } } From 930ff47c2f7388b5cf6d0a3235256f7d91394e45 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 8 Jul 2024 17:37:06 +0200 Subject: [PATCH 17/64] [Inference API] Use extractOptionalPositiveInteger in MistralEmbeddingsServiceSettings for dims and maxInputTokens (#110485) --- .../MistralEmbeddingsServiceSettings.java | 3 +- ...MistralEmbeddingsServiceSettingsTests.java | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index 62d06a4e0029c..2e4d546e1dc4c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { @@ -67,7 +66,7 @@ public static MistralEmbeddingsServiceSettings fromMap(Map map, MistralService.NAME, context ); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java index 076986acdcee6..009a6dbdeb793 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.ByteArrayStreamInput; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.core.Nullable; @@ -27,6 +28,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; public class MistralEmbeddingsServiceSettingsTests extends ESTestCase { @@ -77,6 +79,84 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, null, null))); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var model = "mistral-embed"; + var dimensions = 0; + + var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var model = "mistral-embed"; + var dimensions = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var model = "mistral-embed"; + var maxInputTokens = 0; + + var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var model = "mistral-embed"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { var model = "mistral-embed"; From b01949c6aa82a2ab56f13f01c34da3768a1a56fe Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 8 Jul 2024 17:22:59 +0100 Subject: [PATCH 18/64] [ML] Fixes processing chunked results in AWS Bedrock service (#110592) Fixes error using the Amazon Bedrock service with a large input that was chunked. --- .../amazonbedrock/AmazonBedrockService.java | 24 +------------------ .../azureopenai/AzureOpenAiService.java | 18 -------------- .../AmazonBedrockServiceTests.java | 21 +++++++++------- 3 files changed, 14 insertions(+), 49 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index dadcc8a40245e..459ca367058f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -23,10 +23,6 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; @@ -47,7 +43,6 @@ import java.util.Set; import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; -import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; @@ -115,10 +110,6 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - ActionListener inferListener = listener.delegateFailureAndWrap( - (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response)) - ); - var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); @@ -126,26 +117,13 @@ protected void doChunkedInfer( .batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, inferListener); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); } } - private static List translateToChunkedResults( - List inputs, - InferenceServiceResults inferenceResults - ) { - if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults); - } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName()); - } - } - @Override public String name() { return NAME; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 3facb78864831..3c75243770f97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -24,10 +24,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -44,7 +40,6 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; @@ -246,19 +241,6 @@ protected void doChunkedInfer( } } - private static List translateToChunkedResults( - List inputs, - InferenceServiceResults inferenceResults - ) { - if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults); - } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName()); - } - } - /** * For text embedding models get the embedding size and * update the service settings. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 00a840c8d4812..ae413fc17425c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -1048,13 +1048,18 @@ public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() th try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var mockResults = new InferenceTextEmbeddingFloatResults( - List.of( - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F }), - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.456F, 0.987F }) - ) - ); - requestSender.enqueue(mockResults); + { + var mockResults1 = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(mockResults1); + } + { + var mockResults2 = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.223F, 0.278F })) + ); + requestSender.enqueue(mockResults2); + } var model = AmazonBedrockEmbeddingsModelTests.createModel( "id", @@ -1089,7 +1094,7 @@ public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() th var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("xyz", floatResult.chunks().get(0).matchedText()); - assertArrayEquals(new float[] { 0.456F, 0.987F }, floatResult.chunks().get(0).embedding(), 0.0f); + assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f); } } } From fbcde9c0fd40c3f461af0cefe1af6eabe9da5091 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Mon, 8 Jul 2024 19:02:18 +0200 Subject: [PATCH 19/64] [CI] Temporally increase disk space for DRA build jobs (#110601) --- .buildkite/pipelines/dra-workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipelines/dra-workflow.yml b/.buildkite/pipelines/dra-workflow.yml index 32a2b7d22134a..bcc6c9c57d756 100644 --- a/.buildkite/pipelines/dra-workflow.yml +++ b/.buildkite/pipelines/dra-workflow.yml @@ -7,7 +7,7 @@ steps: image: family/elasticsearch-ubuntu-2204 machineType: custom-32-98304 buildDirectory: /dev/shm/bk - diskSizeGb: 250 + diskSizeGb: 350 - wait # The hadoop build depends on the ES artifact # So let's trigger the hadoop build any time we build a new staging artifact From 320b88ae37dfbd88c75eaa0476927dae67e90879 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Mon, 8 Jul 2024 14:53:31 -0400 Subject: [PATCH 20/64] [Inference API] Semantic text delete inference (#110487) * Prevent inference endpoints from being deleted if they are referenced by a semantic text field * Update docs/changelog/110399.yaml * fix tests * remove erroneous loging * Apply suggestions from code review Co-authored-by: David Kyle Co-authored-by: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> * Fix serialization problem * Update error messages * Update Delete response to include new fields * Refactor Delete Transport Action to return the error message on dry run * Fix tests including disabling failing yaml tests * Fix YAML tests * move work off of transport thread onto utility threadpool * clean up semantic text indexes after IT * improvements from review --------- Co-authored-by: David Kyle Co-authored-by: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Co-authored-by: Elastic Machine --- docs/changelog/110399.yaml | 6 + .../org/elasticsearch/TransportVersions.java | 1 + .../action/DeleteInferenceEndpointAction.java | 35 ++++- .../ml/utils/SemanticTextInfoExtractor.java | 50 +++++++ .../inference/InferenceBaseRestTest.java | 19 +++ .../xpack/inference/InferenceCrudIT.java | 90 +++++++++++- ...ransportDeleteInferenceEndpointAction.java | 138 ++++++++++++------ ..._text_query_inference_endpoint_changes.yml | 3 + 8 files changed, 283 insertions(+), 59 deletions(-) create mode 100644 docs/changelog/110399.yaml create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java diff --git a/docs/changelog/110399.yaml b/docs/changelog/110399.yaml new file mode 100644 index 0000000000000..9e04e2656809e --- /dev/null +++ b/docs/changelog/110399.yaml @@ -0,0 +1,6 @@ +pr: 110399 +summary: "[Inference API] Prevent inference endpoints from being deleted if they are\ + \ referenced by semantic text" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ff50d1513d28a..f64a43d463d47 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -209,6 +209,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0); public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0); + public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index dfb77ccd49fc2..e9d612751e48f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -11,8 +11,10 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -105,10 +107,16 @@ public static class Response extends AcknowledgedResponse { private final String PIPELINE_IDS = "pipelines"; Set pipelineIds; + private final String REFERENCED_INDEXES = "indexes"; + Set indexes; + private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run + String dryRunMessage; - public Response(boolean acknowledged, Set pipelineIds) { + public Response(boolean acknowledged, Set pipelineIds, Set semanticTextIndexes, @Nullable String dryRunMessage) { super(acknowledged); this.pipelineIds = pipelineIds; + this.indexes = semanticTextIndexes; + this.dryRunMessage = dryRunMessage; } public Response(StreamInput in) throws IOException { @@ -118,6 +126,15 @@ public Response(StreamInput in) throws IOException { } else { pipelineIds = Set.of(); } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + indexes = in.readCollectionAsSet(StreamInput::readString); + dryRunMessage = in.readOptionalString(); + } else { + indexes = Set.of(); + dryRunMessage = null; + } + } @Override @@ -126,23 +143,25 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) { out.writeCollection(pipelineIds, StreamOutput::writeString); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + out.writeCollection(indexes, StreamOutput::writeString); + out.writeOptionalString(dryRunMessage); + } } @Override protected void addCustomFields(XContentBuilder builder, Params params) throws IOException { super.addCustomFields(builder, params); builder.field(PIPELINE_IDS, pipelineIds); + builder.field(REFERENCED_INDEXES, indexes); + if (dryRunMessage != null) { + builder.field(DRY_RUN_MESSAGE, dryRunMessage); + } } @Override public String toString() { - StringBuilder returnable = new StringBuilder(); - returnable.append("acknowledged: ").append(this.acknowledged); - returnable.append(", pipelineIdsByEndpoint: "); - for (String entry : pipelineIds) { - returnable.append(entry).append(", "); - } - return returnable.toString(); + return Strings.toString(this); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java new file mode 100644 index 0000000000000..544c1e344c91f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.core.ml.utils; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.transport.Transports; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class SemanticTextInfoExtractor { + private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); + + public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Set referenceIndices = new HashSet<>(); + + Map indices = metadata.indices(); + + indices.forEach((indexName, indexMetadata) -> { + if (indexMetadata.getInferenceFields() != null) { + Map inferenceFields = indexMetadata.getInferenceFields(); + if (inferenceFields.entrySet() + .stream() + .anyMatch( + entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId()) + )) { + referenceIndices.add(indexName); + } + } + }); + + return referenceIndices; + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 419869c0c4a5e..f30f2e8fe201a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -126,6 +126,25 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException assertOkOrCreated(response); } + protected void putSemanticText(String endpointId, String indexName) throws IOException { + var request = new Request("PUT", Strings.format("%s", indexName)); + String body = Strings.format(""" + { + "mappings": { + "properties": { + "inference_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + } + """, endpointId); + request.setJsonEntity(body); + var response = client().performRequest(request); + assertOkOrCreated(response); + } + protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return putRequest(endpoint, modelConfig); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 75e392b6d155f..242f786e95364 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; +import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -124,14 +125,15 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { putPipeline(pipelineId, endpointId); { + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat( - e.getMessage(), - containsString( - "Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it." - ) - ); + assertThat(e.getMessage(), containsString(errorString.toString())); } { var response = deleteModel(endpointId, "dry_run=true"); @@ -146,4 +148,78 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { } deletePipeline(pipelineId); } + + public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + { + + var errorString = new StringBuilder().append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deleteIndex(indexName); + } + + public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + var pipelineId = "pipeline_referencing_model"; + putPipeline(pipelineId, endpointId); + { + + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint.") + .append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + assertThat(entityString, containsString(pipelineId)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deletePipeline(pipelineId); + deleteIndex(indexName); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 07d5e1e618578..e59ac4e1356f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -3,6 +3,8 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * this file was contributed to by a Generative AI */ package org.elasticsearch.xpack.inference.action; @@ -11,6 +13,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.TransportMasterNodeAction; @@ -18,12 +21,10 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -34,6 +35,10 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.Set; +import java.util.concurrent.Executor; + +import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, @@ -42,6 +47,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class); + private final Executor executor; @Inject public TransportDeleteInferenceEndpointAction( @@ -66,6 +72,7 @@ public TransportDeleteInferenceEndpointAction( ); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; + this.executor = threadPool.executor(UTILITY_THREAD_POOL_NAME); } @Override @@ -74,6 +81,15 @@ protected void masterOperation( DeleteInferenceEndpointAction.Request request, ClusterState state, ActionListener masterListener + ) { + // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can + executor.execute(ActionRunnable.wrap(masterListener, l -> doExecuteForked(request, state, l))); + } + + private void doExecuteForked( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry @@ -89,17 +105,15 @@ protected void masterOperation( } if (request.isDryRun()) { - masterListener.onResponse( - new DeleteInferenceEndpointAction.Response( - false, - InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())) - ) - ); + handleDryRun(request, state, masterListener); return; - } else if (request.isForceDelete() == false - && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), listener)) { + } else if (request.isForceDelete() == false) { + var errorString = endpointIsReferencedInPipelinesOrIndexes(state, request.getInferenceEndpointId()); + if (errorString != null) { + listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT)); return; } + } var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { @@ -126,47 +140,83 @@ && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), list }) .addListener( masterListener.delegateFailure( - (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of())) + (l3, didDeleteModel) -> masterListener.onResponse( + new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of(), Set.of(), null) + ) ) ); } - private static boolean endpointIsReferencedInPipelines( - final ClusterState state, - final String inferenceEndpointId, - ActionListener listener + private static void handleDryRun( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { - Metadata metadata = state.getMetadata(); - if (metadata == null) { - listener.onFailure( - new ElasticsearchStatusException( - " Could not determine if the endpoint is referenced in a pipeline as cluster state metadata was unexpectedly null. " - + "Use `force` to delete it anyway", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - // Unsure why the ClusterState metadata would ever be null, but in this case it seems safer to assume the endpoint is referenced - return true; + Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); + + Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( + state.getMetadata(), + Set.of(request.getInferenceEndpointId()) + ); + + masterListener.onResponse( + new DeleteInferenceEndpointAction.Response( + false, + pipelines, + indexesReferencedBySemanticText, + buildErrorString(request.getInferenceEndpointId(), pipelines, indexesReferencedBySemanticText) + ) + ); + } + + private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterState state, final String inferenceEndpointId) { + + var pipelines = endpointIsReferencedInPipelines(state, inferenceEndpointId); + var indexes = endpointIsReferencedInIndex(state, inferenceEndpointId); + + if (pipelines.isEmpty() == false || indexes.isEmpty() == false) { + return buildErrorString(inferenceEndpointId, pipelines, indexes); } - IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); - if (ingestMetadata == null) { - logger.debug("No ingest metadata found in cluster state while attempting to delete inference endpoint"); - } else { - Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(ingestMetadata); - if (modelIdsReferencedByPipelines.contains(inferenceEndpointId)) { - listener.onFailure( - new ElasticsearchStatusException( - "Inference endpoint " - + inferenceEndpointId - + " is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it.", - RestStatus.CONFLICT - ) - ); - return true; - } + return null; + } + + private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { + StringBuilder errorString = new StringBuilder(); + + if (pipelines.isEmpty() == false) { + errorString.append("Inference endpoint ") + .append(inferenceEndpointId) + .append(" is referenced by pipelines: ") + .append(pipelines) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); } - return false; + + if (indexes.isEmpty() == false) { + errorString.append(" Inference endpoint ") + .append(inferenceEndpointId) + .append(" is being used in the mapping for indexes: ") + .append(indexes) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + } + + return errorString.toString(); + } + + private static Set endpointIsReferencedInIndex(final ClusterState state, final String inferenceEndpointId) { + Set indexes = extractIndexesReferencingInferenceEndpoints(state.getMetadata(), Set.of(inferenceEndpointId)); + return indexes; + } + + private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { + Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource( + state, + Set.of(inferenceEndpointId) + ); + return modelIdsReferencedByPipelines; } @Override diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index fd656c9d5d950..f6a7073914609 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -81,6 +81,7 @@ setup: - do: inference.delete: inference_id: sparse-inference-id + force: true - do: inference.put: @@ -119,6 +120,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: @@ -155,6 +157,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: From 52b2a414eaaabf61f270f9c86c09599306a620ac Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Mon, 8 Jul 2024 14:09:33 -0500 Subject: [PATCH 21/64] Do not run TickerScheduleTriggerEngine watches if the schedule trigger engine is paused (#110061) --- docs/changelog/110061.yaml | 6 +++++ .../engine/TickerScheduleTriggerEngine.java | 24 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 docs/changelog/110061.yaml diff --git a/docs/changelog/110061.yaml b/docs/changelog/110061.yaml new file mode 100644 index 0000000000000..1880a2a197722 --- /dev/null +++ b/docs/changelog/110061.yaml @@ -0,0 +1,6 @@ +pr: 110061 +summary: Avoiding running watch jobs in TickerScheduleTriggerEngine if it is paused +area: Watcher +type: bug +issues: + - 105933 diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java index ba07c3137340d..ced131640f0ee 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.common.settings.Setting.positiveTimeSetting; @@ -50,6 +51,7 @@ public class TickerScheduleTriggerEngine extends ScheduleTriggerEngine { private final TimeValue tickInterval; private final Map schedules = new ConcurrentHashMap<>(); private final Ticker ticker; + private final AtomicBoolean isRunning = new AtomicBoolean(false); public TickerScheduleTriggerEngine(Settings settings, ScheduleRegistry scheduleRegistry, Clock clock) { super(scheduleRegistry, clock); @@ -60,7 +62,8 @@ public TickerScheduleTriggerEngine(Settings settings, ScheduleRegistry scheduleR @Override public synchronized void start(Collection jobs) { long startTime = clock.millis(); - logger.info("Watcher starting watches at {}", WatcherDateTimeUtils.dateTimeFormatter.formatMillis(startTime)); + isRunning.set(true); + logger.info("Starting watcher engine at {}", WatcherDateTimeUtils.dateTimeFormatter.formatMillis(startTime)); Map startingSchedules = Maps.newMapWithExpectedSize(jobs.size()); for (Watch job : jobs) { if (job.trigger() instanceof ScheduleTrigger trigger) { @@ -81,17 +84,22 @@ public synchronized void start(Collection jobs) { @Override public void stop() { + logger.info("Stopping watcher engine"); + isRunning.set(false); schedules.clear(); ticker.close(); } @Override - public synchronized void pauseExecution() { + public void pauseExecution() { + logger.info("Pausing watcher engine"); + isRunning.set(false); schedules.clear(); } @Override public void add(Watch watch) { + logger.trace("Adding watch [{}] to engine (engine is running: {})", watch.id(), isRunning.get()); assert watch.trigger() instanceof ScheduleTrigger; ScheduleTrigger trigger = (ScheduleTrigger) watch.trigger(); ActiveSchedule currentSchedule = schedules.get(watch.id()); @@ -106,13 +114,25 @@ public void add(Watch watch) { @Override public boolean remove(String jobId) { + logger.debug("Removing watch [{}] from engine (engine is running: {})", jobId, isRunning.get()); return schedules.remove(jobId) != null; } void checkJobs() { + if (isRunning.get() == false) { + logger.debug( + "Watcher not running because the engine is paused. Currently scheduled watches being skipped: {}", + schedules.size() + ); + return; + } long triggeredTime = clock.millis(); List events = new ArrayList<>(); for (ActiveSchedule schedule : schedules.values()) { + if (isRunning.get() == false) { + logger.debug("Watcher paused while running [{}]", schedule.name); + break; + } long scheduledTime = schedule.check(triggeredTime); if (scheduledTime > 0) { ZonedDateTime triggeredDateTime = utcDateTimeAtEpochMillis(triggeredTime); From f2382a9a7475e674f05200b0e53ea35ac6e27416 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:14:17 -0400 Subject: [PATCH 22/64] [ML] Refactor inference input (#108167) * Passing inference input through to 3rd party request classes * Plumbing changes through the rest of the support integrations * Clean up * Including input changes with google vertex * Backing out semantic query changes * Addressing feedback and adding aws --- ...onBedrockChatCompletionRequestManager.java | 6 ++--- ...AmazonBedrockEmbeddingsRequestManager.java | 6 ++--- .../AnthropicCompletionRequestManager.java | 7 +++--- ...eAiStudioChatCompletionRequestManager.java | 6 ++--- ...AzureAiStudioEmbeddingsRequestManager.java | 6 ++--- .../AzureOpenAiCompletionRequestManager.java | 7 +++--- .../AzureOpenAiEmbeddingsRequestManager.java | 7 +++--- .../CohereCompletionRequestManager.java | 6 ++--- .../CohereEmbeddingsRequestManager.java | 6 ++--- .../sender/CohereRerankRequestManager.java | 7 +++--- .../http/sender/DocumentsOnlyInput.java | 10 +++++++- ...oogleAiStudioCompletionRequestManager.java | 6 ++--- ...oogleAiStudioEmbeddingsRequestManager.java | 6 ++--- ...oogleVertexAiEmbeddingsRequestManager.java | 6 ++--- .../GoogleVertexAiRerankRequestManager.java | 7 +++--- .../sender/HuggingFaceRequestManager.java | 6 ++--- .../external/http/sender/InferenceInputs.java | 8 ++++++- .../http/sender/InferenceRequest.java | 10 ++------ .../MistralEmbeddingsRequestManager.java | 6 ++--- .../OpenAiCompletionRequestManager.java | 7 +++--- .../OpenAiEmbeddingsRequestManager.java | 6 ++--- .../http/sender/QueryAndDocsInputs.java | 10 +++++++- .../http/sender/RequestExecutorService.java | 4 ++-- .../external/http/sender/RequestManager.java | 5 +--- .../external/http/sender/RequestTask.java | 24 ++++--------------- .../http/sender/BaseRequestManagerTests.java | 19 +++++---------- .../http/sender/HttpRequestSenderTests.java | 2 +- ... OpenAiEmbeddingsRequestManagerTests.java} | 2 +- .../sender/RequestExecutorServiceTests.java | 4 ++-- .../http/sender/RequestManagerTests.java | 5 ++-- .../http/sender/RequestTaskTests.java | 10 ++++---- 31 files changed, 106 insertions(+), 121 deletions(-) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/{OpenAiEmbeddingsExecutableRequestCreatorTests.java => OpenAiEmbeddingsRequestManagerTests.java} (95%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 1d8226664979c..8642a19b26a7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -41,13 +41,13 @@ public AmazonBedrockChatCompletionRequestManager( @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, input); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java index e9bc6b574865c..2f94cdf342938 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java @@ -49,14 +49,14 @@ public AmazonBedrockEmbeddingsRequestManager( @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); var serviceSettings = embeddingsModel.getServiceSettings(); - var truncatedInput = truncate(input, serviceSettings.maxInputTokens()); + var truncatedInput = truncate(docsInput, serviceSettings.maxInputTokens()); var requestEntity = AmazonBedrockEmbeddingsEntityFactory.createEntity(embeddingsModel, truncatedInput); var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 7dd1a66db13e7..7c527bbd2ee98 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.anthropic.AnthropicResponseHandler; @@ -43,13 +42,13 @@ private AnthropicCompletionRequestManager(AnthropicChatCompletionModel model, Th @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index e295cf5cc43dd..c5e5a5251f7db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -37,13 +37,13 @@ public AzureAiStudioChatCompletionRequestManager(AzureAiStudioChatCompletionMode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index f0f87402fb3a5..c610a7f31f7ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -41,13 +41,13 @@ public AzureAiStudioEmbeddingsRequestManager(AzureAiStudioEmbeddingsModel model, @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index 5206d6c2c23cc..8c9b848f78e3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler; @@ -43,13 +42,13 @@ public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, Thr @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index e0fcee30e5af3..8d4162858b36f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -55,13 +55,14 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); + AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 8a4b0e45b93fa..423093a14a9f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -46,13 +46,13 @@ private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereCompletionRequest request = new CohereCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index a51910f1d0a67..402f91a0838dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -44,13 +44,13 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 1351eec406569..9d565e7124b03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.cohere.CohereRankedResponseEntity; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -44,13 +43,13 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereRerankRequest request = new CohereRerankRequest(query, input, model); + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index a11be003585fd..a32e2018117f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -12,7 +12,15 @@ public class DocumentsOnlyInput extends InferenceInputs { - List input; + public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + throw createUnsupportedTypeException(inferenceInputs); + } + + return (DocumentsOnlyInput) inferenceInputs; + } + + private final List input; public DocumentsOnlyInput(List chunks) { super(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index 2b191b046477b..426102f7f2376 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -42,13 +42,13 @@ public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel mode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index 6436e0231ab48..c7f87fb1cbf7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -48,13 +48,13 @@ public GoogleAiStudioEmbeddingsRequestManager(GoogleAiStudioEmbeddingsModel mode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java index c682da9a1694a..94f44c64b04da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java @@ -56,13 +56,13 @@ public static RateLimitGrouping of(GoogleVertexAiEmbeddingsModel model) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); var request = new GoogleVertexAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java index ab49ecc7ab9f9..e74f0049fffb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiRerankResponseEntity; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -57,13 +56,13 @@ public static RateLimitGrouping of(GoogleVertexAiRerankModel model) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(query, input, model); + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 6c8fc446d5243..a33eb724551f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -55,13 +55,13 @@ private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler respon @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getTokenLimit()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index d7e07e734ce80..dd241857ef0c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -7,4 +7,10 @@ package org.elasticsearch.xpack.inference.external.http.sender; -public abstract class InferenceInputs {} +import org.elasticsearch.common.Strings; + +public abstract class InferenceInputs { + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { + return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java index 6199a75a41a7d..52be5d8be2b6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; -import java.util.List; import java.util.function.Supplier; /** @@ -24,14 +23,9 @@ public interface InferenceRequest { RequestManager getRequestManager(); /** - * Returns the query associated with this request. Used for Rerank tasks. + * Returns the inputs associated with the request. */ - String getQuery(); - - /** - * Returns the text input associated with this request. - */ - List getInput(); + InferenceInputs getInferenceInputs(); /** * Returns the listener to notify of the results. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index 1807712a31ac5..d550749cc2348 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -51,13 +51,13 @@ public MistralEmbeddingsRequestManager(MistralEmbeddingsModel model, Truncator t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 7bc09fd76736b..65f25c0baf8dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -43,13 +42,13 @@ private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPo @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java index 41f91d2b89ee5..5c164f2eb9644 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java @@ -55,13 +55,13 @@ private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator tr @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 4d24598d67831..0d5f98c180ba9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -12,7 +12,15 @@ public class QueryAndDocsInputs extends InferenceInputs { - String query; + public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { + if (inferenceInputs instanceof QueryAndDocsInputs == false) { + throw createUnsupportedTypeException(inferenceInputs); + } + + return (QueryAndDocsInputs) inferenceInputs; + } + + private final String query; public String getQuery() { return query; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 38d47aec68eb6..ad1324d0a315f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -413,7 +413,7 @@ private TimeValue executeEnqueuedTaskInternal() { assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; task.getRequestManager() - .execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + .execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); return EXECUTED_A_TASK; } @@ -423,7 +423,7 @@ private static boolean shouldExecuteTask(RejectableTask task) { private static boolean isNoopRequest(InferenceRequest inferenceRequest) { return inferenceRequest.getRequestManager() == null - || inferenceRequest.getInput() == null + || inferenceRequest.getInferenceInputs() == null || inferenceRequest.getListener() == null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 79ef1b56ad231..853d6fdcb2473 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -8,12 +8,10 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.ratelimit.RateLimitable; -import java.util.List; import java.util.function.Supplier; /** @@ -21,8 +19,7 @@ */ public interface RequestManager extends RateLimitable { void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 7a5f482412289..9ccb93a0858ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; -import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; @@ -27,8 +26,7 @@ class RequestTask implements RejectableTask { private final AtomicBoolean finished = new AtomicBoolean(); private final RequestManager requestCreator; - private final String query; - private final List input; + private final InferenceInputs inferenceInputs; private final ActionListener listener; RequestTask( @@ -40,16 +38,7 @@ class RequestTask implements RejectableTask { ) { this.requestCreator = Objects.requireNonNull(requestCreator); this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool)); - - if (inferenceInputs instanceof QueryAndDocsInputs) { - this.query = ((QueryAndDocsInputs) inferenceInputs).getQuery(); - this.input = ((QueryAndDocsInputs) inferenceInputs).getChunks(); - } else if (inferenceInputs instanceof DocumentsOnlyInput) { - this.query = null; - this.input = ((DocumentsOnlyInput) inferenceInputs).getInputs(); - } else { - throw new IllegalArgumentException("Unsupported inference inputs type: " + inferenceInputs.getClass()); - } + this.inferenceInputs = Objects.requireNonNull(inferenceInputs); } private ActionListener getListener( @@ -91,13 +80,8 @@ public Supplier getRequestCompletedFunction() { } @Override - public List getInput() { - return input; - } - - @Override - public String getQuery() { - return query; + public InferenceInputs getInferenceInputs() { + return inferenceInputs; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java index 03838896b879d..bf120be621ad3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -30,8 +29,7 @@ public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() { var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -43,8 +41,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -62,8 +59,7 @@ public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() { var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -75,8 +71,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -94,8 +89,7 @@ public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -107,8 +101,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 2b8b5f178b3de..79f6aa8164b75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -106,7 +106,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), new DocumentsOnlyInput(List.of("abc")), null, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java similarity index 95% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java index 37fce8d3f3a7b..eb7f7c4a0035d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java @@ -13,7 +13,7 @@ import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; -public class OpenAiEmbeddingsExecutableRequestCreatorTests { +public class OpenAiEmbeddingsRequestManagerTests { public static OpenAiEmbeddingsRequestManager makeCreator( String url, @Nullable String org, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index 9a45e10007643..762a3a74184a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -131,7 +131,7 @@ public void testIsTerminated_AfterStopFromSeparateThread() { PlainActionFuture listener = new PlainActionFuture<>(); service.execute( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "id", null, threadPool), new DocumentsOnlyInput(List.of()), null, listener @@ -208,7 +208,7 @@ public void testTaskThrowsError_CallsOnFailure() { PlainActionFuture listener = new PlainActionFuture<>(); service.execute( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "id", null, threadPool), new DocumentsOnlyInput(List.of()), null, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index 291de740aca34..8b7c01ae133cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -44,7 +43,7 @@ public static RequestManager createMock(RequestSender requestSender, String infe doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[4]; + ActionListener listener = (ActionListener) invocation.getArguments()[3]; requestSender.send( mock(Logger.class), RequestTests.mockRequest(inferenceEntityId), @@ -55,7 +54,7 @@ public static RequestManager createMock(RequestSender requestSender, String infe ); return Void.TYPE; - }).when(mockManager).execute(any(), anyList(), any(), any(), any()); + }).when(mockManager).execute(any(), any(), any(), any()); // just return something consistent so the hashing works when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index 13c395180cd16..c839c266e9320 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -59,7 +59,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc ActionListener listener = mock(ActionListener.class); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), mockThreadPool, @@ -79,7 +79,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -105,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio }).when(listener).onFailure(any()); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -137,7 +137,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { }).when(listener).onFailure(any()); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -167,7 +167,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp ActionListener listener = mock(ActionListener.class); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), mockThreadPool, From 9dbe97b2cbaa95eb7913879a5e1e0c1a0e330fc0 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 8 Jul 2024 17:28:31 -0400 Subject: [PATCH 23/64] Fix flaky test #109978 (#110245) CCS tests could split the vectors over any number of shards. Through empirical testing, I determined this commits values work to provide the expected order, even if they are not all part of the same shard. quantization can have weird behaviors when there are uniform values, just like this test does. closes #109978 --- muted-tests.yml | 4 ---- .../search.vectors/41_knn_search_half_byte_quantized.yml | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 79372be872928..dc1ba7b855d83 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -97,10 +97,6 @@ tests: - class: "org.elasticsearch.xpack.searchablesnapshots.FrozenSearchableSnapshotsIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/110408" method: "testCreateAndRestorePartialSearchableSnapshot" -- class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT - method: test {p0=search.vectors/41_knn_search_half_byte_quantized/Test create, merge, - and search cosine} - issue: https://github.com/elastic/elasticsearch/issues/109978 - class: "org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT" issue: "https://github.com/elastic/elasticsearch/issues/110591" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index cb5aae482507a..5f1af2ca5c52f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -428,7 +428,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "1" body: - embedding: [1.0, 1.0, 1.0, 1.0] + embedding: [0.5, 0.5, 0.5, 0.5, 0.5, 1.0] # Flush in order to provoke a merge later - do: @@ -439,7 +439,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "2" body: - embedding: [1.0, 1.0, 1.0, 2.0] + embedding: [0.0, 0.0, 0.0, 1.0, 1.0, 0.5] # Flush in order to provoke a merge later - do: @@ -450,7 +450,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "3" body: - embedding: [1.0, 1.0, 1.0, 3.0] + embedding: [0.0, 0.0, 0.0, 0.0, 0.0, 10.5] - do: indices.forcemerge: @@ -468,7 +468,7 @@ setup: query: knn: field: embedding - query_vector: [1.0, 1.0, 1.0, 1.0] + query_vector: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] num_candidates: 10 - length: { hits.hits: 3 } From 80b6611ff21d8b625585b82e9a9a685aa574c1c8 Mon Sep 17 00:00:00 2001 From: Patrick Doyle <810052+prdoyle@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:36:31 -0400 Subject: [PATCH 24/64] Use FileSystemProvider instead of RandomGenerator. (#110607) FileSystemProvider is no longer provided by SPI as of Java 23 EA build 24. See https://github.com/openjdk/jdk/commit/42e3c842ae2684265c794868fc76eb0ff2dea3d9#diff-03546451d9e4189f639e3af3dcd6a9e44318fff5ceaadce8478cbf0203ac3f45L422-L435 --- .../org/elasticsearch/plugins/UberModuleClassLoaderTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java b/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java index e3cd11c8f3b68..ecc2f458cdd60 100644 --- a/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java @@ -427,12 +427,12 @@ public String getTestString() { package p; import java.util.ServiceLoader; - import java.util.random.RandomGenerator; + import java.nio.file.spi.FileSystemProvider; public class ServiceCaller { public static String demo() { // check no error if we load a service from the jdk - ServiceLoader randomLoader = ServiceLoader.load(RandomGenerator.class); + ServiceLoader fileSystemLoader = ServiceLoader.load(FileSystemProvider.class); ServiceLoader loader = ServiceLoader.load(MyService.class, ServiceCaller.class.getClassLoader()); return loader.findFirst().get().getTestString(); From 0a518a32bbbb83a14c41e862810631705e785126 Mon Sep 17 00:00:00 2001 From: Dianna Hohensee Date: Mon, 8 Jul 2024 17:41:09 -0400 Subject: [PATCH 25/64] Add debug logging for snapshots (#110246) Specifically around pausing shard snapshots on node removal, and finalizing shards snapshots that change the shard generation of other non- finalized snapshots Relates ES-8566. --- .../cluster/SnapshotsInProgress.java | 41 ++++++-- .../repositories/ShardGenerations.java | 13 ++- .../snapshots/SnapshotShardsService.java | 17 +++- .../snapshots/SnapshotsService.java | 95 +++++++++++-------- 4 files changed, 116 insertions(+), 50 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java b/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java index 532a33d07b25d..b6fb370991a93 100644 --- a/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java +++ b/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java @@ -27,6 +27,8 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.repositories.IndexId; import org.elasticsearch.repositories.RepositoryOperation; import org.elasticsearch.repositories.RepositoryShardId; @@ -58,6 +60,8 @@ */ public class SnapshotsInProgress extends AbstractNamedDiffable implements Custom { + private static final Logger logger = LogManager.getLogger(SnapshotsInProgress.class); + public static final SnapshotsInProgress EMPTY = new SnapshotsInProgress(Map.of(), Set.of()); public static final String TYPE = "snapshots"; @@ -207,6 +211,17 @@ public Map> obsoleteGenerations(String r // We moved from a non-null generation successful generation to a different non-null successful generation // so the original generation is clearly obsolete because it was in-flight before and is now unreferenced everywhere. obsoleteGenerations.computeIfAbsent(repositoryShardId, ignored -> new HashSet<>()).add(oldStatus.generation()); + logger.debug( + """ + Marking shard generation [{}] file for cleanup. The finalized shard generation is now [{}], for shard \ + snapshot [{}] with shard ID [{}] on node [{}] + """, + oldStatus.generation(), + newStatus.generation(), + entry.snapshot(), + repositoryShardId.shardId(), + oldStatus.nodeId() + ); } } } @@ -441,7 +456,9 @@ public SnapshotsInProgress withUpdatedNodeIdsForRemoval(ClusterState clusterStat updatedNodeIdsForRemoval.addAll(nodeIdsMarkedForRemoval); // remove any nodes which are no longer marked for shutdown if they have no running shard snapshots - updatedNodeIdsForRemoval.removeAll(getObsoleteNodeIdsForRemoval(nodeIdsMarkedForRemoval)); + var restoredNodeIds = getObsoleteNodeIdsForRemoval(nodeIdsMarkedForRemoval); + updatedNodeIdsForRemoval.removeAll(restoredNodeIds); + logger.debug("Resuming shard snapshots on nodes [{}]", restoredNodeIds); if (updatedNodeIdsForRemoval.equals(nodesIdsForRemoval)) { return this; @@ -469,19 +486,26 @@ private static Set getNodesIdsMarkedForRemoval(ClusterState clusterState return result; } + /** + * Identifies any nodes that are no longer marked for removal AND have no running shard snapshots. + * @param latestNodeIdsMarkedForRemoval the current nodes marked for removal in the cluster state. + */ private Set getObsoleteNodeIdsForRemoval(Set latestNodeIdsMarkedForRemoval) { - final var obsoleteNodeIdsForRemoval = new HashSet<>(nodesIdsForRemoval); - obsoleteNodeIdsForRemoval.removeIf(latestNodeIdsMarkedForRemoval::contains); - if (obsoleteNodeIdsForRemoval.isEmpty()) { + // Find any nodes no longer marked for removal. + final var nodeIdsNoLongerMarkedForRemoval = new HashSet<>(nodesIdsForRemoval); + nodeIdsNoLongerMarkedForRemoval.removeIf(latestNodeIdsMarkedForRemoval::contains); + if (nodeIdsNoLongerMarkedForRemoval.isEmpty()) { return Set.of(); } + // If any nodes have INIT state shard snapshots, then the node's snapshots are not concurrency safe to resume yet. All shard + // snapshots on a newly revived node (no longer marked for shutdown) must finish moving to paused before any can resume. for (final var byRepo : entries.values()) { for (final var entry : byRepo.entries()) { if (entry.state() == State.STARTED && entry.hasShardsInInitState()) { for (final var shardSnapshotStatus : entry.shards().values()) { if (shardSnapshotStatus.state() == ShardState.INIT) { - obsoleteNodeIdsForRemoval.remove(shardSnapshotStatus.nodeId()); - if (obsoleteNodeIdsForRemoval.isEmpty()) { + nodeIdsNoLongerMarkedForRemoval.remove(shardSnapshotStatus.nodeId()); + if (nodeIdsNoLongerMarkedForRemoval.isEmpty()) { return Set.of(); } } @@ -489,7 +513,7 @@ private Set getObsoleteNodeIdsForRemoval(Set latestNodeIdsMarked } } } - return obsoleteNodeIdsForRemoval; + return nodeIdsNoLongerMarkedForRemoval; } public boolean nodeIdsForRemovalChanged(SnapshotsInProgress other) { @@ -616,6 +640,9 @@ public record ShardSnapshotStatus( "missing index" ); + /** + * Initializes status with state {@link ShardState#INIT}. + */ public ShardSnapshotStatus(String nodeId, ShardGeneration generation) { this(nodeId, ShardState.INIT, generation); } diff --git a/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java b/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java index 4c34f2e192a26..0dcb28278a66d 100644 --- a/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java +++ b/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java @@ -8,6 +8,8 @@ package org.elasticsearch.repositories; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.cluster.SnapshotsInProgress; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; @@ -30,6 +32,8 @@ */ public final class ShardGenerations { + private static final Logger logger = LogManager.getLogger(ShardGenerations.class); + public static final ShardGenerations EMPTY = new ShardGenerations(Collections.emptyMap()); /** @@ -88,7 +92,7 @@ public Collection indices() { } /** - * Computes the obsolete shard index generations that can be deleted once this instance was written to the repository. + * Computes the obsolete shard index generations that can be deleted once this instance is written to the repository. * Note: This method should only be used when finalizing a snapshot and we can safely assume that data has only been added but not * removed from shard paths. * @@ -109,6 +113,13 @@ public Map> obsoleteShardGenerations(Shar // Since this method assumes only additions and no removals of shards, a null updated generation means no update if (updatedGeneration != null && oldGeneration != null && oldGeneration.equals(updatedGeneration) == false) { obsoleteShardIndices.put(i, oldGeneration); + logger.debug( + "Marking snapshot generation [{}] for cleanup. The new generation is [{}]. Index [{}], shard ID [{}]", + oldGeneration, + updatedGeneration, + indexId, + i + ); } } result.put(indexId, Collections.unmodifiableMap(obsoleteShardIndices)); diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 7b3a83dfc9bb3..7606299c62bc8 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -241,7 +241,7 @@ private void handleUpdatedSnapshotsInProgressEntry(String localNodeId, boolean r } if (removingLocalNode) { - pauseShardSnapshots(localNodeId, entry); + pauseShardSnapshotsForNodeRemoval(localNodeId, entry); } else { startNewShardSnapshots(localNodeId, entry); } @@ -318,7 +318,7 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run)); } - private void pauseShardSnapshots(String localNodeId, SnapshotsInProgress.Entry entry) { + private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInProgress.Entry entry) { final var localShardSnapshots = shardSnapshots.getOrDefault(entry.snapshot(), Map.of()); for (final Map.Entry shardEntry : entry.shards().entrySet()) { @@ -606,8 +606,9 @@ private void syncShardStatsOnNewMaster(List entries) } else if (stage == Stage.PAUSED) { // but we think the shard has paused - we need to make new master know that logger.debug(""" - [{}] new master thinks the shard [{}] is still running but the shard paused locally, updating status on \ - master""", snapshot.snapshot(), shardId); + new master thinks that shard [{}] snapshot [{}], with shard generation [{}], is still running, but the \ + shard snapshot is paused locally, updating status on master + """, shardId, snapshot.snapshot(), localShard.getValue().generation()); notifyUnsuccessfulSnapshotShard( snapshot.snapshot(), shardId, @@ -648,6 +649,14 @@ private void notifyUnsuccessfulSnapshotShard( shardId, new ShardSnapshotStatus(clusterService.localNode().getId(), shardState, generation, failure) ); + if (shardState == ShardState.PAUSED_FOR_NODE_REMOVAL) { + logger.debug( + "Pausing shard [{}] snapshot [{}], with shard generation [{}], because this node is marked for removal", + shardId, + snapshot, + generation + ); + } } /** Updates the shard snapshot status by sending a {@link UpdateIndexShardSnapshotStatusRequest} to the master node */ diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java index cd7516a8f1232..9178050ff2a0b 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java @@ -999,39 +999,42 @@ public ClusterState execute(ClusterState currentState) { // We keep a cache of shards that failed in this map. If we fail a shardId for a given repository because of // a node leaving or shard becoming unassigned for one snapshot, we will also fail it for all subsequent enqueued // snapshots for the same repository + // // TODO: the code in this state update duplicates large chunks of the logic in #SHARD_STATE_EXECUTOR. // We should refactor it to ideally also go through #SHARD_STATE_EXECUTOR by hand-crafting shard state updates // that encapsulate nodes leaving or indices having been deleted and passing them to the executor instead. - SnapshotsInProgress updated = snapshots; + SnapshotsInProgress updatedSnapshots = snapshots; + for (final List snapshotsInRepo : snapshots.entriesByRepo()) { boolean changed = false; final List updatedEntriesForRepo = new ArrayList<>(); final Map knownFailures = new HashMap<>(); - final String repository = snapshotsInRepo.get(0).repository(); - for (SnapshotsInProgress.Entry snapshot : snapshotsInRepo) { - if (statesToUpdate.contains(snapshot.state())) { - if (snapshot.isClone()) { - if (snapshot.shardsByRepoShardId().isEmpty()) { + final String repositoryName = snapshotsInRepo.get(0).repository(); + for (SnapshotsInProgress.Entry snapshotEntry : snapshotsInRepo) { + if (statesToUpdate.contains(snapshotEntry.state())) { + if (snapshotEntry.isClone()) { + if (snapshotEntry.shardsByRepoShardId().isEmpty()) { // Currently initializing clone - if (initializingClones.contains(snapshot.snapshot())) { - updatedEntriesForRepo.add(snapshot); + if (initializingClones.contains(snapshotEntry.snapshot())) { + updatedEntriesForRepo.add(snapshotEntry); } else { - logger.debug("removing not yet start clone operation [{}]", snapshot); + logger.debug("removing not yet start clone operation [{}]", snapshotEntry); changed = true; } } else { // see if any clones may have had a shard become available for execution because of failures - if (deletes.hasExecutingDeletion(repository)) { + if (deletes.hasExecutingDeletion(repositoryName)) { // Currently executing a delete for this repo, no need to try and update any clone operations. // The logic for finishing the delete will update running clones with the latest changes. - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); continue; } ImmutableOpenMap.Builder clones = null; InFlightShardSnapshotStates inFlightShardSnapshotStates = null; for (Map.Entry failureEntry : knownFailures.entrySet()) { final RepositoryShardId repositoryShardId = failureEntry.getKey(); - final ShardSnapshotStatus existingStatus = snapshot.shardsByRepoShardId().get(repositoryShardId); + final ShardSnapshotStatus existingStatus = snapshotEntry.shardsByRepoShardId() + .get(repositoryShardId); if (ShardSnapshotStatus.UNASSIGNED_QUEUED.equals(existingStatus)) { if (inFlightShardSnapshotStates == null) { inFlightShardSnapshotStates = InFlightShardSnapshotStates.forEntries(updatedEntriesForRepo); @@ -1044,7 +1047,7 @@ public ClusterState execute(ClusterState currentState) { continue; } if (clones == null) { - clones = ImmutableOpenMap.builder(snapshot.shardsByRepoShardId()); + clones = ImmutableOpenMap.builder(snapshotEntry.shardsByRepoShardId()); } // We can use the generation from the shard failure to start the clone operation here // because #processWaitingShardsAndRemovedNodes adds generations to failure statuses that @@ -1060,50 +1063,54 @@ public ClusterState execute(ClusterState currentState) { } if (clones != null) { changed = true; - updatedEntriesForRepo.add(snapshot.withClones(clones.build())); + updatedEntriesForRepo.add(snapshotEntry.withClones(clones.build())); } else { - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } } else { + // Not a clone, and the snapshot is in STARTED or ABORTED state. + ImmutableOpenMap shards = processWaitingShardsAndRemovedNodes( - snapshot, + snapshotEntry, routingTable, nodes, snapshots::isNodeIdForRemoval, knownFailures ); if (shards != null) { - final SnapshotsInProgress.Entry updatedSnapshot = snapshot.withShardStates(shards); + final SnapshotsInProgress.Entry updatedSnapshot = snapshotEntry.withShardStates(shards); changed = true; if (updatedSnapshot.state().completed()) { finishedSnapshots.add(updatedSnapshot); } updatedEntriesForRepo.add(updatedSnapshot); } else { - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } - } else if (snapshot.repositoryStateId() == RepositoryData.UNKNOWN_REPO_GEN) { + } else if (snapshotEntry.repositoryStateId() == RepositoryData.UNKNOWN_REPO_GEN) { // BwC path, older versions could create entries with unknown repo GEN in INIT or ABORTED state that did not // yet write anything to the repository physically. This means we can simply remove these from the cluster // state without having to do any additional cleanup. changed = true; - logger.debug("[{}] was found in dangling INIT or ABORTED state", snapshot); + logger.debug("[{}] was found in dangling INIT or ABORTED state", snapshotEntry); } else { - if (snapshot.state().completed() || completed(snapshot.shardsByRepoShardId().values())) { - finishedSnapshots.add(snapshot); + // Now we're down to completed or un-modified snapshots + + if (snapshotEntry.state().completed() || completed(snapshotEntry.shardsByRepoShardId().values())) { + finishedSnapshots.add(snapshotEntry); } - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } if (changed) { - updated = updated.withUpdatedEntriesForRepo(repository, updatedEntriesForRepo); + updatedSnapshots = updatedSnapshots.withUpdatedEntriesForRepo(repositoryName, updatedEntriesForRepo); } } final ClusterState res = readyDeletions( - updated != snapshots - ? ClusterState.builder(currentState).putCustom(SnapshotsInProgress.TYPE, updated).build() + updatedSnapshots != snapshots + ? ClusterState.builder(currentState).putCustom(SnapshotsInProgress.TYPE, updatedSnapshots).build() : currentState ).v1(); for (SnapshotDeletionsInProgress.Entry delete : SnapshotDeletionsInProgress.get(res).getEntries()) { @@ -1151,31 +1158,39 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) }); } + /** + * Walks through the snapshot entries' shard snapshots and creates applies updates from looking at removed nodes or indexes and known + * failed shard snapshots on the same shard IDs. + * + * @param nodeIdRemovalPredicate identify any nodes that are marked for removal / in shutdown mode + * @param knownFailures already known failed shard snapshots, but more may be found in this method + * @return an updated map of shard statuses + */ private static ImmutableOpenMap processWaitingShardsAndRemovedNodes( - SnapshotsInProgress.Entry entry, + SnapshotsInProgress.Entry snapshotEntry, RoutingTable routingTable, DiscoveryNodes nodes, Predicate nodeIdRemovalPredicate, Map knownFailures ) { - assert entry.isClone() == false : "clones take a different path"; + assert snapshotEntry.isClone() == false : "clones take a different path"; boolean snapshotChanged = false; ImmutableOpenMap.Builder shards = ImmutableOpenMap.builder(); - for (Map.Entry shardEntry : entry.shardsByRepoShardId().entrySet()) { - ShardSnapshotStatus shardStatus = shardEntry.getValue(); - ShardId shardId = entry.shardId(shardEntry.getKey()); + for (Map.Entry shardSnapshotEntry : snapshotEntry.shardsByRepoShardId().entrySet()) { + ShardSnapshotStatus shardStatus = shardSnapshotEntry.getValue(); + ShardId shardId = snapshotEntry.shardId(shardSnapshotEntry.getKey()); if (shardStatus.equals(ShardSnapshotStatus.UNASSIGNED_QUEUED)) { // this shard snapshot is waiting for a previous snapshot to finish execution for this shard - final ShardSnapshotStatus knownFailure = knownFailures.get(shardEntry.getKey()); + final ShardSnapshotStatus knownFailure = knownFailures.get(shardSnapshotEntry.getKey()); if (knownFailure == null) { final IndexRoutingTable indexShardRoutingTable = routingTable.index(shardId.getIndex()); if (indexShardRoutingTable == null) { // shard became unassigned while queued after a delete or clone operation so we can fail as missing here - assert entry.partial(); + assert snapshotEntry.partial(); snapshotChanged = true; logger.debug("failing snapshot of shard [{}] because index got deleted", shardId); shards.put(shardId, ShardSnapshotStatus.MISSING); - knownFailures.put(shardEntry.getKey(), ShardSnapshotStatus.MISSING); + knownFailures.put(shardSnapshotEntry.getKey(), ShardSnapshotStatus.MISSING); } else { // if no failure is known for the shard we keep waiting shards.put(shardId, shardStatus); @@ -1187,6 +1202,7 @@ private static ImmutableOpenMap processWaitingShar shards.put(shardId, knownFailure); } } else if (shardStatus.state() == ShardState.WAITING || shardStatus.state() == ShardState.PAUSED_FOR_NODE_REMOVAL) { + // The shard primary wasn't assigned, or the shard snapshot was paused because the node was shutting down. IndexRoutingTable indexShardRoutingTable = routingTable.index(shardId.getIndex()); if (indexShardRoutingTable != null) { IndexShardRoutingTable shardRouting = indexShardRoutingTable.shard(shardId.id()); @@ -1208,7 +1224,10 @@ private static ImmutableOpenMap processWaitingShar } else if (shardRouting.primaryShard().started()) { // Shard that we were waiting for has started on a node, let's process it snapshotChanged = true; - logger.trace("starting shard that we were waiting for [{}] on node [{}]", shardId, shardStatus.nodeId()); + logger.debug(""" + Starting shard [{}] with shard generation [{}] that we were waiting to start on node [{}]. Previous \ + shard state [{}] + """, shardId, shardStatus.generation(), shardStatus.nodeId(), shardStatus.state()); shards.put(shardId, new ShardSnapshotStatus(primaryNodeId, shardStatus.generation())); continue; } else if (shardRouting.primaryShard().initializing() || shardRouting.primaryShard().relocating()) { @@ -1218,7 +1237,7 @@ private static ImmutableOpenMap processWaitingShar } } } - // Shard that we were waiting for went into unassigned state or disappeared - giving up + // Shard that we were waiting for went into unassigned state or disappeared (index or shard is gone) - giving up snapshotChanged = true; logger.warn("failing snapshot of shard [{}] on unassigned shard [{}]", shardId, shardStatus.nodeId()); final ShardSnapshotStatus failedState = new ShardSnapshotStatus( @@ -1228,7 +1247,7 @@ private static ImmutableOpenMap processWaitingShar "shard is unassigned" ); shards.put(shardId, failedState); - knownFailures.put(shardEntry.getKey(), failedState); + knownFailures.put(shardSnapshotEntry.getKey(), failedState); } else if (shardStatus.state().completed() == false && shardStatus.nodeId() != null) { if (nodes.nodeExists(shardStatus.nodeId())) { shards.put(shardId, shardStatus); @@ -1243,7 +1262,7 @@ private static ImmutableOpenMap processWaitingShar "node left the cluster during snapshot" ); shards.put(shardId, failedState); - knownFailures.put(shardEntry.getKey(), failedState); + knownFailures.put(shardSnapshotEntry.getKey(), failedState); } } else { shards.put(shardId, shardStatus); From e95cbb48aa436b9fce8c77a868d99ee67104e13f Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Mon, 8 Jul 2024 17:01:47 -0500 Subject: [PATCH 26/64] Updating CloseIndexRequestTests to account for master term (#110611) --- .../action/admin/indices/close/CloseIndexRequestTests.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java index b3caf93fbcddf..24c0f9d97800b 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java @@ -49,6 +49,9 @@ public void testBwcSerialization() throws Exception { in.setTransportVersion(out.getTransportVersion()); assertEquals(request.getParentTask(), TaskId.readFromStream(in)); assertEquals(request.masterNodeTimeout(), in.readTimeValue()); + if (in.getTransportVersion().onOrAfter(TransportVersions.VERSIONED_MASTER_NODE_REQUESTS)) { + assertEquals(request.masterTerm(), in.readVLong()); + } assertEquals(request.ackTimeout(), in.readTimeValue()); assertArrayEquals(request.indices(), in.readStringArray()); final IndicesOptions indicesOptions = IndicesOptions.readIndicesOptions(in); @@ -75,6 +78,9 @@ public void testBwcSerialization() throws Exception { out.setTransportVersion(version); sample.getParentTask().writeTo(out); out.writeTimeValue(sample.masterNodeTimeout()); + if (out.getTransportVersion().onOrAfter(TransportVersions.VERSIONED_MASTER_NODE_REQUESTS)) { + out.writeVLong(sample.masterTerm()); + } out.writeTimeValue(sample.ackTimeout()); out.writeStringArray(sample.indices()); sample.indicesOptions().writeIndicesOptions(out); From bdf9a2e0cc683820c6f9d8749106b9dc2d02385c Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 8 Jul 2024 16:34:44 -0700 Subject: [PATCH 27/64] Fix BWC for compute listener (#110615) ComputeResponse from old nodes may have a null value instead of an empty list for profiles. Relates #110400 Closes #110591 --- muted-tests.yml | 2 -- .../org/elasticsearch/xpack/esql/plugin/ComputeListener.java | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index dc1ba7b855d83..ccbdb68fbb8c7 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -97,8 +97,6 @@ tests: - class: "org.elasticsearch.xpack.searchablesnapshots.FrozenSearchableSnapshotsIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/110408" method: "testCreateAndRestorePartialSearchableSnapshot" -- class: "org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT" - issue: "https://github.com/elastic/elasticsearch/issues/110591" # Examples: # diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java index f8f35bb6f0b4f..01d50d505f7f2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java @@ -76,8 +76,9 @@ ActionListener acquireAvoid() { ActionListener acquireCompute() { return acquireAvoid().map(resp -> { responseHeaders.collect(); - if (resp != null && resp.getProfiles().isEmpty() == false) { - collectedProfiles.addAll(resp.getProfiles()); + var profiles = resp.getProfiles(); + if (profiles != null && profiles.isEmpty() == false) { + collectedProfiles.addAll(profiles); } return null; }); From 3878ae779eea969c6abc708eae1e748a641f7ac7 Mon Sep 17 00:00:00 2001 From: Oleksandr Kolomiiets Date: Mon, 8 Jul 2024 16:36:51 -0700 Subject: [PATCH 28/64] Actually fix RollupIndexerStateTests#testMultipleJobTriggering (#110616) `assertBusy` needs an assert to do something and there were none. Closes #109627. --- .../xpack/rollup/job/RollupIndexerStateTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java index 105711c4057a6..7a947fcb5ce02 100644 --- a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java +++ b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java @@ -556,7 +556,7 @@ public void testMultipleJobTriggering() throws Exception { assertThat(indexer.getState(), equalTo(IndexerState.STARTED)); // This may take more than one attempt due to a cleanup/transition phase // that happens after state change to STARTED (`isJobFinishing`). - assertBusy(() -> indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); + assertBusy(() -> assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()))); assertThat(indexer.getState(), equalTo(IndexerState.INDEXING)); assertFalse(indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); assertThat(indexer.getState(), equalTo(IndexerState.INDEXING)); @@ -566,7 +566,7 @@ public void testMultipleJobTriggering() throws Exception { assertThat(indexer.getStats().getNumPages(), equalTo((long) i + 1)); } final CountDownLatch latch = indexer.newLatch(); - assertBusy(() -> indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); + assertBusy(() -> assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()))); assertThat(indexer.stop(), equalTo(IndexerState.STOPPING)); assertThat(indexer.getState(), Matchers.either(Matchers.is(IndexerState.STOPPING)).or(Matchers.is(IndexerState.STOPPED))); latch.countDown(); From 822ab728676b4a80b7eb44cf4279e470170d267a Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 9 Jul 2024 01:42:38 +0200 Subject: [PATCH 29/64] Enhance test utility for running tasks in parallel (#110610) Follow up to #110552, add utility for starting tasks at the same time via a barrier as discussed there. Also, make use of the new tooling in a couple more spots to save LoC and thread creation. --- .../core/AbstractRefCountedTests.java | 39 +++--- .../reindex/DeleteByQueryConcurrentTests.java | 70 +++-------- .../elasticsearch/backwards/IndexingIT.java | 21 ++-- .../action/bulk/BulkWithUpdatesIT.java | 9 +- .../elasticsearch/blocks/SimpleBlocksIT.java | 42 ++----- .../index/engine/MaxDocsLimitIT.java | 5 +- .../index/mapper/DynamicMappingIT.java | 5 +- .../index/seqno/GlobalCheckpointSyncIT.java | 10 +- .../mapping/UpdateMappingIntegrationIT.java | 94 ++++++-------- .../indices/state/CloseIndexIT.java | 11 +- .../state/CloseWhileRelocatingShardsIT.java | 10 +- .../action/ActionListenerTests.java | 27 ++-- .../concurrent/AsyncIOProcessorTests.java | 118 ++++++------------ .../AbstractProfileBreakdownTests.java | 35 ++---- .../index/engine/EngineTestCase.java | 40 +++--- .../org/elasticsearch/test/ESTestCase.java | 12 ++ 16 files changed, 179 insertions(+), 369 deletions(-) diff --git a/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java b/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java index 9610bae32a775..74dcd19248834 100644 --- a/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java +++ b/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.test.ESTestCase; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.equalTo; @@ -62,32 +61,22 @@ public void testRefCount() { public void testMultiThreaded() throws InterruptedException { final AbstractRefCounted counted = createRefCounted(); - final Thread[] threads = new Thread[randomIntBetween(2, 5)]; - final CountDownLatch latch = new CountDownLatch(1); - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - try { - latch.await(); - for (int j = 0; j < 10000; j++) { - assertTrue(counted.hasReferences()); - if (randomBoolean()) { - counted.incRef(); - } else { - assertTrue(counted.tryIncRef()); - } - assertTrue(counted.hasReferences()); - counted.decRef(); + startInParallel(randomIntBetween(2, 5), i -> { + try { + for (int j = 0; j < 10000; j++) { + assertTrue(counted.hasReferences()); + if (randomBoolean()) { + counted.incRef(); + } else { + assertTrue(counted.tryIncRef()); } - } catch (Exception e) { - throw new AssertionError(e); + assertTrue(counted.hasReferences()); + counted.decRef(); } - }); - threads[i].start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (Exception e) { + throw new AssertionError(e); + } + }); counted.decRef(); assertFalse(counted.hasReferences()); assertThat( diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java index 323b829fe93ff..190616b9980f0 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java @@ -11,11 +11,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.reindex.BulkByScrollResponse; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import static org.elasticsearch.index.query.QueryBuilders.matchQuery; @@ -26,44 +24,29 @@ public class DeleteByQueryConcurrentTests extends ReindexTestCase { public void testConcurrentDeleteByQueriesOnDifferentDocs() throws Throwable { - final Thread[] threads = new Thread[scaledRandomIntBetween(2, 5)]; + final int threadCount = scaledRandomIntBetween(2, 5); final long docs = randomIntBetween(1, 50); List builders = new ArrayList<>(); for (int i = 0; i < docs; i++) { - for (int t = 0; t < threads.length; t++) { + for (int t = 0; t < threadCount; t++) { builders.add(prepareIndex("test").setSource("field", t)); } } indexRandom(true, true, true, builders); - final CountDownLatch start = new CountDownLatch(1); - for (int t = 0; t < threads.length; t++) { - final int threadNum = t; - assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", threadNum)), docs); - - Runnable r = () -> { - try { - start.await(); - - assertThat( - deleteByQuery().source("_all").filter(termQuery("field", threadNum)).refresh(true).get(), - matcher().deleted(docs) - ); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }; - threads[t] = new Thread(r); - threads[t].start(); + for (int t = 0; t < threadCount; t++) { + assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", t)), docs); } - - start.countDown(); - for (Thread thread : threads) { - thread.join(); - } - - for (int t = 0; t < threads.length; t++) { + startInParallel( + threadCount, + threadNum -> assertThat( + deleteByQuery().source("_all").filter(termQuery("field", threadNum)).refresh(true).get(), + matcher().deleted(docs) + ) + ); + + for (int t = 0; t < threadCount; t++) { assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", t)), 0); } } @@ -77,33 +60,12 @@ public void testConcurrentDeleteByQueriesOnSameDocs() throws Throwable { } indexRandom(true, true, true, builders); - final Thread[] threads = new Thread[scaledRandomIntBetween(2, 9)]; + final int threadCount = scaledRandomIntBetween(2, 9); - final CountDownLatch start = new CountDownLatch(1); final MatchQueryBuilder query = matchQuery("foo", "bar"); final AtomicLong deleted = new AtomicLong(0); - - for (int t = 0; t < threads.length; t++) { - Runnable r = () -> { - try { - start.await(); - - BulkByScrollResponse response = deleteByQuery().source("test").filter(query).refresh(true).get(); - // Some deletions might fail due to version conflict, but - // what matters here is the total of successful deletions - deleted.addAndGet(response.getDeleted()); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }; - threads[t] = new Thread(r); - threads[t].start(); - } - - start.countDown(); - for (Thread thread : threads) { - thread.join(); - } + // Some deletions might fail due to version conflict, but what matters here is the total of successful deletions + startInParallel(threadCount, i -> deleted.addAndGet(deleteByQuery().source("test").filter(query).refresh(true).get().getDeleted())); assertHitCount(prepareSearch("test").setSize(0), 0L); assertThat(deleted.get(), equalTo(docs)); diff --git a/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java b/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java index aac4b6a020d4b..6c924fe8e429a 100644 --- a/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java +++ b/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java @@ -59,20 +59,13 @@ private int indexDocs(String index, final int idStart, final int numDocs) throws */ private int indexDocWithConcurrentUpdates(String index, final int docId, int nUpdates) throws IOException, InterruptedException { indexDocs(index, docId, 1); - Thread[] indexThreads = new Thread[nUpdates]; - for (int i = 0; i < nUpdates; i++) { - indexThreads[i] = new Thread(() -> { - try { - indexDocs(index, docId, 1); - } catch (IOException e) { - throw new AssertionError("failed while indexing [" + e.getMessage() + "]"); - } - }); - indexThreads[i].start(); - } - for (Thread indexThread : indexThreads) { - indexThread.join(); - } + runInParallel(nUpdates, i -> { + try { + indexDocs(index, docId, 1); + } catch (IOException e) { + throw new AssertionError("failed while indexing [" + e.getMessage() + "]"); + } + }); return nUpdates + 1; } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java index cfdf667f6c02e..5251f171150b7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java @@ -39,7 +39,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CyclicBarrier; import java.util.function.Function; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -518,14 +517,8 @@ public void testFailingVersionedUpdatedOnBulk() throws Exception { createIndex("test"); indexDoc("test", "1", "field", "1"); final BulkResponse[] responses = new BulkResponse[30]; - final CyclicBarrier cyclicBarrier = new CyclicBarrier(responses.length); - runInParallel(responses.length, threadID -> { - try { - cyclicBarrier.await(); - } catch (Exception e) { - return; - } + startInParallel(responses.length, threadID -> { BulkRequestBuilder requestBuilder = client().prepareBulk(); requestBuilder.add( client().prepareUpdate("test", "1") diff --git a/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java index 1cc771ab72c09..c5c3e441363da 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java @@ -32,7 +32,6 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import java.util.stream.IntStream; @@ -327,10 +326,8 @@ public void testConcurrentAddBlock() throws InterruptedException, ExecutionExcep final APIBlock block = randomAddableBlock(); final int threadCount = randomIntBetween(2, 5); - final CyclicBarrier barrier = new CyclicBarrier(threadCount); try { - runInParallel(threadCount, i -> { - safeAwait(barrier); + startInParallel(threadCount, i -> { try { indicesAdmin().prepareAddBlock(block, indexName).get(); assertIndexHasBlock(block, indexName); @@ -414,34 +411,17 @@ public void testAddBlockWhileDeletingIndices() throws Exception { }; try { - for (final String indexToDelete : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - assertAcked(indicesAdmin().prepareDelete(indexToDelete)); - } catch (final Exception e) { - exceptionConsumer.accept(e); - } - })); - } - for (final String indexToBlock : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - indicesAdmin().prepareAddBlock(block, indexToBlock).get(); - } catch (final Exception e) { - exceptionConsumer.accept(e); + startInParallel(indices.length * 2, i -> { + try { + if (i < indices.length) { + assertAcked(indicesAdmin().prepareDelete(indices[i])); + } else { + indicesAdmin().prepareAddBlock(block, indices[i - indices.length]).get(); } - })); - } - - for (Thread thread : threads) { - thread.start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (final Exception e) { + exceptionConsumer.accept(e); + } + }); } finally { for (final String indexToBlock : indices) { try { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java index 409a57b35ac4b..d475208d7e1ff 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java @@ -26,7 +26,6 @@ import java.util.Collection; import java.util.Optional; -import java.util.concurrent.Phaser; import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -155,9 +154,7 @@ static IndexingResult indexDocs(int numRequests, int numThreads) throws Exceptio final AtomicInteger completedRequests = new AtomicInteger(); final AtomicInteger numSuccess = new AtomicInteger(); final AtomicInteger numFailure = new AtomicInteger(); - Phaser phaser = new Phaser(numThreads); - runInParallel(numThreads, i -> { - phaser.arriveAndAwaitAdvance(); + startInParallel(numThreads, i -> { while (completedRequests.incrementAndGet() <= numRequests) { try { final DocWriteResponse resp = prepareIndex("test").setSource("{}", XContentType.JSON).get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java index 463ac49d60e47..3f79d7723beb3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java @@ -46,7 +46,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -162,12 +161,10 @@ public void testConcurrentDynamicIgnoreBeyondLimitUpdates() throws Throwable { private Map indexConcurrently(int numberOfFieldsToCreate, Settings.Builder settings) throws Throwable { indicesAdmin().prepareCreate("index").setSettings(settings).get(); ensureGreen("index"); - final CyclicBarrier barrier = new CyclicBarrier(numberOfFieldsToCreate); final AtomicReference error = new AtomicReference<>(); - runInParallel(numberOfFieldsToCreate, i -> { + startInParallel(numberOfFieldsToCreate, i -> { final String id = Integer.toString(i); try { - barrier.await(); assertEquals( DocWriteResponse.Result.CREATED, prepareIndex("index").setId(id).setSource("field" + id, "bar").get().getResult() diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java index 6a7c7bcf9d9bf..53f632f6ba8d5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java @@ -26,8 +26,6 @@ import org.elasticsearch.xcontent.XContentType; import java.util.Collection; -import java.util.concurrent.BrokenBarrierException; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Stream; @@ -141,15 +139,9 @@ private void runGlobalCheckpointSyncTest( final int numberOfDocuments = randomIntBetween(0, 256); final int numberOfThreads = randomIntBetween(1, 4); - final CyclicBarrier barrier = new CyclicBarrier(numberOfThreads); // start concurrent indexing threads - runInParallel(numberOfThreads, index -> { - try { - barrier.await(); - } catch (BrokenBarrierException | InterruptedException e) { - throw new RuntimeException(e); - } + startInParallel(numberOfThreads, index -> { for (int j = 0; j < numberOfDocuments; j++) { final String id = Integer.toString(index * numberOfDocuments + j); prepareIndex("test").setId(id).setSource("{\"foo\": " + id + "}", XContentType.JSON).get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java index 70cd143686dc8..0008ec1f9cbd2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java @@ -37,7 +37,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -179,66 +178,53 @@ public void testUpdateMappingConcurrently() throws Throwable { final AtomicReference threadException = new AtomicReference<>(); final AtomicBoolean stop = new AtomicBoolean(false); - Thread[] threads = new Thread[3]; - final CyclicBarrier barrier = new CyclicBarrier(threads.length); final ArrayList clientArray = new ArrayList<>(); for (Client c : clients()) { clientArray.add(c); } - for (int j = 0; j < threads.length; j++) { - threads[j] = new Thread(() -> { - try { - barrier.await(); - - for (int i = 0; i < 100; i++) { - if (stop.get()) { - return; - } - - Client client1 = clientArray.get(i % clientArray.size()); - Client client2 = clientArray.get((i + 1) % clientArray.size()); - String indexName = i % 2 == 0 ? "test2" : "test1"; - String fieldName = Thread.currentThread().getName() + "_" + i; - - AcknowledgedResponse response = client1.admin() - .indices() - .preparePutMapping(indexName) - .setSource( - JsonXContent.contentBuilder() - .startObject() - .startObject("_doc") - .startObject("properties") - .startObject(fieldName) - .field("type", "text") - .endObject() - .endObject() - .endObject() - .endObject() - ) - .setMasterNodeTimeout(TimeValue.timeValueMinutes(5)) - .get(); - - assertThat(response.isAcknowledged(), equalTo(true)); - GetMappingsResponse getMappingResponse = client2.admin().indices().prepareGetMappings(indexName).get(); - MappingMetadata mappings = getMappingResponse.getMappings().get(indexName); - @SuppressWarnings("unchecked") - Map properties = (Map) mappings.getSourceAsMap().get("properties"); - assertThat(properties.keySet(), Matchers.hasItem(fieldName)); + startInParallel(3, j -> { + try { + for (int i = 0; i < 100; i++) { + if (stop.get()) { + return; } - } catch (Exception e) { - threadException.set(e); - stop.set(true); - } - }); - - threads[j].setName("t_" + j); - threads[j].start(); - } - for (Thread t : threads) { - t.join(); - } + Client client1 = clientArray.get(i % clientArray.size()); + Client client2 = clientArray.get((i + 1) % clientArray.size()); + String indexName = i % 2 == 0 ? "test2" : "test1"; + String fieldName = "t_" + j + "_" + i; + + AcknowledgedResponse response = client1.admin() + .indices() + .preparePutMapping(indexName) + .setSource( + JsonXContent.contentBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject(fieldName) + .field("type", "text") + .endObject() + .endObject() + .endObject() + .endObject() + ) + .setMasterNodeTimeout(TimeValue.timeValueMinutes(5)) + .get(); + + assertThat(response.isAcknowledged(), equalTo(true)); + GetMappingsResponse getMappingResponse = client2.admin().indices().prepareGetMappings(indexName).get(); + MappingMetadata mappings = getMappingResponse.getMappings().get(indexName); + @SuppressWarnings("unchecked") + Map properties = (Map) mappings.getSourceAsMap().get("properties"); + assertThat(properties.keySet(), Matchers.hasItem(fieldName)); + } + } catch (Exception e) { + threadException.set(e); + stop.set(true); + } + }); if (threadException.get() != null) { throw threadException.get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java index 1751ffd7f1cfb..d52294d7584b8 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java @@ -197,9 +197,7 @@ public void testConcurrentClose() throws InterruptedException, ExecutionExceptio assertThat(healthResponse.getIndices().get(indexName).getStatus().value(), lessThanOrEqualTo(ClusterHealthStatus.YELLOW.value())); final int tasks = randomIntBetween(2, 5); - final CyclicBarrier barrier = new CyclicBarrier(tasks); - runInParallel(tasks, i -> { - safeAwait(barrier); + startInParallel(tasks, i -> { try { indicesAdmin().prepareClose(indexName).get(); } catch (final Exception e) { @@ -247,9 +245,7 @@ public void testCloseWhileDeletingIndices() throws Exception { } assertThat(clusterAdmin().prepareState().get().getState().metadata().indices().size(), equalTo(indices.length)); - final CyclicBarrier barrier = new CyclicBarrier(indices.length * 2); - runInParallel(indices.length * 2, i -> { - safeAwait(barrier); + startInParallel(indices.length * 2, i -> { final String index = indices[i % indices.length]; try { if (i < indices.length) { @@ -275,9 +271,8 @@ public void testConcurrentClosesAndOpens() throws Exception { final int opens = randomIntBetween(1, 3); final CyclicBarrier barrier = new CyclicBarrier(opens + closes); - runInParallel(opens + closes, i -> { + startInParallel(opens + closes, i -> { try { - safeAwait(barrier); if (i < closes) { indicesAdmin().prepareClose(indexName).get(); } else { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java index 9eb69c87a52e8..6647356f070ae 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java @@ -38,7 +38,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -187,13 +186,8 @@ public void testCloseWhileRelocatingShards() throws Exception { ClusterRerouteUtils.reroute(client(), commands.toArray(AllocationCommand[]::new)); // start index closing threads - final CyclicBarrier barrier = new CyclicBarrier(indices.length); - runInParallel(indices.length, i -> { - try { - safeAwait(barrier); - } finally { - release.countDown(); - } + startInParallel(indices.length, i -> { + release.countDown(); // Closing is not always acknowledged when shards are relocating: this is the case when the target shard is initializing // or is catching up operations. In these cases the TransportVerifyShardBeforeCloseAction will detect that the global // and max sequence number don't match and will not ack the close. diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index 0543bce08a4f0..463203c1357b9 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -308,25 +307,13 @@ public String toString() { }); assertThat(listener.toString(), equalTo("notifyOnce[inner-listener]")); - final var threads = new Thread[between(1, 10)]; - final var startBarrier = new CyclicBarrier(threads.length); - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - safeAwait(startBarrier); - if (randomBoolean()) { - listener.onResponse(null); - } else { - listener.onFailure(new RuntimeException("test")); - } - }); - } - - for (Thread thread : threads) { - thread.start(); - } - for (Thread thread : threads) { - thread.join(); - } + startInParallel(between(1, 10), i -> { + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new RuntimeException("test")); + } + }); assertTrue(completed.get()); } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java index 65bcb473f7d22..0392a3f5ab4e1 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java @@ -54,32 +54,19 @@ protected void write(List>> candidates) throws }; Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); final int count = randomIntBetween(1000, 20000); - Thread[] thread = new Thread[randomIntBetween(3, 10)]; - CountDownLatch latch = new CountDownLatch(thread.length); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread() { - @Override - public void run() { - try { - latch.countDown(); - latch.await(); - for (int i = 0; i < count; i++) { - semaphore.acquire(); - processor.put(new Object(), (ex) -> semaphore.release()); - } - } catch (Exception ex) { - throw new RuntimeException(ex); - } + final int threads = randomIntBetween(3, 10); + startInParallel(threads, t -> { + for (int i = 0; i < count; i++) { + try { + semaphore.acquire(); + processor.put(new Object(), (ex) -> semaphore.release()); + } catch (Exception ex) { + throw new RuntimeException(ex); } - }; - thread[i].start(); - } - - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } + }); safeAcquire(10, semaphore); - assertEquals(count * thread.length, received.get()); + assertEquals(count * threads, received.get()); } public void testRandomFail() throws InterruptedException { @@ -102,37 +89,24 @@ protected void write(List>> candidates) throws }; Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); final int count = randomIntBetween(1000, 20000); - Thread[] thread = new Thread[randomIntBetween(3, 10)]; - CountDownLatch latch = new CountDownLatch(thread.length); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread() { - @Override - public void run() { - try { - latch.countDown(); - latch.await(); - for (int i = 0; i < count; i++) { - semaphore.acquire(); - processor.put(new Object(), (ex) -> { - if (ex != null) { - actualFailed.incrementAndGet(); - } - semaphore.release(); - }); + final int threads = randomIntBetween(3, 10); + startInParallel(threads, t -> { + try { + for (int i = 0; i < count; i++) { + semaphore.acquire(); + processor.put(new Object(), (ex) -> { + if (ex != null) { + actualFailed.incrementAndGet(); } - } catch (Exception ex) { - throw new RuntimeException(ex); - } + semaphore.release(); + }); } - }; - thread[i].start(); - } - - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + }); safeAcquire(Integer.MAX_VALUE, semaphore); - assertEquals(count * thread.length, received.get()); + assertEquals(count * threads, received.get()); assertEquals(actualFailed.get(), failed.get()); } @@ -226,7 +200,7 @@ public void run() { threads.forEach(t -> assertFalse(t.isAlive())); } - public void testSlowConsumer() { + public void testSlowConsumer() throws InterruptedException { AtomicInteger received = new AtomicInteger(0); AtomicInteger notified = new AtomicInteger(0); @@ -240,39 +214,23 @@ protected void write(List>> candidates) throws int threadCount = randomIntBetween(2, 10); CyclicBarrier barrier = new CyclicBarrier(threadCount); Semaphore serializePutSemaphore = new Semaphore(1); - List threads = IntStream.range(0, threadCount).mapToObj(i -> new Thread(getTestName() + "_" + i) { - { - setDaemon(true); - } - - @Override - public void run() { - try { - assertTrue(serializePutSemaphore.tryAcquire(10, TimeUnit.SECONDS)); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - processor.put(new Object(), (e) -> { - serializePutSemaphore.release(); - try { - barrier.await(10, TimeUnit.SECONDS); - } catch (InterruptedException | BrokenBarrierException | TimeoutException ex) { - throw new RuntimeException(ex); - } - notified.incrementAndGet(); - }); - } - }).toList(); - threads.forEach(Thread::start); - threads.forEach(t -> { + runInParallel(threadCount, t -> { try { - t.join(20000); + assertTrue(serializePutSemaphore.tryAcquire(10, TimeUnit.SECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } + processor.put(new Object(), (e) -> { + serializePutSemaphore.release(); + try { + barrier.await(10, TimeUnit.SECONDS); + } catch (InterruptedException | BrokenBarrierException | TimeoutException ex) { + throw new RuntimeException(ex); + } + notified.incrementAndGet(); + }); }); assertEquals(threadCount, notified.get()); assertEquals(threadCount, received.get()); - threads.forEach(t -> assertFalse(t.isAlive())); } } diff --git a/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java b/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java index b8b12357b085e..e988599fccc3b 100644 --- a/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java +++ b/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.test.ESTestCase; import java.util.Map; -import java.util.concurrent.CountDownLatch; import static org.hamcrest.Matchers.equalTo; @@ -107,35 +106,21 @@ public void testGetBreakdownAndNodeTime() { public void testMultiThreaded() throws InterruptedException { TestProfileBreakdown testBreakdown = new TestProfileBreakdown(); - Thread[] threads = new Thread[200]; - final CountDownLatch latch = new CountDownLatch(1); + final int threads = 200; int startsPerThread = between(1, 5); - for (int t = 0; t < threads.length; t++) { - final TestTimingTypes timingType = randomFrom(TestTimingTypes.values()); - threads[t] = new Thread(() -> { - try { - latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - Timer timer = testBreakdown.getNewTimer(timingType); - for (int runs = 0; runs < startsPerThread; runs++) { - timer.start(); - timer.stop(); - } - }); - threads[t].start(); - } // starting all threads simultaneously increases the likelihood of failure in case we don't synchronize timer access properly - latch.countDown(); - for (Thread t : threads) { - t.join(); - } + startInParallel(threads, t -> { + final TestTimingTypes timingType = randomFrom(TestTimingTypes.values()); + Timer timer = testBreakdown.getNewTimer(timingType); + for (int runs = 0; runs < startsPerThread; runs++) { + timer.start(); + timer.stop(); + } + }); Map breakdownMap = testBreakdown.toBreakdownMap(); long totalCounter = breakdownMap.get(TestTimingTypes.ONE + "_count") + breakdownMap.get(TestTimingTypes.TWO + "_count") + breakdownMap.get(TestTimingTypes.THREE + "_count"); - assertEquals(threads.length * startsPerThread, totalCounter); - + assertEquals(threads * startsPerThread, totalCounter); } private void runTimerNTimes(Timer t, int n) { diff --git a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java index 1c7cabb541581..70738c510f62a 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java @@ -125,7 +125,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -1179,33 +1178,24 @@ public static void assertOpsOnReplica( } public static void concurrentlyApplyOps(List ops, InternalEngine engine) throws InterruptedException { - Thread[] thread = new Thread[randomIntBetween(3, 5)]; - CountDownLatch startGun = new CountDownLatch(thread.length); + final int threadCount = randomIntBetween(3, 5); AtomicInteger offset = new AtomicInteger(-1); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread(() -> { - startGun.countDown(); - safeAwait(startGun); - int docOffset; - while ((docOffset = offset.incrementAndGet()) < ops.size()) { - try { - applyOperation(engine, ops.get(docOffset)); - if ((docOffset + 1) % 4 == 0) { - engine.refresh("test"); - } - if (rarely()) { - engine.flush(); - } - } catch (IOException e) { - throw new AssertionError(e); + startInParallel(threadCount, i -> { + int docOffset; + while ((docOffset = offset.incrementAndGet()) < ops.size()) { + try { + applyOperation(engine, ops.get(docOffset)); + if ((docOffset + 1) % 4 == 0) { + engine.refresh("test"); + } + if (rarely()) { + engine.flush(); } + } catch (IOException e) { + throw new AssertionError(e); } - }); - thread[i].start(); - } - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } + }); } public static void applyOperations(Engine engine, List operations) throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 68fc6b41e0be0..7295dce7a257a 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -2434,6 +2434,18 @@ public static T expectThrows(Class expectedType, Reques ); } + /** + * Same as {@link #runInParallel(int, IntConsumer)} but also attempts to start all tasks at the same time by blocking execution on a + * barrier until all threads are started and ready to execute their task. + */ + public static void startInParallel(int numberOfTasks, IntConsumer taskFactory) throws InterruptedException { + final CyclicBarrier barrier = new CyclicBarrier(numberOfTasks); + runInParallel(numberOfTasks, i -> { + safeAwait(barrier); + taskFactory.accept(i); + }); + } + /** * Run {@code numberOfTasks} parallel tasks that were created by the given {@code taskFactory}. On of the tasks will be run on the * calling thread, the rest will be run on a new thread. From b54bf0b1f8809c0cb27b1d838b544071da361d31 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 9 Jul 2024 08:18:03 +0300 Subject: [PATCH 30/64] Updating ESSingleNodeTestCase to ensure that all free_context actions have been consumed before tearDown (#110595) --- .../org/elasticsearch/test/ESSingleNodeTestCase.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java index 8526acc851c72..7fdc5765a90e8 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java @@ -69,6 +69,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.action.search.SearchTransportService.FREE_CONTEXT_ACTION_NAME; import static org.elasticsearch.cluster.coordination.ClusterBootstrapService.INITIAL_MASTER_NODES_SETTING; import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING; import static org.elasticsearch.test.NodeRoles.dataNode; @@ -130,6 +131,8 @@ public void tearDown() throws Exception { logger.trace("[{}#{}]: cleaning up after test", getTestClass().getSimpleName(), getTestName()); awaitIndexShardCloseAsyncTasks(); ensureNoInitializingShards(); + ensureAllFreeContextActionsAreConsumed(); + SearchService searchService = getInstanceFromNode(SearchService.class); assertThat(searchService.getActiveContexts(), equalTo(0)); assertThat(searchService.getOpenScrollContexts(), equalTo(0)); @@ -455,6 +458,14 @@ protected void ensureNoInitializingShards() { assertFalse("timed out waiting for shards to initialize", actionGet.isTimedOut()); } + /** + * waits until all free_context actions have been handled by the generic thread pool + */ + protected void ensureAllFreeContextActionsAreConsumed() throws Exception { + logger.info("--> waiting for all free_context tasks to complete within a reasonable time"); + safeGet(clusterAdmin().prepareListTasks().setActions(FREE_CONTEXT_ACTION_NAME + "*").setWaitForCompletion(true).execute()); + } + /** * Whether we'd like to enable inter-segment search concurrency and increase the likelihood of leveraging it, by creating multiple * slices with a low amount of documents in them, which would not be allowed in production. From fb10d61db3004d3c1de735f45f72244fb5f22b62 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 8 Jul 2024 22:34:31 -0700 Subject: [PATCH 31/64] Fix translate metrics without rate (#110614) Currently, we incorrectly remove the `@timestamp` attribute from the EsRelation when translating metric aggregates. --- .../rules/TranslateMetricsAggregate.java | 2 +- .../optimizer/LogicalPlanOptimizerTests.java | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java index 64555184be12d..10c7a7325debc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java @@ -217,7 +217,7 @@ private static Aggregate toStandardAggregate(Aggregate metrics) { final LogicalPlan child = metrics.child().transformDown(EsRelation.class, r -> { var attributes = new ArrayList<>(new AttributeSet(metrics.inputSet())); attributes.removeIf(a -> a.name().equals(MetadataAttribute.TSID_FIELD)); - if (attributes.stream().noneMatch(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)) == false) { + if (attributes.stream().noneMatch(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD))) { attributes.removeIf(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)); } return new EsRelation(r.source(), r.index(), new ArrayList<>(attributes), IndexMode.STANDARD); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 7ace781652419..dea3a974fbd5a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -5477,6 +5477,56 @@ METRICS k8s avg(round(1.05 * rate(network.total_bytes_in))) BY bucket(@timestamp assertThat(Expressions.attribute(values.field()).name(), equalTo("cluster")); } + public void testMetricsWithoutRate() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + List queries = List.of(""" + METRICS k8s count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """, """ + METRICS k8s | STATS count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """, """ + FROM k8s | STATS count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """); + List plans = new ArrayList<>(); + for (String query : queries) { + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + plans.add(plan); + } + for (LogicalPlan plan : plans) { + Limit limit = as(plan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + assertThat(aggregate.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + assertThat(aggregate.aggregates(), hasSize(2)); + assertThat(aggregate.groupings(), hasSize(1)); + Eval eval = as(aggregate.child(), Eval.class); + assertThat(eval.fields(), hasSize(2)); + assertThat(Alias.unwrap(eval.fields().get(0)), instanceOf(Bucket.class)); + assertThat(Alias.unwrap(eval.fields().get(1)), instanceOf(ToLong.class)); + EsRelation relation = as(eval.child(), EsRelation.class); + assertThat(relation.indexMode(), equalTo(IndexMode.STANDARD)); + } + for (int i = 1; i < plans.size(); i++) { + assertThat(plans.get(i), equalTo(plans.get(0))); + } + } + + public void testRateInStats() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = """ + METRICS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """; + VerificationException error = expectThrows( + VerificationException.class, + () -> logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))) + ); + assertThat(error.getMessage(), equalTo(""" + Found 1 problem + line 1:25: the rate aggregate[rate(network.total_bytes_in)] can only be used within the metrics command""")); + } + public void testMvSortInvalidOrder() { VerificationException e = expectThrows(VerificationException.class, () -> plan(""" from test From 67da6ba645afc5a6c9bf5478ce3e92a7d8cb7d08 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 9 Jul 2024 09:15:08 +0200 Subject: [PATCH 32/64] Deduplicate FieldInfo attributes and field names (#110561) We can use a similar strategy to what worked with mappers+settings and reuse the string deduplicator to deal with a large chunk (more than 70% from heap dumps we've seen in production) of the `FieldInfo` duplication overhead without any Lucene changes. There's generally only a very limited number of attribute maps out there and the "dedup up to 100" logic in here deals with all scenarios I have observed in the wild thus far. As a side effect of deduplicating the field name and always working with an interned string now, I would expect the performance of field caps filtering for empty fields to improve measurably. --- .../codec/DeduplicatingFieldInfosFormat.java | 96 +++++++++++++++++++ .../index/codec/Elasticsearch814Codec.java | 9 ++ 2 files changed, 105 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java b/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java new file mode 100644 index 0000000000000..75ec265a68391 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec; + +import org.apache.lucene.codecs.FieldInfosFormat; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.StringLiteralDeduplicator; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.index.mapper.FieldMapper; + +import java.io.IOException; +import java.util.Map; + +/** + * Wrapper around a {@link FieldInfosFormat} that will deduplicate and intern all field names, attribute-keys and -values, and in most + * cases attribute maps on read. We use this to reduce the per-field overhead for Elasticsearch instances holding a large number of + * segments. + */ +public final class DeduplicatingFieldInfosFormat extends FieldInfosFormat { + + private static final Map, Map> attributeDeduplicator = ConcurrentCollections.newConcurrentMap(); + + private static final StringLiteralDeduplicator attributesDeduplicator = new StringLiteralDeduplicator(); + + private final FieldInfosFormat delegate; + + public DeduplicatingFieldInfosFormat(FieldInfosFormat delegate) { + this.delegate = delegate; + } + + @Override + public FieldInfos read(Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext iocontext) throws IOException { + final FieldInfos fieldInfos = delegate.read(directory, segmentInfo, segmentSuffix, iocontext); + final FieldInfo[] deduplicated = new FieldInfo[fieldInfos.size()]; + int i = 0; + for (FieldInfo fi : fieldInfos) { + deduplicated[i++] = new FieldInfo( + FieldMapper.internFieldName(fi.getName()), + fi.number, + fi.hasVectors(), + fi.omitsNorms(), + fi.hasPayloads(), + fi.getIndexOptions(), + fi.getDocValuesType(), + fi.getDocValuesGen(), + internStringStringMap(fi.attributes()), + fi.getPointDimensionCount(), + fi.getPointIndexDimensionCount(), + fi.getPointNumBytes(), + fi.getVectorDimension(), + fi.getVectorEncoding(), + fi.getVectorSimilarityFunction(), + fi.isSoftDeletesField(), + fi.isParentField() + ); + } + return new FieldInfos(deduplicated); + } + + private static Map internStringStringMap(Map m) { + if (m.size() > 10) { + return m; + } + var res = attributeDeduplicator.get(m); + if (res == null) { + if (attributeDeduplicator.size() > 100) { + // Unexpected edge case to have more than 100 different attribute maps + // Just to be safe, don't retain more than 100 maps to prevent a potential memory leak + attributeDeduplicator.clear(); + } + final Map interned = Maps.newHashMapWithExpectedSize(m.size()); + m.forEach((key, value) -> interned.put(attributesDeduplicator.deduplicate(key), attributesDeduplicator.deduplicate(value))); + res = Map.copyOf(interned); + attributeDeduplicator.put(res, res); + } + return res; + } + + @Override + public void write(Directory directory, SegmentInfo segmentInfo, String segmentSuffix, FieldInfos infos, IOContext context) + throws IOException { + delegate.write(directory, segmentInfo, segmentSuffix, infos, context); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java index e85e05c87b083..dd7a668605e57 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.codec; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FieldInfosFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; @@ -30,6 +31,8 @@ public class Elasticsearch814Codec extends FilterCodec { private final StoredFieldsFormat storedFieldsFormat; + private final FieldInfosFormat fieldInfosFormat; + private final PostingsFormat defaultPostingsFormat; private final PostingsFormat postingsFormat = new PerFieldPostingsFormat() { @Override @@ -69,6 +72,7 @@ public Elasticsearch814Codec(Zstd814StoredFieldsFormat.Mode mode) { this.defaultPostingsFormat = new Lucene99PostingsFormat(); this.defaultDVFormat = new Lucene90DocValuesFormat(); this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); + this.fieldInfosFormat = new DeduplicatingFieldInfosFormat(delegate.fieldInfosFormat()); } @Override @@ -127,4 +131,9 @@ public DocValuesFormat getDocValuesFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return defaultKnnVectorsFormat; } + + @Override + public FieldInfosFormat fieldInfosFormat() { + return fieldInfosFormat; + } } From af779d68a28b6ec75bf0257fb6f3ec5abfe9c1a9 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:59:28 +0100 Subject: [PATCH 33/64] Upgrade to JMH 1.37 (#110580) This commit upgrades to JMH 1.37. There are some fixes for Mac that allow easier running of profilers, etc. --- benchmarks/build.gradle | 4 ++-- build-tools-internal/version.properties | 2 +- gradle/verification-metadata.xml | 22 +++++++++++----------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 8753d4a4762b7..49e81a67e85f9 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -47,8 +47,8 @@ dependencies { api "org.openjdk.jmh:jmh-core:$versions.jmh" annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:$versions.jmh" // Dependencies of JMH - runtimeOnly 'net.sf.jopt-simple:jopt-simple:4.6' - runtimeOnly 'org.apache.commons:commons-math3:3.2' + runtimeOnly 'net.sf.jopt-simple:jopt-simple:5.0.4' + runtimeOnly 'org.apache.commons:commons-math3:3.6.1' } // enable the JMH's BenchmarkProcessor to generate the final benchmark classes diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index 728f44a365974..1dd9fb95bd17b 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -49,7 +49,7 @@ commonsCompress = 1.24.0 reflections = 0.10.2 # benchmark dependencies -jmh = 1.26 +jmh = 1.37 # test dependencies # when updating this version, also update :qa:evil-tests diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 02313c5ed82a2..5e26d96c4ca17 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -1699,16 +1699,16 @@ - - - - - + + + + + @@ -3837,14 +3837,14 @@ - - - + + + - - - + + + From 362e049bd30501e19abc99d9119a7931c93ed2e0 Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Tue, 9 Jul 2024 09:59:44 +0200 Subject: [PATCH 34/64] An alternative approach to supporting union-types on stats grouping field (#110600) * Added union-types field extration to ordinals aggregation * Revert previous approach to getting union-types working in aggregations Where the grouping field is erased by later commands, like a subsequent stats. Instead we include union-type supports in the ordinals aggregation and mark the block loader as not supporting ordinals. --- .../lucene/ValueSourceReaderTypeConversionTests.java | 7 ++++--- .../esql/optimizer/LocalPhysicalPlanOptimizer.java | 10 +--------- .../esql/planner/EsPhysicalOperationProviders.java | 10 ++++++---- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java index 66bcf2a57e393..09f63e9fa45bb 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java @@ -1687,12 +1687,13 @@ public StoredFieldsSpec rowStrideStoredFieldSpec() { @Override public boolean supportsOrdinals() { - return delegate.supportsOrdinals(); + // Fields with mismatching types cannot use ordinals for uniqueness determination, but must convert the values first + return false; } @Override - public SortedSetDocValues ordinals(LeafReaderContext context) throws IOException { - return delegate.ordinals(context); + public SortedSetDocValues ordinals(LeafReaderContext context) { + throw new IllegalArgumentException("Ordinals are not supported for type conversion"); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index f78ae6930d9ba..1b40a1c2b02ad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -77,7 +77,6 @@ import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; import org.elasticsearch.xpack.esql.planner.EsqlTranslatorHandler; import org.elasticsearch.xpack.esql.stats.SearchStats; -import org.elasticsearch.xpack.esql.type.MultiTypeEsField; import java.nio.ByteOrder; import java.util.ArrayList; @@ -194,10 +193,7 @@ public PhysicalPlan apply(PhysicalPlan plan) { * it loads the field lazily. If we have more than one field we need to * make sure the fields are loaded for the standard hash aggregator. */ - if (p instanceof AggregateExec agg - && agg.groupings().size() == 1 - && (isMultiTypeFieldAttribute(agg.groupings().get(0)) == false) // Union types rely on field extraction. - ) { + if (p instanceof AggregateExec agg && agg.groupings().size() == 1) { var leaves = new LinkedList<>(); // TODO: this seems out of place agg.aggregates() @@ -221,10 +217,6 @@ public PhysicalPlan apply(PhysicalPlan plan) { return plan; } - private static boolean isMultiTypeFieldAttribute(Expression attribute) { - return attribute instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField; - } - private static Set missingAttributes(PhysicalPlan p) { var missing = new LinkedHashSet(); var input = p.inputSet(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 9e1e1a50fe8f0..8611d2c6fa9fb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -233,8 +233,9 @@ public final Operator.OperatorFactory ordinalGroupingOperatorFactory( // The grouping-by values are ready, let's group on them directly. // Costin: why are they ready and not already exposed in the layout? boolean isUnsupported = attrSource.dataType() == DataType.UNSUPPORTED; + var unionTypes = findUnionTypes(attrSource); return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory( - shardIdx -> shardContexts.get(shardIdx).blockLoader(attrSource.name(), isUnsupported, NONE), + shardIdx -> getBlockLoaderFor(shardIdx, attrSource.name(), isUnsupported, NONE, unionTypes), vsShardContexts, groupElementType, docChannel, @@ -434,12 +435,13 @@ public StoredFieldsSpec rowStrideStoredFieldSpec() { @Override public boolean supportsOrdinals() { - return delegate.supportsOrdinals(); + // Fields with mismatching types cannot use ordinals for uniqueness determination, but must convert the values first + return false; } @Override - public SortedSetDocValues ordinals(LeafReaderContext context) throws IOException { - return delegate.ordinals(context); + public SortedSetDocValues ordinals(LeafReaderContext context) { + throw new IllegalArgumentException("Ordinals are not supported for type conversion"); } @Override From bd0eff6be6ceea15857b5277c698eadb1f4667ee Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Tue, 9 Jul 2024 09:17:08 +0100 Subject: [PATCH 35/64] Implement xorBitCount in Elasticsearch (#110599) This commit adds an implement of XOR bit count computed over signed bytes that is ~4x faster than that of Lucene 9.11, on ARM. While already fixed in Lucene, it'll be in a Lucene version > 9.11. This is effectively a temporary workaround until Lucene 9.12, after which we can revert this. --- .../vectors/ES815BitFlatVectorsFormat.java | 4 +- .../field/vectors/ByteBinaryDenseVector.java | 2 +- .../field/vectors/ByteKnnDenseVector.java | 2 +- .../script/field/vectors/ESVectorUtil.java | 72 +++++++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index 659cc89bfe46d..de91833c99842 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -16,11 +16,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.elasticsearch.script.field.vectors.ESVectorUtil; import java.io.IOException; @@ -100,7 +100,7 @@ public RandomVectorScorer getRandomVectorScorer( } static float hammingScore(byte[] a, byte[] b) { - return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); + return ((a.length * Byte.SIZE) - ESVectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); } static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index f2ff8fbccd2fb..e5c2d6a370f12 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -102,7 +102,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, vectorValue); + return ESVectorUtil.xorBitCount(queryVector, vectorValue); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java index e0ba032826aa1..0145eb3eae04b 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java @@ -103,7 +103,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, docVector); + return ESVectorUtil.xorBitCount(queryVector, docVector); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java new file mode 100644 index 0000000000000..7d9542bccf357 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.Constants; + +/** + * This class consists of a single utility method that provides XOR bit count computed over signed bytes. + * Remove this class when Lucene version > 9.11 is released, and replace with Lucene's VectorUtil directly. + */ +public class ESVectorUtil { + + /** + * For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. + * On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when + * compared to Integer::bitCount. While Long::bitCount is optimal on x64. + */ + static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64"); + + /** + * XOR bit count computed over signed bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the XOR bit count of the two vectors + */ + public static int xorBitCount(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + if (XOR_BIT_COUNT_STRIDE_AS_INT) { + return xorBitCountInt(a, b); + } else { + return xorBitCountLong(a, b); + } + } + + /** XOR bit count striding over 4 bytes at a time. */ + static int xorBitCountInt(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + /** XOR bit count striding over 8 bytes at a time. */ + static int xorBitCountLong(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + private ESVectorUtil() {} +} From 13096f4d7e82d03a258f552d3b97cea0ee9dc247 Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Tue, 9 Jul 2024 11:51:30 +0200 Subject: [PATCH 36/64] Remove 'emulated' option for CSV tests (#110124) it's redundant, as we can use warningRegex instead --- .../xpack/esql/core/CsvSpecReader.java | 23 ++----------------- .../xpack/esql/qa/rest/EsqlSpecTestCase.java | 6 +---- .../src/main/resources/ip.csv-spec | 4 ++-- .../elasticsearch/xpack/esql/CsvTests.java | 2 +- 4 files changed, 6 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java b/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java index a1f524e525eee..8e5a228af00d6 100644 --- a/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java +++ b/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java @@ -15,7 +15,6 @@ import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; public final class CsvSpecReader { @@ -113,34 +112,16 @@ public static class CsvTestCase { public boolean ignoreOrder; public List requiredCapabilities = List.of(); - // The emulated-specific warnings must always trail the non-emulated ones, if these are present. Otherwise, the closing bracket - // would need to be changed to a less common sequence (like `]#` maybe). - private static final String EMULATED_PREFIX = "#[emulated:"; - /** * Returns the warning headers expected to be added by the test. To declare such a header, use the `warning:definition` format * in the CSV test declaration. The `definition` can use the `EMULATED_PREFIX` string to specify the format of the warning run on * emulated physical operators, if this differs from the format returned by SingleValueQuery. - * @param forEmulated if true, the tests are run on emulated physical operators; if false, the test case is for queries executed - * on a "full stack" ESQL, having data loaded from Lucene. * @return the list of headers that are expected to be returned part of the response. */ - public List expectedWarnings(boolean forEmulated) { + public List expectedWarnings() { List warnings = new ArrayList<>(expectedWarnings.size()); for (String warning : expectedWarnings) { - int idx = warning.toLowerCase(Locale.ROOT).indexOf(EMULATED_PREFIX); - if (idx >= 0) { - assertTrue("Invalid warning spec: closing delimiter (]) missing: `" + warning + "`", warning.endsWith("]")); - if (forEmulated) { - if (idx + EMULATED_PREFIX.length() < warning.length() - 1) { - warnings.add(warning.substring(idx + EMULATED_PREFIX.length(), warning.length() - 1)); - } - } else if (idx > 0) { - warnings.add(warning.substring(0, idx)); - } // else: no warnings expected for non-emulated - } else { - warnings.add(warning); - } + warnings.add(warning); } return warnings; } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index e25eb84023867..e650f0815f964 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -204,11 +204,7 @@ protected final void doTest() throws Throwable { builder.tables(tables()); } - Map answer = runEsql( - builder.query(testCase.query), - testCase.expectedWarnings(false), - testCase.expectedWarningsRegex() - ); + Map answer = runEsql(builder.query(testCase.query), testCase.expectedWarnings(), testCase.expectedWarningsRegex()); var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec index 54d5484bb4172..697b1c899d65e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec @@ -285,8 +285,8 @@ str1:keyword |str2:keyword |ip1:ip |ip2:ip pushDownIP from hosts | where ip1 == to_ip("::1") | keep card, host, ip0, ip1; ignoreOrder:true -warning:#[Emulated:Line 1:20: evaluation of [ip1 == to_ip(\"::1\")] failed, treating result as null. Only first 20 failures recorded.] -warning:#[Emulated:Line 1:20: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[ip1 == to_ip\(\\\"::1\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth1 |alpha |::1 |::1 diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index e8a403ae7d9d0..20b4d3a503f0c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -459,7 +459,7 @@ private void assertWarnings(List warnings) { normalized.add(normW); } } - EsqlTestUtils.assertWarnings(normalized, testCase.expectedWarnings(true), testCase.expectedWarningsRegex()); + EsqlTestUtils.assertWarnings(normalized, testCase.expectedWarnings(), testCase.expectedWarningsRegex()); } BiConsumer> runPhase( From 4cfe6cc2cdeeb7ed8d9bbbe7f744f28fd7c3f58f Mon Sep 17 00:00:00 2001 From: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:51:44 +0200 Subject: [PATCH 37/64] Inference autoscaling (#109667) * Python dev tool for inference autoscaling simulation. Squashed commit of the following: commit d98bd3d39d833329ab83a8274885473db41ed08a Author: Jan Kuipers Date: Mon May 13 17:27:38 2024 +0200 Increase measurement interval to 10secs commit e808ae5be52c5ea4d5ff8ccb881a4a80de0254f9 Author: Jan Kuipers Date: Mon May 13 17:09:33 2024 +0200 jump -> jumps commit c38cbdebfcec43e6982bb8bd1670519293161154 Author: Jan Kuipers Date: Mon May 13 14:32:42 2024 +0200 Remove unused estimator commit 16101f32b539481cd4d648ebb5637a3309853552 Author: Jan Kuipers Date: Mon May 13 14:31:30 2024 +0200 Measure latency periodically + documentation commit bc73bf29fde1d772701f0b71a7c8a0908669eb0f Author: Jan Kuipers Date: Mon May 13 12:53:19 2024 +0200 Init variance to None commit 0e73fa836fa9deec6ba55ef1161cc0dd71f35044 Author: Jan Kuipers Date: Mon May 13 11:18:21 2024 +0200 No autodetection of dynamics changes for latency commit 75924a744d26a72835529598a6df1a2d22bdaddc Author: Jan Kuipers Date: Mon May 13 10:10:34 2024 +0200 Move autoscaling code to own class commit 23553bb8cccd6ed80ac667b12ec38a6d5562dd29 Author: Jan Kuipers Date: Wed May 8 18:01:55 2024 +0200 Improved autoscaling simulation commit 2db606b2bba69d741fa231f369c633ea793294d5 Author: Tom Veasey Date: Tue Apr 30 15:01:40 2024 +0100 Correct the dependency on allocations commit 0e45cfbaf901cf9d440efa9b404058a67d000653 Author: Tom Veasey Date: Tue Apr 30 11:11:05 2024 +0100 Tweak commit a0f23a4a05875cd5df3863e5ad067b46a67c8cda Author: Tom Veasey Date: Tue Apr 30 11:09:30 2024 +0100 Correction commit f9cdb140d298bd99c64c79f020c058d60bfba134 Author: Tom Veasey Date: Tue Apr 30 09:57:59 2024 +0100 Allow extrapolation commit 57eb1a661a2b97412f479606c23c54dfb7887f52 Author: Tom Veasey Date: Tue Apr 30 09:55:17 2024 +0100 Simplify and estimate average duration rather than rate commit 36dff17194f2bcf816013b112cf07d70c9eec161 Author: Tom Veasey Date: Mon Apr 29 21:42:25 2024 +0100 Kalman filter for simple state model for average inference duration as a function of time and allocation count commit a1b85bd0deeabd5162f2ccd5a28672299025cee5 Author: Jan Kuipers Date: Mon Apr 29 12:15:59 2024 +0200 Improvements commit 51040655fcfbfd221f2446542a955fb0f19fb145 Author: Jan Kuipers Date: Mon Apr 29 09:33:10 2024 +0200 Account for virtual cores / hyperthreading commit 7a93407ecae6b6044108299a1d05f72cdf0d752a Author: Jan Kuipers Date: Fri Apr 26 16:58:25 2024 +0200 Simulator for inference autoscaling. * Better process variance upon dynamics changes, and propagate dynamics changes to the next iteration. * Inference autoscaling (WIP) * Inference autoscaling test scripts * Debug logs * Inference autoscaling API * Update Autoscalers upon cluster changes * Polish code / fix bugs * Use correct string formatter * More fixes * Autoscaling tests * spotless * Remove scripts (moved to ml-data) * Rebrand to "adaptive allocations". * Move serialized field to end * Rebranding leftover * Improve adaptive allocation timing * SystemAuditor for scaling messages * Fix test * Add documentation * Update docs/changelog/109667.yaml * Cooldown of 5mins after scaleup * Polish code * High-variance adaptive allocations test * Fix AdaptiveAllocationsScalerServiceTests * Fix typo in package name * Wire adaptive allocations setting into put inference API * Checkstyle * Fix serialization of ElserInternalServiceSettings. * Propagate adaptive allocations settings from put inference request to create trained model request * Fix CustomElandInternalTextEmbeddingServiceSettingsTests * Javadocs * Improvements / fixes * Disallow setting num_allocations when adaptive allocations is enabled * Fix AdaptiveAllocationsScalerServiceTests * spotless * NPE fixes * spotless * Allow autoscaler to update num allocations * Fix AdaptiveAllocationsScalerServiceTests. * Fix bug in inference stats api * Fix PyTorchResultProcessorTests --- docs/changelog/109667.yaml | 5 + .../org/elasticsearch/TransportVersions.java | 1 + .../CreateTrainedModelAssignmentAction.java | 23 +- .../StartTrainedModelDeploymentAction.java | 90 ++++- .../UpdateTrainedModelDeploymentAction.java | 82 ++++- .../AdaptiveAllocationsSettings.java | 181 ++++++++++ .../inference/assignment/AssignmentStats.java | 22 ++ .../assignment/TrainedModelAssignment.java | 86 ++++- ...inedModelAssignmentActionRequestTests.java | 2 +- ...TrainedModelsStatsActionResponseTests.java | 7 + ...artTrainedModelDeploymentRequestTests.java | 6 +- .../assignment/AssignmentStatsTests.java | 4 + .../TrainedModelAssignmentTests.java | 46 +-- .../inference/services/ServiceUtils.java | 15 + .../CustomElandInternalServiceSettings.java | 39 +- ...dInternalTextEmbeddingServiceSettings.java | 26 +- .../elasticsearch/CustomElandModel.java | 1 + .../ElasticsearchInternalService.java | 1 + .../ElasticsearchInternalServiceSettings.java | 36 +- ...lingualE5SmallInternalServiceSettings.java | 49 ++- .../MultilingualE5SmallModel.java | 1 + .../services/elser/ElserInternalService.java | 1 + .../elser/ElserInternalServiceSettings.java | 43 ++- .../settings/InternalServiceSettings.java | 38 +- ...rnalTextEmbeddingServiceSettingsTests.java | 9 +- .../ElasticsearchInternalServiceTests.java | 20 +- ...alE5SmallInternalServiceSettingsTests.java | 14 +- .../ElserInternalServiceSettingsTests.java | 30 +- .../elser/ElserInternalServiceTests.java | 12 +- .../MlInitializationServiceIT.java | 10 +- .../xpack/ml/MachineLearning.java | 1 + .../xpack/ml/MlInitializationService.java | 10 + ...ortCreateTrainedModelAssignmentAction.java | 2 +- .../TransportGetDeploymentStatsAction.java | 3 + ...portStartTrainedModelDeploymentAction.java | 7 +- ...ortUpdateTrainedModelDeploymentAction.java | 4 +- .../AdaptiveAllocationsScaler.java | 154 ++++++++ .../AdaptiveAllocationsScalerService.java | 340 ++++++++++++++++++ .../adaptiveallocations/KalmanFilter1d.java | 121 +++++++ .../TrainedModelAssignmentClusterService.java | 145 ++++++-- .../TrainedModelAssignmentRebalancer.java | 39 +- .../TrainedModelAssignmentService.java | 5 +- .../planning/AbstractPreserveAllocations.java | 1 + .../assignment/planning/AssignmentPlan.java | 9 +- .../planning/AssignmentPlanner.java | 2 + .../planning/ZoneAwareAssignmentPlanner.java | 2 + .../process/PyTorchResultProcessor.java | 2 +- ...RestStartTrainedModelDeploymentAction.java | 3 +- ...chineLearningInfoTransportActionTests.java | 2 + .../ml/MlInitializationServiceTests.java | 10 + .../xpack/ml/MlLifeCycleServiceTests.java | 10 +- .../xpack/ml/MlMetricsTests.java | 6 +- ...ransportGetDeploymentStatsActionTests.java | 5 +- .../MlAutoscalingResourceTrackerTests.java | 30 +- .../MlMemoryAutoscalingDeciderTests.java | 18 +- .../MlProcessorAutoscalingDeciderTests.java | 42 ++- ...AdaptiveAllocationsScalerServiceTests.java | 239 ++++++++++++ .../AdaptiveAllocationsScalerTests.java | 141 ++++++++ .../KalmanFilter1dTests.java | 122 +++++++ ...nedModelAssignmentClusterServiceTests.java | 136 +++---- .../TrainedModelAssignmentMetadataTests.java | 10 +- ...rainedModelAssignmentNodeServiceTests.java | 38 +- ...TrainedModelAssignmentRebalancerTests.java | 94 ++--- .../planning/AllocationReducerTests.java | 3 +- .../planning/AssignmentPlanTests.java | 92 +++-- .../planning/AssignmentPlannerTests.java | 114 +++--- .../planning/PreserveAllAllocationsTests.java | 21 +- .../planning/PreserveOneAllocationTests.java | 23 +- .../ZoneAwareAssignmentPlannerTests.java | 24 +- .../process/PyTorchResultProcessorTests.java | 25 +- .../xpack/ml/job/NodeLoadDetectorTests.java | 3 +- ...dateTrainedModelDeploymentActionTests.java | 4 +- 72 files changed, 2517 insertions(+), 445 deletions(-) create mode 100644 docs/changelog/109667.yaml create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java diff --git a/docs/changelog/109667.yaml b/docs/changelog/109667.yaml new file mode 100644 index 0000000000000..782a1b1cf6c9b --- /dev/null +++ b/docs/changelog/109667.yaml @@ -0,0 +1,5 @@ +pr: 109667 +summary: Inference autoscaling +area: Machine Learning +type: feature +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f64a43d463d47..65606465b8502 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -210,6 +210,7 @@ static TransportVersion def(int id) { public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0); public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0); + public static final TransportVersion INFERENCE_ADAPTIVE_ALLOCATIONS = def(8_704_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java index 9b383b2652af4..c6976ab4b513e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -18,6 +19,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,15 +36,22 @@ private CreateTrainedModelAssignmentAction() { public static class Request extends MasterNodeRequest { private final StartTrainedModelDeploymentAction.TaskParams taskParams; + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; - public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) { + public Request(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT); this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams"); + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } public Request(StreamInput in) throws IOException { super(in); this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + this.adaptiveAllocationsSettings = null; + } } @Override @@ -54,6 +63,9 @@ public ActionRequestValidationException validate() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); taskParams.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -61,17 +73,22 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(taskParams, request.taskParams); + return Objects.equals(taskParams, request.taskParams) + && Objects.equals(adaptiveAllocationsSettings, request.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(taskParams); + return Objects.hash(taskParams, adaptiveAllocationsSettings); } public StartTrainedModelDeploymentAction.TaskParams getTaskParams() { return taskParams; } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } } public static class Response extends ActionResponse implements ToXContentObject { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index ca9b86a90f875..e635851a4c5e8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -29,8 +29,10 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlTaskParams; @@ -40,7 +42,6 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription; public class StartTrainedModelDeploymentAction extends ActionType { @@ -99,6 +100,7 @@ public static class Request extends MasterNodeRequest implements ToXCon public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY; public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE; public static final ParseField PRIORITY = TaskParams.PRIORITY; + public static final ParseField ADAPTIVE_ALLOCATIONS = TrainedModelAssignment.ADAPTIVE_ALLOCATIONS; public static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); @@ -117,6 +119,12 @@ public static class Request extends MasterNodeRequest implements ToXCon ObjectParser.ValueType.VALUE ); PARSER.declareString(Request::setPriority, PRIORITY); + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + null, + ADAPTIVE_ALLOCATIONS + ); } public static Request parseRequest(String modelId, String deploymentId, XContentParser parser) { @@ -140,7 +148,8 @@ public static Request parseRequest(String modelId, String deploymentId, XContent private TimeValue timeout = DEFAULT_TIMEOUT; private AllocationStatus.State waitForState = DEFAULT_WAITFOR_STATE; private ByteSizeValue cacheSize; - private int numberOfAllocations = DEFAULT_NUM_ALLOCATIONS; + private Integer numberOfAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings = null; private int threadsPerAllocation = DEFAULT_NUM_THREADS; private int queueCapacity = DEFAULT_QUEUE_CAPACITY; private Priority priority = DEFAULT_PRIORITY; @@ -160,7 +169,11 @@ public Request(StreamInput in) throws IOException { modelId = in.readString(); timeout = in.readTimeValue(); waitForState = in.readEnum(AllocationStatus.State.class); - numberOfAllocations = in.readVInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + numberOfAllocations = in.readOptionalVInt(); + } else { + numberOfAllocations = in.readVInt(); + } threadsPerAllocation = in.readVInt(); queueCapacity = in.readVInt(); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { @@ -171,12 +184,16 @@ public Request(StreamInput in) throws IOException { } else { this.priority = Priority.NORMAL; } - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { this.deploymentId = in.readString(); } else { this.deploymentId = modelId; } + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + this.adaptiveAllocationsSettings = null; + } } public final void setModelId(String modelId) { @@ -212,14 +229,34 @@ public Request setWaitForState(AllocationStatus.State waitForState) { return this; } - public int getNumberOfAllocations() { + public Integer getNumberOfAllocations() { return numberOfAllocations; } - public void setNumberOfAllocations(int numberOfAllocations) { + public int computeNumberOfAllocations() { + if (numberOfAllocations != null) { + return numberOfAllocations; + } else { + if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) { + return DEFAULT_NUM_ALLOCATIONS; + } else { + return adaptiveAllocationsSettings.getMinNumberOfAllocations(); + } + } + } + + public void setNumberOfAllocations(Integer numberOfAllocations) { this.numberOfAllocations = numberOfAllocations; } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + public int getThreadsPerAllocation() { return threadsPerAllocation; } @@ -258,7 +295,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeTimeValue(timeout); out.writeEnum(waitForState); - out.writeVInt(numberOfAllocations); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalVInt(numberOfAllocations); + } else { + out.writeVInt(numberOfAllocations); + } out.writeVInt(threadsPerAllocation); out.writeVInt(queueCapacity); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { @@ -270,6 +311,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -279,7 +323,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId); builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); builder.field(WAIT_FOR.getPreferredName(), waitForState); - builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + if (numberOfAllocations != null) { + builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + } + if (adaptiveAllocationsSettings != null) { + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation); builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); if (cacheSize != null) { @@ -301,12 +350,25 @@ public ActionRequestValidationException validate() { + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES) ); } - if (numberOfAllocations < 1) { - validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + if (numberOfAllocations != null) { + if (numberOfAllocations < 1) { + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + } + if (adaptiveAllocationsSettings != null && adaptiveAllocationsSettings.getEnabled()) { + validationException.addValidationError( + "[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled" + ); + } } if (threadsPerAllocation < 1) { validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer"); } + ActionRequestValidationException autoscaleException = adaptiveAllocationsSettings == null + ? null + : adaptiveAllocationsSettings.validate(); + if (autoscaleException != null) { + validationException.addValidationErrors(autoscaleException.validationErrors()); + } if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2(threadsPerAllocation) == false) { validationException.addValidationError( "[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION @@ -322,7 +384,7 @@ public ActionRequestValidationException validate() { validationException.addValidationError("[" + TIMEOUT + "] must be positive"); } if (priority == Priority.LOW) { - if (numberOfAllocations > 1) { + if (numberOfAllocations != null && numberOfAllocations > 1) { validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be 1 when [" + PRIORITY + "] is low"); } if (threadsPerAllocation > 1) { @@ -344,6 +406,7 @@ public int hashCode() { timeout, waitForState, numberOfAllocations, + adaptiveAllocationsSettings, threadsPerAllocation, queueCapacity, cacheSize, @@ -365,7 +428,8 @@ public boolean equals(Object obj) { && Objects.equals(timeout, other.timeout) && Objects.equals(waitForState, other.waitForState) && Objects.equals(cacheSize, other.cacheSize) - && numberOfAllocations == other.numberOfAllocations + && Objects.equals(numberOfAllocations, other.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, other.adaptiveAllocationsSettings) && threadsPerAllocation == other.threadsPerAllocation && queueCapacity == other.queueCapacity && priority == other.priority; @@ -430,7 +494,7 @@ public static boolean mayAssignToNode(@Nullable DiscoveryNode node) { PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION); PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY); PARSER.declareField( - optionalConstructorArg(), + ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()), CACHE_SIZE, ObjectParser.ValueType.VALUE diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java index 62a7d84c60a62..c69a88600f915 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.AcknowledgedRequest; @@ -19,12 +20,14 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.ADAPTIVE_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_ID; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; @@ -46,6 +49,12 @@ public static class Request extends AcknowledgedRequest implements ToXC static { PARSER.declareString(Request::setDeploymentId, MODEL_ID); PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS); + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + AdaptiveAllocationsSettings.RESET_PLACEHOLDER, + ADAPTIVE_ALLOCATIONS + ); PARSER.declareString((r, val) -> r.ackTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); } @@ -62,7 +71,9 @@ public static Request parseRequest(String deploymentId, XContentParser parser) { } private String deploymentId; - private int numberOfAllocations; + private Integer numberOfAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; + private boolean isInternal; private Request() { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); @@ -76,7 +87,15 @@ public Request(String deploymentId) { public Request(StreamInput in) throws IOException { super(in); deploymentId = in.readString(); - numberOfAllocations = in.readVInt(); + if (in.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + numberOfAllocations = in.readVInt(); + adaptiveAllocationsSettings = null; + isInternal = false; + } else { + numberOfAllocations = in.readOptionalVInt(); + adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + isInternal = in.readBoolean(); + } } public final void setDeploymentId(String deploymentId) { @@ -87,26 +106,53 @@ public String getDeploymentId() { return deploymentId; } - public void setNumberOfAllocations(int numberOfAllocations) { + public void setNumberOfAllocations(Integer numberOfAllocations) { this.numberOfAllocations = numberOfAllocations; } - public int getNumberOfAllocations() { + public Integer getNumberOfAllocations() { return numberOfAllocations; } + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + + public boolean isInternal() { + return isInternal; + } + + public void setIsInternal(boolean isInternal) { + this.isInternal = isInternal; + } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(deploymentId); - out.writeVInt(numberOfAllocations); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeVInt(numberOfAllocations); + } else { + out.writeOptionalVInt(numberOfAllocations); + out.writeOptionalWriteable(adaptiveAllocationsSettings); + out.writeBoolean(isInternal); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(MODEL_ID.getPreferredName(), deploymentId); - builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + if (numberOfAllocations != null) { + builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + } + if (adaptiveAllocationsSettings != null) { + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } builder.endObject(); return builder; } @@ -114,15 +160,28 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = new ActionRequestValidationException(); - if (numberOfAllocations < 1) { - validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + if (numberOfAllocations != null) { + if (numberOfAllocations < 1) { + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + } + if (isInternal == false && adaptiveAllocationsSettings != null && adaptiveAllocationsSettings.getEnabled()) { + validationException.addValidationError( + "[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled" + ); + } + } + ActionRequestValidationException autoscaleException = adaptiveAllocationsSettings == null + ? null + : adaptiveAllocationsSettings.validate(); + if (autoscaleException != null) { + validationException.addValidationErrors(autoscaleException.validationErrors()); } return validationException.validationErrors().isEmpty() ? null : validationException; } @Override public int hashCode() { - return Objects.hash(deploymentId, numberOfAllocations); + return Objects.hash(deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal); } @Override @@ -134,7 +193,10 @@ public boolean equals(Object obj) { return false; } Request other = (Request) obj; - return Objects.equals(deploymentId, other.deploymentId) && numberOfAllocations == other.numberOfAllocations; + return Objects.equals(deploymentId, other.deploymentId) + && Objects.equals(numberOfAllocations, other.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, other.adaptiveAllocationsSettings) + && isInternal == other.isInternal; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java new file mode 100644 index 0000000000000..0b5a62ccb588c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class AdaptiveAllocationsSettings implements ToXContentObject, Writeable { + + public static final AdaptiveAllocationsSettings RESET_PLACEHOLDER = new AdaptiveAllocationsSettings(false, -1, -1); + + public static final ParseField ENABLED = new ParseField("enabled"); + public static final ParseField MIN_NUMBER_OF_ALLOCATIONS = new ParseField("min_number_of_allocations"); + public static final ParseField MAX_NUMBER_OF_ALLOCATIONS = new ParseField("max_number_of_allocations"); + + public static final ObjectParser PARSER = new ObjectParser<>( + "autoscaling_settings", + AdaptiveAllocationsSettings.Builder::new + ); + + static { + PARSER.declareBoolean(Builder::setEnabled, ENABLED); + PARSER.declareIntOrNull(Builder::setMinNumberOfAllocations, -1, MIN_NUMBER_OF_ALLOCATIONS); + PARSER.declareIntOrNull(Builder::setMaxNumberOfAllocations, -1, MAX_NUMBER_OF_ALLOCATIONS); + } + + public static AdaptiveAllocationsSettings parseRequest(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static class Builder { + private Boolean enabled; + private Integer minNumberOfAllocations; + private Integer maxNumberOfAllocations; + + public Builder() {} + + public Builder(AdaptiveAllocationsSettings settings) { + enabled = settings.enabled; + minNumberOfAllocations = settings.minNumberOfAllocations; + maxNumberOfAllocations = settings.maxNumberOfAllocations; + } + + public void setEnabled(Boolean enabled) { + this.enabled = enabled; + } + + public void setMinNumberOfAllocations(Integer minNumberOfAllocations) { + this.minNumberOfAllocations = minNumberOfAllocations; + } + + public void setMaxNumberOfAllocations(Integer maxNumberOfAllocations) { + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + public AdaptiveAllocationsSettings build() { + return new AdaptiveAllocationsSettings(enabled, minNumberOfAllocations, maxNumberOfAllocations); + } + } + + private final Boolean enabled; + private final Integer minNumberOfAllocations; + private final Integer maxNumberOfAllocations; + + public AdaptiveAllocationsSettings(Boolean enabled, Integer minNumberOfAllocations, Integer maxNumberOfAllocations) { + this.enabled = enabled; + this.minNumberOfAllocations = minNumberOfAllocations; + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + public AdaptiveAllocationsSettings(StreamInput in) throws IOException { + enabled = in.readOptionalBoolean(); + minNumberOfAllocations = in.readOptionalInt(); + maxNumberOfAllocations = in.readOptionalInt(); + } + + public Boolean getEnabled() { + return enabled; + } + + public Integer getMinNumberOfAllocations() { + return minNumberOfAllocations; + } + + public Integer getMaxNumberOfAllocations() { + return maxNumberOfAllocations; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (enabled != null) { + builder.field(ENABLED.getPreferredName(), enabled); + } + if (minNumberOfAllocations != null) { + builder.field(MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), minNumberOfAllocations); + } + if (maxNumberOfAllocations != null) { + builder.field(MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), maxNumberOfAllocations); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(enabled); + out.writeOptionalInt(minNumberOfAllocations); + out.writeOptionalInt(maxNumberOfAllocations); + } + + public AdaptiveAllocationsSettings merge(AdaptiveAllocationsSettings updates) { + AdaptiveAllocationsSettings.Builder builder = new Builder(this); + if (updates.getEnabled() != null) { + builder.setEnabled(updates.enabled); + } + if (updates.minNumberOfAllocations != null) { + if (updates.minNumberOfAllocations == -1) { + builder.setMinNumberOfAllocations(null); + } else { + builder.setMinNumberOfAllocations(updates.minNumberOfAllocations); + } + } + if (updates.maxNumberOfAllocations != null) { + if (updates.maxNumberOfAllocations == -1) { + builder.setMaxNumberOfAllocations(null); + } else { + builder.setMaxNumberOfAllocations(updates.maxNumberOfAllocations); + } + } + return builder.build(); + } + + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = new ActionRequestValidationException(); + boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1); + if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) { + validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + } + boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1); + if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) { + validationException.addValidationError("[" + MAX_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + } + if (hasMinNumberOfAllocations && hasMaxNumberOfAllocations && minNumberOfAllocations > maxNumberOfAllocations) { + validationException.addValidationError( + "[" + MIN_NUMBER_OF_ALLOCATIONS + "] must not be larger than [" + MAX_NUMBER_OF_ALLOCATIONS + "]" + ); + } + return validationException.validationErrors().isEmpty() ? null : validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AdaptiveAllocationsSettings that = (AdaptiveAllocationsSettings) o; + return Objects.equals(enabled, that.enabled) + && Objects.equals(minNumberOfAllocations, that.minNumberOfAllocations) + && Objects.equals(maxNumberOfAllocations, that.maxNumberOfAllocations); + } + + @Override + public int hashCode() { + return Objects.hash(enabled, minNumberOfAllocations, maxNumberOfAllocations); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java index d8e5d7a6d9603..aadaa5254ff15 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java @@ -423,6 +423,8 @@ public int hashCode() { @Nullable private final Integer numberOfAllocations; @Nullable + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + @Nullable private final Integer queueCapacity; @Nullable private final ByteSizeValue cacheSize; @@ -435,6 +437,7 @@ public AssignmentStats( String modelId, @Nullable Integer threadsPerAllocation, @Nullable Integer numberOfAllocations, + @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings, @Nullable Integer queueCapacity, @Nullable ByteSizeValue cacheSize, Instant startTime, @@ -445,6 +448,7 @@ public AssignmentStats( this.modelId = modelId; this.threadsPerAllocation = threadsPerAllocation; this.numberOfAllocations = numberOfAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; this.queueCapacity = queueCapacity; this.startTime = Objects.requireNonNull(startTime); this.nodeStats = nodeStats; @@ -479,6 +483,11 @@ public AssignmentStats(StreamInput in) throws IOException { } else { deploymentId = modelId; } + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + adaptiveAllocationsSettings = null; + } } public String getDeploymentId() { @@ -499,6 +508,11 @@ public Integer getNumberOfAllocations() { return numberOfAllocations; } + @Nullable + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + @Nullable public Integer getQueueCapacity() { return queueCapacity; @@ -575,6 +589,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (numberOfAllocations != null) { builder.field(StartTrainedModelDeploymentAction.TaskParams.NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); } + if (adaptiveAllocationsSettings != null) { + builder.field(StartTrainedModelDeploymentAction.Request.ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } if (queueCapacity != null) { builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity); } @@ -649,6 +666,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -660,6 +680,7 @@ public boolean equals(Object o) { && Objects.equals(modelId, that.modelId) && Objects.equals(threadsPerAllocation, that.threadsPerAllocation) && Objects.equals(numberOfAllocations, that.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings) && Objects.equals(queueCapacity, that.queueCapacity) && Objects.equals(startTime, that.startTime) && Objects.equals(state, that.state) @@ -677,6 +698,7 @@ public int hashCode() { modelId, threadsPerAllocation, numberOfAllocations, + adaptiveAllocationsSettings, queueCapacity, startTime, nodeStats, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index b7219fbaa2061..60e0c0e86a828 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -23,6 +23,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -52,6 +53,7 @@ public final class TrainedModelAssignment implements SimpleDiffable PARSER = new ConstructingObjectParser<>( @@ -64,7 +66,8 @@ public final class TrainedModelAssignment implements SimpleDiffable AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + null, + ADAPTIVE_ALLOCATIONS + ); } private final StartTrainedModelDeploymentAction.TaskParams taskParams; @@ -96,6 +105,7 @@ public final class TrainedModelAssignment implements SimpleDiffable assignableNodeIds) { int allocations = nodeRoutingTable.entrySet() .stream() @@ -301,12 +324,21 @@ public boolean equals(Object o) { && Objects.equals(reason, that.reason) && Objects.equals(assignmentState, that.assignmentState) && Objects.equals(startTime, that.startTime) - && maxAssignedAllocations == that.maxAssignedAllocations; + && maxAssignedAllocations == that.maxAssignedAllocations + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(nodeRoutingTable, taskParams, assignmentState, reason, startTime, maxAssignedAllocations); + return Objects.hash( + nodeRoutingTable, + taskParams, + assignmentState, + reason, + startTime, + maxAssignedAllocations, + adaptiveAllocationsSettings + ); } @Override @@ -320,6 +352,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.timeField(START_TIME.getPreferredName(), startTime); builder.field(MAX_ASSIGNED_ALLOCATIONS.getPreferredName(), maxAssignedAllocations); + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); builder.endObject(); return builder; } @@ -334,6 +367,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { out.writeVInt(maxAssignedAllocations); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } public Optional calculateAllocationStatus() { @@ -355,6 +391,7 @@ public static class Builder { private String reason; private Instant startTime; private int maxAssignedAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; public static Builder fromAssignment(TrainedModelAssignment assignment) { return new Builder( @@ -363,12 +400,20 @@ public static Builder fromAssignment(TrainedModelAssignment assignment) { assignment.assignmentState, assignment.reason, assignment.startTime, - assignment.maxAssignedAllocations + assignment.maxAssignedAllocations, + assignment.adaptiveAllocationsSettings ); } - public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) { - return new Builder(taskParams); + public static Builder empty(CreateTrainedModelAssignmentAction.Request request) { + return new Builder(request.getTaskParams(), request.getAdaptiveAllocationsSettings()); + } + + public static Builder empty( + StartTrainedModelDeploymentAction.TaskParams taskParams, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + return new Builder(taskParams, adaptiveAllocationsSettings); } private Builder( @@ -377,7 +422,8 @@ private Builder( AssignmentState assignmentState, String reason, Instant startTime, - int maxAssignedAllocations + int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { this.taskParams = taskParams; this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable); @@ -385,10 +431,11 @@ private Builder( this.reason = reason; this.startTime = startTime; this.maxAssignedAllocations = maxAssignedAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } - private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) { - this(taskParams, new LinkedHashMap<>(), AssignmentState.STARTING, null, Instant.now(), 0); + private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this(taskParams, new LinkedHashMap<>(), AssignmentState.STARTING, null, Instant.now(), 0, adaptiveAllocationsSettings); } public Builder setStartTime(Instant startTime) { @@ -401,6 +448,11 @@ public Builder setMaxAssignedAllocations(int maxAssignedAllocations) { return this; } + public Builder setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + return this; + } + public Builder addRoutingEntry(String nodeId, RoutingInfo routingInfo) { if (nodeRoutingTable.containsKey(nodeId)) { throw new ResourceAlreadyExistsException( @@ -518,7 +570,15 @@ public Builder setNumberOfAllocations(int numberOfAllocations) { } public TrainedModelAssignment build() { - return new TrainedModelAssignment(taskParams, nodeRoutingTable, assignmentState, reason, startTime, maxAssignedAllocations); + return new TrainedModelAssignment( + taskParams, + nodeRoutingTable, + assignmentState, + reason, + startTime, + maxAssignedAllocations, + adaptiveAllocationsSettings + ); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java index 71a68a65b7977..39f646df0d582 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java @@ -14,7 +14,7 @@ public class CreateTrainedModelAssignmentActionRequestTests extends AbstractWire @Override protected Request createTestInstance() { - return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom()); + return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java index 8c175c17fccc8..d60bbc6cc7713 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java @@ -156,6 +156,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -228,6 +229,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -300,6 +302,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -372,6 +375,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -445,6 +449,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -518,6 +523,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -591,6 +597,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java index ad33a85d42e53..730d994fc5e35 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java @@ -71,7 +71,8 @@ public static Request createRandom() { } if (randomBoolean()) { request.setPriority(randomFrom(Priority.values()).toString()); - if (request.getNumberOfAllocations() > 1 || request.getThreadsPerAllocation() > 1) { + if ((request.getNumberOfAllocations() != null && request.getNumberOfAllocations() > 1) + || request.getThreadsPerAllocation() > 1) { request.setPriority(Priority.NORMAL.toString()); } } @@ -230,7 +231,8 @@ public void testDefaults() { Request request = new Request(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(30))); assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED)); - assertThat(request.getNumberOfAllocations(), equalTo(1)); + assertThat(request.getNumberOfAllocations(), nullValue()); + assertThat(request.computeNumberOfAllocations(), equalTo(1)); assertThat(request.getThreadsPerAllocation(), equalTo(1)); assertThat(request.getQueueCapacity(), equalTo(1024)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java index a1ab023a6935f..07c56b073cd00 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java @@ -50,6 +50,7 @@ public static AssignmentStats randomDeploymentStats() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)), Instant.now(), @@ -102,6 +103,7 @@ public void testGetOverallInferenceStats() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -166,6 +168,7 @@ public void testGetOverallInferenceStatsWithNoNodes() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -187,6 +190,7 @@ public void testGetOverallInferenceStatsWithOnlyStoppedNodes() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index 75706f3d6a9bf..6d70105dfedba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -39,7 +39,7 @@ public class TrainedModelAssignmentTests extends AbstractXContentSerializingTestCase { public static TrainedModelAssignment randomInstance() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams(), null); List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList(); for (String node : nodes) { builder.addRoutingEntry(node, RoutingInfoTests.randomInstance()); @@ -72,7 +72,7 @@ protected TrainedModelAssignment mutateInstance(TrainedModelAssignment instance) } public void testBuilderAddingExistingRoute() { - TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams(), null); String addingNode = "new-node"; assignment.addRoutingEntry(addingNode, RoutingInfoTests.randomInstance()); @@ -80,7 +80,7 @@ public void testBuilderAddingExistingRoute() { } public void testBuilderUpdatingMissingRoute() { - TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams(), null); String addingNode = "new-node"; expectThrows( ResourceNotFoundException.class, @@ -93,7 +93,7 @@ public void testGetStartedNodes() { String startedNode2 = "started-node-2"; String nodeInAnotherState1 = "another-state-node-1"; String nodeInAnotherState2 = "another-state-node-2"; - TrainedModelAssignment allocation = TrainedModelAssignment.Builder.empty(randomParams()) + TrainedModelAssignment allocation = TrainedModelAssignment.Builder.empty(randomParams(), null) .addRoutingEntry(startedNode1, RoutingInfoTests.randomInstance(RoutingState.STARTED)) .addRoutingEntry(startedNode2, RoutingInfoTests.randomInstance(RoutingState.STARTED)) .addRoutingEntry( @@ -114,20 +114,20 @@ public void testGetStartedNodes() { public void testCalculateAllocationStatus_GivenNoAllocations() { assertThat( - TrainedModelAssignment.Builder.empty(randomTaskParams(5)).build().calculateAllocationStatus(), + TrainedModelAssignment.Builder.empty(randomTaskParams(5), null).build().calculateAllocationStatus(), isPresentWith(new AllocationStatus(0, 5)) ); } public void testCalculateAllocationStatus_GivenStoppingAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); assertThat(builder.stopAssignment("test").build().calculateAllocationStatus(), isEmpty()); } public void testCalculateAllocationStatus_GivenPartiallyAllocated() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTING, "")); @@ -135,28 +135,28 @@ public void testCalculateAllocationStatus_GivenPartiallyAllocated() { } public void testCalculateAllocationStatus_GivenFullyAllocated() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat(builder.build().calculateAllocationStatus(), isPresentWith(new AllocationStatus(5, 5))); } public void testCalculateAssignmentState_GivenNoStartedAssignments() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTING, "")); assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTING)); } public void testCalculateAssignmentState_GivenOneStartedAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTED)); } public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat( @@ -166,7 +166,7 @@ public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoStartedAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, "")); TrainedModelAssignment assignment = builder.build(); @@ -175,7 +175,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -185,7 +185,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSin } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNoNodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -195,7 +195,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNode1() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, "")); TrainedModelAssignment assignment = builder.build(); @@ -205,7 +205,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh } public void testSingleRequestWith2Nodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -216,7 +216,7 @@ public void testSingleRequestWith2Nodes() { } public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -239,7 +239,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodesWithZeroAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, "")); @@ -257,7 +257,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul } public void testIsSatisfied_GivenEnoughAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -266,7 +266,7 @@ public void testIsSatisfied_GivenEnoughAllocations() { } public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNotAssignable() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -275,7 +275,7 @@ public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNotAssignable() { } public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNeitherStartingNorStarted() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry( "node-1", new RoutingInfo(1, 1, randomFrom(RoutingState.FAILED, RoutingState.STOPPING, RoutingState.STOPPED), "") @@ -287,7 +287,7 @@ public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNeitherStartingNo } public void testIsSatisfied_GivenNotEnoughAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(7)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(7), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -296,7 +296,7 @@ public void testIsSatisfied_GivenNotEnoughAllocations() { } public void testMaxAssignedAllocations() { - TrainedModelAssignment assignment = TrainedModelAssignment.Builder.empty(randomTaskParams(10)) + TrainedModelAssignment assignment = TrainedModelAssignment.Builder.empty(randomTaskParams(10), null) .addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTING, "")) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 966cc029232b1..99779ac378d89 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbedding; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; import java.net.URI; @@ -37,6 +38,9 @@ import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; public final class ServiceUtils { @@ -126,6 +130,17 @@ public static Object removeAsOneOfTypes( return null; } + public static AdaptiveAllocationsSettings removeAsAdaptiveAllocationsSettings(Map sourceMap, String key) { + Map settingsMap = ServiceUtils.removeFromMap(sourceMap, key); + return settingsMap == null + ? null + : new AdaptiveAllocationsSettings( + ServiceUtils.removeAsType(settingsMap, ENABLED.getPreferredName(), Boolean.class), + ServiceUtils.removeAsType(settingsMap, MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class), + ServiceUtils.removeAsType(settingsMap, MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class) + ); + } + @SuppressWarnings("unchecked") public static Map removeFromMap(Map sourceMap, String fieldName) { return (Map) sourceMap.remove(fieldName); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java index 6c81cc9948b70..b74dbe482acc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java @@ -9,11 +9,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; import java.util.Map; @@ -25,8 +28,13 @@ public class CustomElandInternalServiceSettings extends ElasticsearchInternalSer public static final String NAME = "custom_eland_model_internal_service_settings"; - public CustomElandInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public CustomElandInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } /** @@ -50,6 +58,16 @@ public static CustomElandInternalServiceSettings fromMap(Map map validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { @@ -59,12 +77,18 @@ public static CustomElandInternalServiceSettings fromMap(Map map var builder = new Builder() { @Override public CustomElandInternalServiceSettings build() { - return new CustomElandInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new CustomElandInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(numAllocations); builder.setNumThreads(numThreads); builder.setModelId(modelId); + builder.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); return builder.build(); } @@ -74,7 +98,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public CustomElandInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java index 5ef9ce1a0507f..8413d06045601 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; @@ -107,19 +108,38 @@ private static CommonFields commonFieldsFromMap(Map map, Validat private final SimilarityMeasure similarityMeasure; private final DenseVectorFieldMapper.ElementType elementType; - public CustomElandInternalTextEmbeddingServiceSettings(int numAllocations, int numThreads, String modelId) { - this(numAllocations, numThreads, modelId, null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT); + public CustomElandInternalTextEmbeddingServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + this( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings, + null, + SimilarityMeasure.COSINE, + DenseVectorFieldMapper.ElementType.FLOAT + ); } public CustomElandInternalTextEmbeddingServiceSettings( int numAllocations, int numThreads, String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, Integer dimensions, SimilarityMeasure similarityMeasure, DenseVectorFieldMapper.ElementType elementType ) { - internalServiceSettings = new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId); + internalServiceSettings = new ElasticsearchInternalServiceSettings( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings + ); this.dimensions = dimensions; this.similarityMeasure = Objects.requireNonNull(similarityMeasure); this.elementType = Objects.requireNonNull(elementType); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java index 5a82e73299b85..703fca8c74c31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java @@ -37,6 +37,7 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA var startRequest = new StartTrainedModelDeploymentAction.Request(internalServiceSettings.getModelId(), this.getInferenceEntityId()); startRequest.setNumberOfAllocations(internalServiceSettings.getNumAllocations()); startRequest.setThreadsPerAllocation(internalServiceSettings.getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(internalServiceSettings.getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index d5401f61823db..9dc88be16ddbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -271,6 +271,7 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE model.getServiceSettings().getElasticsearchInternalServiceSettings().getNumAllocations(), model.getServiceSettings().getElasticsearchInternalServiceSettings().getNumThreads(), model.getServiceSettings().getElasticsearchInternalServiceSettings().getModelId(), + model.getServiceSettings().getElasticsearchInternalServiceSettings().getAdaptiveAllocationsSettings(), embeddingSize, model.getServiceSettings().similarity(), model.getServiceSettings().elementType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 45d616074dded..ff4ef4ff0358f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -9,9 +9,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; import java.io.IOException; @@ -34,23 +37,46 @@ public static ElasticsearchInternalServiceSettings fromMap(Map m validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - // if an error occurred while parsing, we'll set these to an invalid value so we don't accidentally get a + // if an error occurred while parsing, we'll set these to an invalid value, so we don't accidentally get a // null pointer when doing unboxing return new ElasticsearchInternalServiceSettings( Objects.requireNonNullElse(numAllocations, FAILED_INT_PARSE_VALUE), Objects.requireNonNullElse(numThreads, FAILED_INT_PARSE_VALUE), - modelId + modelId, + adaptiveAllocationsSettings ); } - public ElasticsearchInternalServiceSettings(int numAllocations, int numThreads, String modelVariant) { - super(numAllocations, numThreads, modelVariant); + public ElasticsearchInternalServiceSettings( + int numAllocations, + int numThreads, + String modelVariant, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelVariant, adaptiveAllocationsSettings); } public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java index 602f3a5c6c4e8..e4aa9616fb332 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -14,6 +16,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; @@ -30,12 +33,24 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt static final int DIMENSIONS = 384; static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE; - public MultilingualE5SmallInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public MultilingualE5SmallInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } /** @@ -66,7 +81,16 @@ private static RequestFields extractRequestFields(Map map, Valid validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); - + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = ServiceUtils.removeAsType(map, MODEL_ID, String.class); if (modelId != null) { if (ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId) == false) { @@ -79,23 +103,34 @@ private static RequestFields extractRequestFields(Map map, Valid } } - return new RequestFields(numAllocations, numThreads, modelId); + return new RequestFields(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } private static MultilingualE5SmallInternalServiceSettings.Builder createBuilder(RequestFields requestFields) { var builder = new InternalServiceSettings.Builder() { @Override public MultilingualE5SmallInternalServiceSettings build() { - return new MultilingualE5SmallInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new MultilingualE5SmallInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(requestFields.numAllocations); builder.setNumThreads(requestFields.numThreads); builder.setModelId(requestFields.modelId); + builder.setAdaptiveAllocationsSettings(requestFields.adaptiveAllocationsSettings); return builder; } - private record RequestFields(@Nullable Integer numAllocations, @Nullable Integer numThreads, @Nullable String modelId) {} + private record RequestFields( + @Nullable Integer numAllocations, + @Nullable Integer numThreads, + @Nullable String modelId, + @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) {} @Override public boolean isFragment() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java index 60d68eb2fcee7..f22118d00cc29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java @@ -47,6 +47,7 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA ); startRequest.setNumberOfAllocations(this.getServiceSettings().getNumAllocations()); startRequest.setThreadsPerAllocation(this.getServiceSettings().getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(this.getServiceSettings().getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 11c97f8b8e37e..54434a7563dab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -216,6 +216,7 @@ private static StartTrainedModelDeploymentAction.Request startDeploymentRequest( ); startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations()); startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(serviceSettings.getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java index 603c218d4dd21..fcbf7394ccb33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java @@ -9,10 +9,13 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; import java.io.IOException; @@ -45,6 +48,16 @@ public static ElserInternalServiceSettings.Builder fromMap(Map m validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (modelId != null && ElserInternalService.VALID_ELSER_MODEL_IDS.contains(modelId) == false) { @@ -58,17 +71,28 @@ public static ElserInternalServiceSettings.Builder fromMap(Map m var builder = new InternalServiceSettings.Builder() { @Override public ElserInternalServiceSettings build() { - return new ElserInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new ElserInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(numAllocations); builder.setNumThreads(numThreads); + builder.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); builder.setModelId(modelId); return builder; } - public ElserInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public ElserInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); Objects.requireNonNull(modelId); } @@ -76,7 +100,10 @@ public ElserInternalServiceSettings(StreamInput in) throws IOException { super( in.readVInt(), in.readVInt(), - in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X) ? in.readString() : ElserInternalService.ELSER_V2_MODEL + in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X) ? in.readString() : ElserInternalService.ELSER_V2_MODEL, + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null ); } @@ -97,11 +124,14 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X)) { out.writeString(getModelId()); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(getAdaptiveAllocationsSettings()); + } } @Override public int hashCode() { - return Objects.hash(NAME, getNumAllocations(), getNumThreads(), getModelId()); + return Objects.hash(NAME, getNumAllocations(), getNumThreads(), getModelId(), getAdaptiveAllocationsSettings()); } @Override @@ -111,6 +141,7 @@ public boolean equals(Object o) { ElserInternalServiceSettings that = (ElserInternalServiceSettings) o; return getNumAllocations() == that.getNumAllocations() && getNumThreads() == that.getNumThreads() - && Objects.equals(getModelId(), that.getModelId()); + && Objects.equals(getModelId(), that.getModelId()) + && Objects.equals(getAdaptiveAllocationsSettings(), that.getAdaptiveAllocationsSettings()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java index 00bb48ae2302a..2cbe2f930c84d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.services.settings; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; import java.util.Objects; @@ -20,15 +22,23 @@ public abstract class InternalServiceSettings implements ServiceSettings { public static final String NUM_ALLOCATIONS = "num_allocations"; public static final String NUM_THREADS = "num_threads"; public static final String MODEL_ID = "model_id"; + public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations"; private final int numAllocations; private final int numThreads; private final String modelId; - - public InternalServiceSettings(int numAllocations, int numThreads, String modelId) { + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + + public InternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { this.numAllocations = numAllocations; this.numThreads = numThreads; this.modelId = modelId; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } public int getNumAllocations() { @@ -43,16 +53,23 @@ public String getModelId() { return modelId; } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InternalServiceSettings that = (InternalServiceSettings) o; - return numAllocations == that.numAllocations && numThreads == that.numThreads && Objects.equals(modelId, that.modelId); + return numAllocations == that.numAllocations + && numThreads == that.numThreads + && Objects.equals(modelId, that.modelId) + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(numAllocations, numThreads, modelId); + return Objects.hash(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } @Override @@ -67,6 +84,7 @@ public void addXContentFragment(XContentBuilder builder, Params params) throws I builder.field(NUM_ALLOCATIONS, getNumAllocations()); builder.field(NUM_THREADS, getNumThreads()); builder.field(MODEL_ID, getModelId()); + builder.field(ADAPTIVE_ALLOCATIONS, getAdaptiveAllocationsSettings()); } @Override @@ -84,12 +102,16 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVInt(getNumAllocations()); out.writeVInt(getNumThreads()); out.writeString(getModelId()); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(getAdaptiveAllocationsSettings()); + } } public abstract static class Builder { private int numAllocations; private int numThreads; private String modelId; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; public abstract InternalServiceSettings build(); @@ -105,6 +127,10 @@ public void setModelId(String modelId) { this.modelId = modelId; } + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + public String getModelId() { return modelId; } @@ -116,5 +142,9 @@ public int getNumAllocations() { public int getNumThreads() { return numThreads; } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java index 0cc3e6698388d..8e8a1db76da14 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java @@ -47,6 +47,7 @@ public static CustomElandInternalTextEmbeddingServiceSettings createRandom() { numAllocations, numThreads, modelId, + null, dims, similarityMeasure, elementType @@ -84,6 +85,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { numThreads, modelId, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -108,6 +110,7 @@ public void testFromMap_Request_DoesNotDefaultSimilarityElementType() { numThreads, modelId, null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -148,6 +151,7 @@ public void testFromMap_Request_IgnoresDimensions() { numThreads, modelId, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -187,6 +191,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { numAllocations, numThreads, modelId, + null, 1, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT @@ -200,6 +205,7 @@ public void testToXContent_WritesAllValues() throws IOException { 1, 1, "model_id", + null, 100, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE @@ -210,7 +216,8 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"num_allocations":1,"num_threads":1,"model_id":"model_id","dimensions":100,"similarity":"cosine","element_type":"byte"}""")); + {"num_allocations":1,"num_threads":1,"model_id":"model_id","adaptive_allocations":null,"dimensions":100,""" + """ + "similarity":"cosine","element_type":"byte"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3bec202ed9e5e..ad1910cb9fc0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -134,7 +134,8 @@ public void testParseRequestConfig() { var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( 1, 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + null ); service.parseRequestConfig( @@ -400,7 +401,7 @@ public void testParsePersistedConfig() { taskType, settings ); - var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid"); + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid", null); assertEquals( new CustomElandEmbeddingModel(randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, elandServiceSettings), parsedModel @@ -430,7 +431,8 @@ public void testParsePersistedConfig() { var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( 1, 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + null ); MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( @@ -500,7 +502,7 @@ public void testChunkInfer() { "foo", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform") + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) ); var service = createService(client); @@ -594,7 +596,7 @@ public void testChunkInferSetsTokenization() { "foo", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform") + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) ); var service = createService(client); @@ -726,11 +728,11 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, - new CustomElandInternalServiceSettings(1, 4, "custom-model"), + new CustomElandInternalServiceSettings(1, 4, "custom-model", null), CustomElandRerankTaskSettings.DEFAULT_SETTINGS ); } else if (taskType == TaskType.TEXT_EMBEDDING) { - var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model"); + var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model", null); expectedModel = new CustomElandEmbeddingModel( randomInferenceEntityId, @@ -786,7 +788,7 @@ public void testPutModel() { "my-e5", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, ".multilingual-e5-small") + new MultilingualE5SmallInternalServiceSettings(1, 1, ".multilingual-e5-small", null) ); service.putModel(model, new ActionListener<>() { @@ -827,6 +829,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { 1, 4, "custom-model", + null, 1, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT @@ -850,6 +853,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { 4, "custom-model", null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java index fbff04efe6883..927d53360a2c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java @@ -24,7 +24,8 @@ public static MultilingualE5SmallInternalServiceSettings createRandom() { return new MultilingualE5SmallInternalServiceSettings( randomIntBetween(1, 4), randomIntBetween(1, 4), - randomFrom(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS) + randomFrom(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS), + null ); } @@ -56,7 +57,7 @@ public void testFromMap() { ) ) ).build(); - assertEquals(new MultilingualE5SmallInternalServiceSettings(1, 4, randomModelVariant), serviceSettings); + assertEquals(new MultilingualE5SmallInternalServiceSettings(1, 4, randomModelVariant, null), serviceSettings); } public void testFromMapInvalidVersion() { @@ -130,12 +131,14 @@ protected MultilingualE5SmallInternalServiceSettings mutateInstance(Multilingual case 0 -> new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations() + 1, instance.getNumThreads(), - instance.getModelId() + instance.getModelId(), + null ); case 1 -> new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads() + 1, - instance.getModelId() + instance.getModelId(), + null ); case 2 -> { var versions = new HashSet<>(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS); @@ -143,7 +146,8 @@ protected MultilingualE5SmallInternalServiceSettings mutateInstance(Multilingual yield new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads(), - versions.iterator().next() + versions.iterator().next(), + null ); } default -> throw new IllegalStateException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java index c0e425144a618..e7fbbffa2d3fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java @@ -25,7 +25,8 @@ public static ElserInternalServiceSettings createRandom() { return new ElserInternalServiceSettings( randomIntBetween(1, 4), randomIntBetween(1, 2), - randomFrom(ElserInternalService.VALID_ELSER_MODEL_IDS) + randomFrom(ElserInternalService.VALID_ELSER_MODEL_IDS), + null ); } @@ -49,7 +50,7 @@ public void testFromMap() { ) ) ).build(); - assertEquals(new ElserInternalServiceSettings(1, 4, ".elser_model_1"), serviceSettings); + assertEquals(new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), serviceSettings); } public void testFromMapInvalidVersion() { @@ -89,12 +90,12 @@ public void testFromMapMissingOptions() { public void testBwcWrite() throws IOException { { - var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1"); + var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1", null); var copy = copyInstance(settings, TransportVersions.V_8_12_0); assertEquals(settings, copy); } { - var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1"); + var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1", null); var copy = copyInstance(settings, TransportVersions.V_8_11_X); assertEquals(settings, copy); } @@ -123,12 +124,27 @@ protected ElserInternalServiceSettings createTestInstance() { @Override protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettings instance) { return switch (randomIntBetween(0, 2)) { - case 0 -> new ElserInternalServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads(), instance.getModelId()); - case 1 -> new ElserInternalServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1, instance.getModelId()); + case 0 -> new ElserInternalServiceSettings( + instance.getNumAllocations() + 1, + instance.getNumThreads(), + instance.getModelId(), + null + ); + case 1 -> new ElserInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads() + 1, + instance.getModelId(), + null + ); case 2 -> { var versions = new HashSet<>(ElserInternalService.VALID_ELSER_MODEL_IDS); versions.remove(instance.getModelId()); - yield new ElserInternalServiceSettings(instance.getNumAllocations(), instance.getNumThreads(), versions.iterator().next()); + yield new ElserInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads(), + versions.iterator().next(), + null + ); } default -> throw new IllegalStateException(); }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index bc7dca4f11960..5ee55003e7fe1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -108,7 +108,7 @@ public void testParseConfigStrict() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1"), + new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), ElserMlNodeTaskSettings.DEFAULT ); @@ -141,7 +141,7 @@ public void testParseConfigLooseWithOldModelId() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1"), + new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), ElserMlNodeTaskSettings.DEFAULT ); @@ -171,7 +171,7 @@ public void testParseConfigStrictWithNoTaskSettings() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserInternalService.ELSER_V2_MODEL), + new ElserInternalServiceSettings(1, 4, ElserInternalService.ELSER_V2_MODEL, null), ElserMlNodeTaskSettings.DEFAULT ); @@ -373,7 +373,7 @@ public void testChunkInfer() { "foo", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, "elser"), + new ElserInternalServiceSettings(1, 1, "elser", null), new ElserMlNodeTaskSettings() ); var service = createService(client); @@ -437,7 +437,7 @@ public void testChunkInferSetsTokenization() { "foo", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, "elser"), + new ElserInternalServiceSettings(1, 1, "elser", null), new ElserMlNodeTaskSettings() ); var service = createService(client); @@ -489,7 +489,7 @@ public void testPutModel() { "my-elser", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2"), + new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), ElserMlNodeTaskSettings.DEFAULT ); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java index 30f84a97bcfb0..1d67639f712a0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MlDailyMaintenanceService; import org.elasticsearch.xpack.ml.MlInitializationService; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; import org.junit.Before; import java.util.List; @@ -47,7 +48,14 @@ public void setUpMocks() { when(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); MlDailyMaintenanceService mlDailyMaintenanceService = mock(MlDailyMaintenanceService.class); ClusterService clusterService = mock(ClusterService.class); - mlInitializationService = new MlInitializationService(client(), threadPool, mlDailyMaintenanceService, clusterService); + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService = mock(AdaptiveAllocationsScalerService.class); + mlInitializationService = new MlInitializationService( + client(), + threadPool, + mlDailyMaintenanceService, + adaptiveAllocationsScalerService, + clusterService + ); } public void testThatMlIndicesBecomeHiddenWhenTheNodeBecomesMaster() throws Exception { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 6fdc4e73e184f..22a9c2dbcc281 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -1282,6 +1282,7 @@ public Collection createComponents(PluginServices services) { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, machineLearningExtension.get().isAnomalyDetectionEnabled(), machineLearningExtension.get().isDataFrameAnalyticsEnabled(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java index a2d8fd1d60316..346b67a169912 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java @@ -32,6 +32,8 @@ import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.annotations.AnnotationIndex; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.Collections; import java.util.Map; @@ -55,6 +57,8 @@ public final class MlInitializationService implements ClusterStateListener { private final MlDailyMaintenanceService mlDailyMaintenanceService; + private final AdaptiveAllocationsScalerService adaptiveAllocationsScalerService; + private boolean isMaster = false; MlInitializationService( @@ -62,6 +66,7 @@ public final class MlInitializationService implements ClusterStateListener { ThreadPool threadPool, ClusterService clusterService, Client client, + InferenceAuditor inferenceAuditor, MlAssignmentNotifier mlAssignmentNotifier, boolean isAnomalyDetectionEnabled, boolean isDataFrameAnalyticsEnabled, @@ -81,6 +86,7 @@ public final class MlInitializationService implements ClusterStateListener { isDataFrameAnalyticsEnabled, isNlpEnabled ), + new AdaptiveAllocationsScalerService(threadPool, clusterService, client, inferenceAuditor, isNlpEnabled), clusterService ); } @@ -90,11 +96,13 @@ public MlInitializationService( Client client, ThreadPool threadPool, MlDailyMaintenanceService dailyMaintenanceService, + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService, ClusterService clusterService ) { this.client = Objects.requireNonNull(client); this.threadPool = threadPool; this.mlDailyMaintenanceService = dailyMaintenanceService; + this.adaptiveAllocationsScalerService = adaptiveAllocationsScalerService; clusterService.addListener(this); clusterService.addLifecycleListener(new LifecycleListener() { @Override @@ -115,11 +123,13 @@ public void beforeStop() { public void onMaster() { mlDailyMaintenanceService.start(); + adaptiveAllocationsScalerService.start(); threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::makeMlInternalIndicesHidden); } public void offMaster() { mlDailyMaintenanceService.stop(); + adaptiveAllocationsScalerService.stop(); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java index 348cb396f9c9f..30371fcbe115a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java @@ -75,7 +75,7 @@ public TransportCreateTrainedModelAssignmentAction( @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { trainedModelAssignmentClusterService.createNewModelAssignment( - request.getTaskParams(), + request, listener.delegateFailureAndWrap((l, trainedModelAssignment) -> l.onResponse(new Response(trainedModelAssignment))) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java index 04b597292dad6..590aeded2b674 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java @@ -238,6 +238,7 @@ static GetDeploymentStatsAction.Response addFailedRoutes( stat.getModelId(), stat.getThreadsPerAllocation(), stat.getNumberOfAllocations(), + stat.getAdaptiveAllocationsSettings(), stat.getQueueCapacity(), stat.getCacheSize(), stat.getStartTime(), @@ -277,6 +278,7 @@ static GetDeploymentStatsAction.Response addFailedRoutes( assignment.getModelId(), assignment.getTaskParams().getThreadsPerAllocation(), assignment.getTaskParams().getNumberOfAllocations(), + assignment.getAdaptiveAllocationsSettings(), assignment.getTaskParams().getQueueCapacity(), assignment.getTaskParams().getCacheSize().orElse(null), assignment.getStartTime(), @@ -346,6 +348,7 @@ protected void taskOperation( task.getParams().getModelId(), task.getParams().getThreadsPerAllocation(), assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(), + assignment == null ? null : assignment.getAdaptiveAllocationsSettings(), task.getParams().getQueueCapacity(), task.getParams().getCacheSize().orElse(null), TrainedModelAssignmentMetadata.fromState(clusterService.state()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index de93a41fb7296..ae0da7dc9cc69 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -207,7 +207,7 @@ protected void masterOperation( modelIdAndSizeInBytes.v1(), request.getDeploymentId(), modelIdAndSizeInBytes.v2(), - request.getNumberOfAllocations(), + request.computeNumberOfAllocations(), request.getThreadsPerAllocation(), request.getQueueCapacity(), Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelIdAndSizeInBytes.v2())), @@ -219,7 +219,10 @@ protected void masterOperation( memoryTracker.refresh( persistentTasks, ActionListener.wrap( - aVoid -> trainedModelAssignmentService.createNewModelAssignment(taskParams, waitForDeploymentToStart), + aVoid -> trainedModelAssignmentService.createNewModelAssignment( + new CreateTrainedModelAssignmentAction.Request(taskParams, request.getAdaptiveAllocationsSettings()), + waitForDeploymentToStart + ), listener::onFailure ) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java index 7d4143d9e722a..fa38b30ae8b84 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java @@ -81,9 +81,11 @@ protected void masterOperation( ) ); - trainedModelAssignmentClusterService.updateNumberOfAllocations( + trainedModelAssignmentClusterService.updateDeployment( request.getDeploymentId(), request.getNumberOfAllocations(), + request.getAdaptiveAllocationsSettings(), + request.isInternal(), ActionListener.wrap(updatedAssignment -> { auditor.info( request.getDeploymentId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java new file mode 100644 index 0000000000000..b33e86d434f95 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; + +/** + * Processes measured requests counts and inference times and decides whether + * the number of allocations should be scaled up or down. + */ +public class AdaptiveAllocationsScaler { + + // visible for testing + static final double SCALE_UP_THRESHOLD = 0.9; + private static final double SCALE_DOWN_THRESHOLD = 0.85; + + private static final Logger logger = LogManager.getLogger(AdaptiveAllocationsScaler.class); + + private final String deploymentId; + private final KalmanFilter1d requestRateEstimator; + private final KalmanFilter1d inferenceTimeEstimator; + + private int numberOfAllocations; + private Integer minNumberOfAllocations; + private Integer maxNumberOfAllocations; + private boolean dynamicsChanged; + + AdaptiveAllocationsScaler(String deploymentId, int numberOfAllocations) { + this.deploymentId = deploymentId; + // A smoothing factor of 100 roughly means the last 100 measurements have an effect + // on the estimated values. The sampling time is 10 seconds, so approximately the + // last 15 minutes are taken into account. + // For the request rate, use auto-detection for dynamics changes, because the request + // rate maybe change due to changed user behaviour. + // For the inference time, don't use this auto-detection. The dynamics may change when + // the number of allocations changes, which is passed explicitly to the estimator. + requestRateEstimator = new KalmanFilter1d(deploymentId + ":rate", 100, true); + inferenceTimeEstimator = new KalmanFilter1d(deploymentId + ":time", 100, false); + this.numberOfAllocations = numberOfAllocations; + this.minNumberOfAllocations = null; + this.maxNumberOfAllocations = null; + this.dynamicsChanged = false; + } + + void setMinMaxNumberOfAllocations(Integer minNumberOfAllocations, Integer maxNumberOfAllocations) { + this.minNumberOfAllocations = minNumberOfAllocations; + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + void process(AdaptiveAllocationsScalerService.Stats stats, double timeIntervalSeconds, int numberOfAllocations) { + // The request rate (per second) is the request count divided by the time. + // Assuming a Poisson process for the requests, the variance in the request + // count equals the mean request count, and the variance in the request rate + // equals that variance divided by the time interval squared. + // The minimum request count is set to 1, because lower request counts can't + // be reliably measured. + // The estimated request rate should be used for the variance calculations, + // because the measured request rate gives biased estimates. + double requestRate = (double) stats.requestCount() / timeIntervalSeconds; + double requestRateEstimate = requestRateEstimator.hasValue() ? requestRateEstimator.estimate() : requestRate; + double requestRateVariance = Math.max(1.0, requestRateEstimate * timeIntervalSeconds) / Math.pow(timeIntervalSeconds, 2); + requestRateEstimator.add(requestRate, requestRateVariance, false); + + if (stats.requestCount() > 0 && Double.isNaN(stats.inferenceTime()) == false) { + // The inference time distribution is unknown. For simplicity, we assume + // a std.error equal to the mean, so that the variance equals the mean + // value squared. The variance of the mean is inversely proportional to + // the number of inference measurements it contains. + // Again, the estimated inference time should be used for the variance + // calculations to prevent biased estimates. + double inferenceTime = stats.inferenceTime(); + double inferenceTimeEstimate = inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.estimate() : inferenceTime; + double inferenceTimeVariance = Math.pow(inferenceTimeEstimate, 2) / stats.requestCount(); + inferenceTimeEstimator.add(inferenceTime, inferenceTimeVariance, dynamicsChanged); + } + + this.numberOfAllocations = numberOfAllocations; + dynamicsChanged = false; + } + + double getLoadLower() { + double requestRateLower = Math.max(0.0, requestRateEstimator.lower()); + double inferenceTimeLower = Math.max(0.0, inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.lower() : 1.0); + return requestRateLower * inferenceTimeLower; + } + + double getLoadUpper() { + double requestRateUpper = requestRateEstimator.upper(); + double inferenceTimeUpper = inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.upper() : 1.0; + return requestRateUpper * inferenceTimeUpper; + } + + Integer scale() { + if (requestRateEstimator.hasValue() == false) { + return null; + } + + int oldNumberOfAllocations = numberOfAllocations; + + double loadLower = getLoadLower(); + while (loadLower / numberOfAllocations > SCALE_UP_THRESHOLD) { + numberOfAllocations++; + } + + double loadUpper = getLoadUpper(); + while (numberOfAllocations > 1 && loadUpper / (numberOfAllocations - 1) < SCALE_DOWN_THRESHOLD) { + numberOfAllocations--; + } + + if (minNumberOfAllocations != null) { + numberOfAllocations = Math.max(numberOfAllocations, minNumberOfAllocations); + } + if (maxNumberOfAllocations != null) { + numberOfAllocations = Math.min(numberOfAllocations, maxNumberOfAllocations); + } + + if (numberOfAllocations != oldNumberOfAllocations) { + logger.debug( + () -> Strings.format( + "[%s] adaptive allocations scaler: load in [%.3f, %.3f], scaling from %d to %d allocations.", + deploymentId, + loadLower, + loadUpper, + oldNumberOfAllocations, + numberOfAllocations + ) + ); + } else { + logger.debug( + () -> Strings.format( + "[%s] adaptive allocations scaler: load in [%.3f, %.3f], keeping %d allocations.", + deploymentId, + loadLower, + loadUpper, + numberOfAllocations + ) + ); + } + + if (numberOfAllocations != oldNumberOfAllocations) { + this.dynamicsChanged = true; + return numberOfAllocations; + } else { + return null; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java new file mode 100644 index 0000000000000..30e3871ad5ad0 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Periodically schedules adaptive allocations scaling. This process consists + * of calling the trained model stats API, processing the results, determining + * whether scaling should be applied, and potentially calling the trained + * model update API. + */ +public class AdaptiveAllocationsScalerService implements ClusterStateListener { + + record Stats(long successCount, long pendingCount, long failedCount, double inferenceTime) { + + long requestCount() { + return successCount + pendingCount + failedCount; + } + + double totalInferenceTime() { + return successCount * inferenceTime; + } + + Stats add(Stats value) { + long newSuccessCount = successCount + value.successCount; + long newPendingCount = pendingCount + value.pendingCount; + long newFailedCount = failedCount + value.failedCount; + double newInferenceTime = newSuccessCount > 0 + ? (totalInferenceTime() + value.totalInferenceTime()) / newSuccessCount + : Double.NaN; + return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime); + } + + Stats sub(Stats value) { + long newSuccessCount = Math.max(0, successCount - value.successCount); + long newPendingCount = Math.max(0, pendingCount - value.pendingCount); + long newFailedCount = Math.max(0, failedCount - value.failedCount); + double newInferenceTime = newSuccessCount > 0 + ? (totalInferenceTime() - value.totalInferenceTime()) / newSuccessCount + : Double.NaN; + return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime); + } + } + + /** + * The time interval between the adaptive allocations triggers. + */ + private static final int DEFAULT_TIME_INTERVAL_SECONDS = 10; + /** + * The time that has to pass after scaling up, before scaling down is allowed. + * Note that the ML autoscaling has its own cooldown time to release the hardware. + */ + private static final long SCALE_UP_COOLDOWN_TIME_MILLIS = TimeValue.timeValueMinutes(5).getMillis(); + + private static final Logger logger = LogManager.getLogger(AdaptiveAllocationsScalerService.class); + + private final int timeIntervalSeconds; + private final ThreadPool threadPool; + private final ClusterService clusterService; + private final Client client; + private final InferenceAuditor inferenceAuditor; + private final boolean isNlpEnabled; + private final Map> lastInferenceStatsByDeploymentAndNode; + private Long lastInferenceStatsTimestampMillis; + private final Map scalers; + private final Map lastScaleUpTimesMillis; + + private volatile Scheduler.Cancellable cancellable; + private final AtomicBoolean busy; + + public AdaptiveAllocationsScalerService( + ThreadPool threadPool, + ClusterService clusterService, + Client client, + InferenceAuditor inferenceAuditor, + boolean isNlpEnabled + ) { + this(threadPool, clusterService, client, inferenceAuditor, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS); + } + + // visible for testing + AdaptiveAllocationsScalerService( + ThreadPool threadPool, + ClusterService clusterService, + Client client, + InferenceAuditor inferenceAuditor, + boolean isNlpEnabled, + int timeIntervalSeconds + ) { + this.threadPool = threadPool; + this.clusterService = clusterService; + this.client = client; + this.inferenceAuditor = inferenceAuditor; + this.isNlpEnabled = isNlpEnabled; + this.timeIntervalSeconds = timeIntervalSeconds; + + lastInferenceStatsByDeploymentAndNode = new HashMap<>(); + lastInferenceStatsTimestampMillis = null; + lastScaleUpTimesMillis = new HashMap<>(); + scalers = new HashMap<>(); + busy = new AtomicBoolean(false); + } + + public synchronized void start() { + updateAutoscalers(clusterService.state()); + clusterService.addListener(this); + if (scalers.isEmpty() == false) { + startScheduling(); + } + } + + public synchronized void stop() { + stopScheduling(); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + updateAutoscalers(event.state()); + if (scalers.isEmpty() == false) { + startScheduling(); + } else { + stopScheduling(); + } + } + + private synchronized void updateAutoscalers(ClusterState state) { + if (isNlpEnabled == false) { + return; + } + Set deploymentIds = new HashSet<>(); + TrainedModelAssignmentMetadata assignments = TrainedModelAssignmentMetadata.fromState(state); + for (TrainedModelAssignment assignment : assignments.allAssignments().values()) { + deploymentIds.add(assignment.getDeploymentId()); + if (assignment.getAdaptiveAllocationsSettings() != null && assignment.getAdaptiveAllocationsSettings().getEnabled()) { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.computeIfAbsent( + assignment.getDeploymentId(), + key -> new AdaptiveAllocationsScaler(assignment.getDeploymentId(), assignment.totalTargetAllocations()) + ); + adaptiveAllocationsScaler.setMinMaxNumberOfAllocations( + assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations(), + assignment.getAdaptiveAllocationsSettings().getMaxNumberOfAllocations() + ); + } else { + scalers.remove(assignment.getDeploymentId()); + lastInferenceStatsByDeploymentAndNode.remove(assignment.getDeploymentId()); + } + } + scalers.keySet().removeIf(key -> deploymentIds.contains(key) == false); + } + + private synchronized void startScheduling() { + if (cancellable == null) { + logger.debug("Starting ML adaptive allocations scaler"); + try { + cancellable = threadPool.scheduleWithFixedDelay( + this::trigger, + TimeValue.timeValueSeconds(timeIntervalSeconds), + threadPool.generic() + ); + } catch (EsRejectedExecutionException e) { + if (e.isExecutorShutdown() == false) { + throw e; + } + } + } + } + + private synchronized void stopScheduling() { + if (cancellable != null && cancellable.isCancelled() == false) { + logger.debug("Stopping ML adaptive allocations scaler"); + cancellable.cancel(); + cancellable = null; + } + } + + private void trigger() { + if (busy.getAndSet(true)) { + logger.debug("Skipping inference adaptive allocations scaling, because it's still busy."); + return; + } + ActionListener listener = ActionListener.runAfter( + ActionListener.wrap(this::processDeploymentStats, e -> logger.warn("Error in inference adaptive allocations scaling", e)), + () -> busy.set(false) + ); + getDeploymentStats(listener); + } + + private void getDeploymentStats(ActionListener processDeploymentStats) { + String deploymentIds = String.join(",", scalers.keySet()); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + GetDeploymentStatsAction.INSTANCE, + // TODO(dave/jan): create a lightweight version of this request, because the current one + // collects too much data for the adaptive allocations scaler. + new GetDeploymentStatsAction.Request(deploymentIds), + processDeploymentStats + ); + } + + private void processDeploymentStats(GetDeploymentStatsAction.Response statsResponse) { + Double statsTimeInterval; + long now = System.currentTimeMillis(); + if (lastInferenceStatsTimestampMillis != null) { + statsTimeInterval = (now - lastInferenceStatsTimestampMillis) / 1000.0; + } else { + statsTimeInterval = null; + } + lastInferenceStatsTimestampMillis = now; + + Map recentStatsByDeployment = new HashMap<>(); + Map numberOfAllocations = new HashMap<>(); + + for (AssignmentStats assignmentStats : statsResponse.getStats().results()) { + String deploymentId = assignmentStats.getDeploymentId(); + numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations()); + Map deploymentStats = lastInferenceStatsByDeploymentAndNode.computeIfAbsent( + deploymentId, + key -> new HashMap<>() + ); + for (AssignmentStats.NodeStats nodeStats : assignmentStats.getNodeStats()) { + String nodeId = nodeStats.getNode().getId(); + Stats lastStats = deploymentStats.get(nodeId); + Stats nextStats = new Stats( + nodeStats.getInferenceCount().orElse(0L), + nodeStats.getPendingCount() == null ? 0 : nodeStats.getPendingCount(), + nodeStats.getErrorCount() + nodeStats.getTimeoutCount() + nodeStats.getRejectedExecutionCount(), + nodeStats.getAvgInferenceTime().orElse(0.0) / 1000.0 + ); + deploymentStats.put(nodeId, nextStats); + if (lastStats != null) { + Stats recentStats = nextStats.sub(lastStats); + recentStatsByDeployment.compute( + assignmentStats.getDeploymentId(), + (key, value) -> value == null ? recentStats : value.add(recentStats) + ); + } + } + } + + if (statsTimeInterval == null) { + return; + } + + for (Map.Entry deploymentAndStats : recentStatsByDeployment.entrySet()) { + String deploymentId = deploymentAndStats.getKey(); + Stats stats = deploymentAndStats.getValue(); + AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.get(deploymentId); + adaptiveAllocationsScaler.process(stats, statsTimeInterval, numberOfAllocations.get(deploymentId)); + Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale(); + if (newNumberOfAllocations != null) { + Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId); + if (newNumberOfAllocations < numberOfAllocations.get(deploymentId) + && lastScaleUpTimeMillis != null + && now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) { + logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId); + continue; + } + if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) { + lastScaleUpTimesMillis.put(deploymentId, now); + } + UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId); + updateRequest.setNumberOfAllocations(newNumberOfAllocations); + updateRequest.setIsInternal(true); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + UpdateTrainedModelDeploymentAction.INSTANCE, + updateRequest, + ActionListener.wrap(updateResponse -> { + logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, newNumberOfAllocations); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME) + .execute( + () -> inferenceAuditor.info( + deploymentId, + Strings.format( + "adaptive allocations scaler: scaled [%s] to [%s] allocations.", + deploymentId, + newNumberOfAllocations + ) + ) + ); + }, e -> { + logger.atLevel(Level.WARN) + .withThrowable(e) + .log( + "adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", + deploymentId, + newNumberOfAllocations + ); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME) + .execute( + () -> inferenceAuditor.warning( + deploymentId, + Strings.format( + "adaptive allocations scaler: scaling [%s] to [%s] allocations failed.", + deploymentId, + newNumberOfAllocations + ) + ) + ); + }) + ); + } + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java new file mode 100644 index 0000000000000..ad3e66fc3e8e2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; + +/** + * Estimator for the mean value and stderr of a series of measurements. + *
+ * This implements a 1d Kalman filter with manoeuvre detection. Rather than a derived + * dynamics model we simply fix how much we want to smooth in the steady state. + * See also: Wikipedia. + */ +class KalmanFilter1d { + + private static final Logger logger = LogManager.getLogger(KalmanFilter1d.class); + + private final String name; + private final double smoothingFactor; + private final boolean autodetectDynamicsChange; + + private double value; + private double variance; + private boolean dynamicsChangedLastTime; + + KalmanFilter1d(String name, double smoothingFactor, boolean autodetectDynamicsChange) { + this.name = name; + this.smoothingFactor = smoothingFactor; + this.autodetectDynamicsChange = autodetectDynamicsChange; + this.value = Double.MAX_VALUE; + this.variance = Double.MAX_VALUE; + this.dynamicsChangedLastTime = false; + } + + /** + * Adds a measurement (value, variance) to the estimator. + * dynamicChangedExternal indicates whether the underlying possibly changed before this measurement. + */ + void add(double value, double variance, boolean dynamicChangedExternal) { + boolean dynamicChanged; + if (hasValue() == false) { + dynamicChanged = true; + this.value = value; + this.variance = variance; + } else { + double processVariance = variance / smoothingFactor; + dynamicChanged = dynamicChangedExternal || detectDynamicsChange(value, variance); + if (dynamicChanged || dynamicsChangedLastTime) { + // If we know we likely had a change in the quantity we're estimating or the prediction + // is 10 stddev off, we inject extra noise in the dynamics for this step. + processVariance = Math.pow(value, 2); + } + + double gain = (this.variance + processVariance) / (this.variance + processVariance + variance); + this.value += gain * (value - this.value); + this.variance = (1 - gain) * (this.variance + processVariance); + } + dynamicsChangedLastTime = dynamicChanged; + logger.debug( + () -> Strings.format( + "[%s] measurement %.3f ± %.3f: estimate %.3f ± %.3f (dynamic changed: %s).", + name, + value, + Math.sqrt(variance), + this.value, + Math.sqrt(this.variance), + dynamicChanged + ) + ); + } + + /** + * Returns whether the estimator has received data and contains a value. + */ + boolean hasValue() { + return this.value < Double.MAX_VALUE && this.variance < Double.MAX_VALUE; + } + + /** + * Returns the estimate of the mean value. + */ + double estimate() { + return value; + } + + /** + * Returns the stderr of the estimate. + */ + double error() { + return Math.sqrt(this.variance); + } + + /** + * Returns the lowerbound of the 1 stddev confidence interval of the estimate. + */ + double lower() { + return value - error(); + } + + /** + * Returns the upperbound of the 1 stddev confidence interval of the estimate. + */ + double upper() { + return value + error(); + } + + /** + * Returns whether (value, variance) is very unlikely, indicating that + * the underlying dynamics have changed. + */ + private boolean detectDynamicsChange(double value, double variance) { + return hasValue() && autodetectDynamicsChange && Math.pow(Math.abs(value - this.value), 2) / (variance + this.variance) > 100.0; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index f468e5239fd29..e86a9cfe94045 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -14,6 +14,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; @@ -26,6 +27,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; @@ -38,8 +40,10 @@ import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; @@ -68,6 +72,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils.NODES_CHANGED_REASON; import static org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils.createShuttingDownRoute; @@ -393,7 +398,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) } public void createNewModelAssignment( - StartTrainedModelDeploymentAction.TaskParams params, + CreateTrainedModelAssignmentAction.Request request, ActionListener listener ) { if (clusterService.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) { @@ -401,8 +406,8 @@ public void createNewModelAssignment( new ElasticsearchStatusException( "cannot create new assignment [{}] for model [{}] while cluster upgrade is in progress", RestStatus.CONFLICT, - params.getDeploymentId(), - params.getModelId() + request.getTaskParams().getDeploymentId(), + request.getTaskParams().getModelId() ) ); return; @@ -413,20 +418,20 @@ public void createNewModelAssignment( new ElasticsearchStatusException( "cannot create new assignment [{}] for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, - params.getDeploymentId(), - params.getModelId() + request.getTaskParams().getDeploymentId(), + request.getTaskParams().getModelId() ) ); return; } - rebalanceAssignments(clusterService.state(), Optional.of(params), "model deployment started", ActionListener.wrap(newMetadata -> { - TrainedModelAssignment assignment = newMetadata.getDeploymentAssignment(params.getDeploymentId()); + rebalanceAssignments(clusterService.state(), Optional.of(request), "model deployment started", ActionListener.wrap(newMetadata -> { + TrainedModelAssignment assignment = newMetadata.getDeploymentAssignment(request.getTaskParams().getDeploymentId()); if (assignment == null) { // If we could not allocate the model anywhere then it is possible the assignment // here is null. We should notify the listener of an empty assignment as the // handling of this is done elsewhere with the wait-to-start predicate. - assignment = TrainedModelAssignment.Builder.empty(params).build(); + assignment = TrainedModelAssignment.Builder.empty(request).build(); } listener.onResponse(assignment); }, listener::onFailure)); @@ -528,13 +533,13 @@ private static ClusterState forceUpdate(ClusterState currentState, TrainedModelA return ClusterState.builder(currentState).metadata(metadata).build(); } - ClusterState createModelAssignment(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) throws Exception { - return update(currentState, rebalanceAssignments(currentState, Optional.of(params))); + ClusterState createModelAssignment(ClusterState currentState, CreateTrainedModelAssignmentAction.Request request) throws Exception { + return update(currentState, rebalanceAssignments(currentState, Optional.of(request))); } private void rebalanceAssignments( ClusterState clusterState, - Optional modelToAdd, + Optional createAssignmentRequest, String reason, ActionListener listener ) { @@ -544,7 +549,7 @@ private void rebalanceAssignments( TrainedModelAssignmentMetadata.Builder rebalancedMetadata; try { - rebalancedMetadata = rebalanceAssignments(clusterState, modelToAdd); + rebalancedMetadata = rebalanceAssignments(clusterState, createAssignmentRequest); } catch (Exception e) { listener.onFailure(e); return; @@ -561,7 +566,7 @@ public ClusterState execute(ClusterState currentState) { currentState = stopPlatformSpecificModelsInHeterogeneousClusters( currentState, mlNodesArchitectures, - modelToAdd, + createAssignmentRequest.map(CreateTrainedModelAssignmentAction.Request::getTaskParams), clusterState ); @@ -572,7 +577,7 @@ public ClusterState execute(ClusterState currentState) { return updatedState; } - rebalanceAssignments(currentState, modelToAdd, reason, listener); + rebalanceAssignments(currentState, createAssignmentRequest, reason, listener); return currentState; } @@ -639,7 +644,7 @@ && detectNodeLoads(sourceNodes, source).equals(detectNodeLoads(targetNodes, targ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( ClusterState currentState, - Optional modelToAdd + Optional createAssignmentRequest ) throws Exception { List nodes = getAssignableNodes(currentState); logger.debug(() -> format("assignable nodes are %s", nodes.stream().map(DiscoveryNode::getId).toList())); @@ -651,7 +656,7 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( currentMetadata, nodeLoads, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), - modelToAdd, + createAssignmentRequest, allocatedProcessorsScale, useNewMemoryFields ); @@ -668,8 +673,12 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( rebalancer.rebalance() ); - if (modelToAdd.isPresent()) { - checkModelIsFullyAllocatedIfScalingIsNotPossible(modelToAdd.get().getDeploymentId(), rebalanced, nodes); + if (createAssignmentRequest.isPresent()) { + checkModelIsFullyAllocatedIfScalingIsNotPossible( + createAssignmentRequest.get().getTaskParams().getDeploymentId(), + rebalanced, + nodes + ); } return rebalanced; @@ -795,14 +804,22 @@ private boolean isScalingPossible(List nodes) { || (smallestMLNode.isPresent() && smallestMLNode.getAsLong() < maxMLNodeSize); } - public void updateNumberOfAllocations(String deploymentId, int numberOfAllocations, ActionListener listener) { - updateNumberOfAllocations(clusterService.state(), deploymentId, numberOfAllocations, listener); + public void updateDeployment( + String deploymentId, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + boolean isInternal, + ActionListener listener + ) { + updateDeployment(clusterService.state(), deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener); } - private void updateNumberOfAllocations( + private void updateDeployment( ClusterState clusterState, String deploymentId, - int numberOfAllocations, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettingsUpdates, + boolean isInternal, ActionListener listener ) { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(clusterState); @@ -811,7 +828,27 @@ private void updateNumberOfAllocations( listener.onFailure(ExceptionsHelper.missingModelDeployment(deploymentId)); return; } - if (existingAssignment.getTaskParams().getNumberOfAllocations() == numberOfAllocations) { + AdaptiveAllocationsSettings adaptiveAllocationsSettings = getAdaptiveAllocationsSettings( + existingAssignment.getAdaptiveAllocationsSettings(), + adaptiveAllocationsSettingsUpdates + ); + if (adaptiveAllocationsSettings != null) { + if (isInternal == false && adaptiveAllocationsSettings.getEnabled() && numberOfAllocations != null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled"); + listener.onFailure(validationException); + return; + } + ActionRequestValidationException validationException = adaptiveAllocationsSettings.validate(); + if (validationException != null) { + listener.onFailure(validationException); + return; + } + } + boolean hasUpdates = (numberOfAllocations != null + && Objects.equals(numberOfAllocations, existingAssignment.getTaskParams().getNumberOfAllocations()) == false) + || Objects.equals(adaptiveAllocationsSettings, existingAssignment.getAdaptiveAllocationsSettings()) == false; + if (hasUpdates == false) { listener.onResponse(existingAssignment); return; } @@ -828,7 +865,7 @@ private void updateNumberOfAllocations( if (clusterState.getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) { listener.onFailure( new ElasticsearchStatusException( - "cannot update number_of_allocations for deployment with model id [{}] while cluster upgrade is in progress.", + "cannot update deployment with model id [{}] while cluster upgrade is in progress.", RestStatus.CONFLICT, deploymentId ) @@ -837,7 +874,7 @@ private void updateNumberOfAllocations( } ActionListener updatedStateListener = ActionListener.wrap( - updatedState -> submitUnbatchedTask("update model deployment number_of_allocations", new ClusterStateUpdateTask() { + updatedState -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() { private volatile boolean isUpdated; @@ -848,7 +885,7 @@ public ClusterState execute(ClusterState currentState) { return updatedState; } logger.debug(() -> format("[%s] Retrying update as cluster state has been modified", deploymentId)); - updateNumberOfAllocations(currentState, deploymentId, numberOfAllocations, listener); + updateDeployment(currentState, deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener); return currentState; } @@ -877,38 +914,69 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) listener::onFailure ); - adjustNumberOfAllocations(clusterState, existingAssignment, numberOfAllocations, updatedStateListener); + updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedStateListener); + } + + private AdaptiveAllocationsSettings getAdaptiveAllocationsSettings( + AdaptiveAllocationsSettings original, + AdaptiveAllocationsSettings updates + ) { + if (updates == null) { + return original; + } else if (updates == AdaptiveAllocationsSettings.RESET_PLACEHOLDER) { + return null; + } else if (original == null) { + return updates; + } else { + return original.merge(updates); + } } - private void adjustNumberOfAllocations( + private void updateAssignment( ClusterState clusterState, TrainedModelAssignment assignment, - int numberOfAllocations, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { - if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) { - increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener); + if (numberOfAllocations == null || numberOfAllocations == assignment.getTaskParams().getNumberOfAllocations()) { + updateAndKeepNumberOfAllocations(clusterState, assignment, adaptiveAllocationsSettings, listener); + } else if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) { + increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener); } else { - decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener); + decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener); } }); } + private void updateAndKeepNumberOfAllocations( + ClusterState clusterState, + TrainedModelAssignment assignment, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + ActionListener listener + ) { + TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment) + .setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); + TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState); + builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment); + listener.onResponse(update(clusterState, builder)); + } + private void increaseNumberOfAllocations( ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { try { + TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment) + .setNumberOfAllocations(numberOfAllocations) + .setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); final ClusterState updatedClusterState = update( clusterState, - TrainedModelAssignmentMetadata.builder(clusterState) - .updateAssignment( - assignment.getDeploymentId(), - TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations) - ) + TrainedModelAssignmentMetadata.builder(clusterState).updateAssignment(assignment.getDeploymentId(), updatedAssignment) ); TrainedModelAssignmentMetadata.Builder rebalancedMetadata = rebalanceAssignments(updatedClusterState, Optional.empty()); if (isScalingPossible(getAssignableNodes(clusterState)) == false @@ -931,6 +999,7 @@ private void decreaseNumberOfAllocations( ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations() @@ -938,7 +1007,7 @@ private void decreaseNumberOfAllocations( numberOfAllocations ) : TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations); - + updatedAssignment.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); // We have now reduced allocations to a number we can be sure it is satisfied // and thus we should clear the assignment reason. if (numberOfAllocations <= assignment.totalTargetAllocations()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index ef8af6af445fb..624ef5434e2a0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; @@ -50,7 +51,7 @@ class TrainedModelAssignmentRebalancer { private final TrainedModelAssignmentMetadata currentMetadata; private final Map nodeLoads; private final Map, Collection> mlNodesByZone; - private final Optional deploymentToAdd; + private final Optional createAssignmentRequest; private final int allocatedProcessorsScale; private final boolean useNewMemoryFields; @@ -59,28 +60,29 @@ class TrainedModelAssignmentRebalancer { TrainedModelAssignmentMetadata currentMetadata, Map nodeLoads, Map, Collection> mlNodesByZone, - Optional deploymentToAdd, + Optional createAssignmentRequest, int allocatedProcessorsScale, boolean useNewMemoryFields ) { this.currentMetadata = Objects.requireNonNull(currentMetadata); this.nodeLoads = Objects.requireNonNull(nodeLoads); this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone); - this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd); + this.createAssignmentRequest = Objects.requireNonNull(createAssignmentRequest); this.allocatedProcessorsScale = allocatedProcessorsScale; this.useNewMemoryFields = useNewMemoryFields; } TrainedModelAssignmentMetadata.Builder rebalance() { - if (deploymentToAdd.isPresent() && currentMetadata.hasDeployment(deploymentToAdd.get().getDeploymentId())) { + if (createAssignmentRequest.isPresent() + && currentMetadata.hasDeployment(createAssignmentRequest.get().getTaskParams().getDeploymentId())) { throw new ResourceAlreadyExistsException( "[{}] assignment for deployment with model [{}] already exists", - deploymentToAdd.get().getDeploymentId(), - deploymentToAdd.get().getModelId() + createAssignmentRequest.get().getTaskParams().getDeploymentId(), + createAssignmentRequest.get().getTaskParams().getModelId() ); } - if (deploymentToAdd.isEmpty() && areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) { + if (createAssignmentRequest.isEmpty() && areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) { logger.trace(() -> "No need to rebalance as all model deployments are satisfied"); return TrainedModelAssignmentMetadata.Builder.fromMetadata(currentMetadata); } @@ -176,14 +178,15 @@ private AssignmentPlan computePlanForNormalPriorityModels( assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, assignment.getMaxAssignedAllocations(), + assignment.getAdaptiveAllocationsSettings(), // in the mixed cluster state use old memory fields to avoid unstable assignment plans useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ); }) .forEach(planDeployments::add); - if (deploymentToAdd.isPresent() && deploymentToAdd.get().getPriority() != Priority.LOW) { - StartTrainedModelDeploymentAction.TaskParams taskParams = deploymentToAdd.get(); + if (createAssignmentRequest.isPresent() && createAssignmentRequest.get().getTaskParams().getPriority() != Priority.LOW) { + StartTrainedModelDeploymentAction.TaskParams taskParams = createAssignmentRequest.get().getTaskParams(); planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), @@ -192,6 +195,7 @@ private AssignmentPlan computePlanForNormalPriorityModels( taskParams.getThreadsPerAllocation(), Map.of(), 0, + createAssignmentRequest.get().getAdaptiveAllocationsSettings(), // in the mixed cluster state use old memory fields to avoid unstable assignment plans useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0, useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0 @@ -231,14 +235,15 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod assignment.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), + assignment.getAdaptiveAllocationsSettings(), Priority.LOW, (useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, (useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ) ) .forEach(planDeployments::add); - if (deploymentToAdd.isPresent() && deploymentToAdd.get().getPriority() == Priority.LOW) { - StartTrainedModelDeploymentAction.TaskParams taskParams = deploymentToAdd.get(); + if (createAssignmentRequest.isPresent() && createAssignmentRequest.get().getTaskParams().getPriority() == Priority.LOW) { + StartTrainedModelDeploymentAction.TaskParams taskParams = createAssignmentRequest.get().getTaskParams(); planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), @@ -247,6 +252,7 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod taskParams.getThreadsPerAllocation(), Map.of(), 0, + createAssignmentRequest.get().getAdaptiveAllocationsSettings(), Priority.LOW, (useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0, (useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0 @@ -325,11 +331,12 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) { TrainedModelAssignment existingAssignment = currentMetadata.getDeploymentAssignment(deployment.id()); - TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.empty( - existingAssignment == null && deploymentToAdd.isPresent() - ? deploymentToAdd.get() - : currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams() - ); + TrainedModelAssignment.Builder assignmentBuilder = existingAssignment == null && createAssignmentRequest.isPresent() + ? TrainedModelAssignment.Builder.empty(createAssignmentRequest.get()) + : TrainedModelAssignment.Builder.empty( + currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams(), + currentMetadata.getDeploymentAssignment(deployment.id()).getAdaptiveAllocationsSettings() + ); if (existingAssignment != null) { assignmentBuilder.setStartTime(existingAssignment.getStartTime()); assignmentBuilder.setMaxAssignedAllocations(existingAssignment.getMaxAssignedAllocations()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java index 0609e0e6ff916..bf19b505e5cfe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java @@ -30,7 +30,6 @@ import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; @@ -85,10 +84,10 @@ public void updateModelAssignmentState( } public void createNewModelAssignment( - StartTrainedModelDeploymentAction.TaskParams taskParams, + CreateTrainedModelAssignmentAction.Request request, ActionListener listener ) { - client.execute(CreateTrainedModelAssignmentAction.INSTANCE, new CreateTrainedModelAssignmentAction.Request(taskParams), listener); + client.execute(CreateTrainedModelAssignmentAction.INSTANCE, request, listener); } public void deleteModelAssignment(String modelId, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 98988ffa11055..0151c8f5ee9c8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -60,6 +60,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) { m.threadsPerAllocation(), calculateAllocationsPerNodeToPreserve(m), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 123c728587604..7fc16394ed85c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.util.ArrayList; @@ -37,11 +38,11 @@ public record Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, Priority priority, long perDeploymentMemoryBytes, long perAllocationMemoryBytes ) { - public Deployment( String id, long modelBytes, @@ -49,6 +50,7 @@ public Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, long perDeploymentMemoryBytes, long perAllocationMemoryBytes ) { @@ -59,12 +61,17 @@ public Deployment( threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, + adaptiveAllocationsSettings, Priority.NORMAL, perDeploymentMemoryBytes, perAllocationMemoryBytes ); } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + int getCurrentAssignedAllocations() { return currentAllocationsByNodeId.values().stream().mapToInt(Integer::intValue).sum(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index b1c017b1a784c..38279a2fd6c03 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -118,6 +118,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat // don't rely on the current allocation new HashMap<>(), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ) @@ -149,6 +150,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.threadsPerAllocation(), currentAllocationsByNodeId, m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 9af2e4cd49b17..1f0857391598f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -129,6 +129,7 @@ private AssignmentPlan computeZonePlan( (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations()) ? m.maxAssignedAllocations() : 0, + m.getAdaptiveAllocationsSettings(), // Only force assigning at least once previously assigned models that have not had any allocation yet m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() @@ -154,6 +155,7 @@ private AssignmentPlan computePlanAcrossAllNodes(List plans) { m.threadsPerAllocation(), allocationsByNodeIdByModelId.get(m.id()), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 87fad19ab87fc..1bb2f1006822e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -160,11 +160,11 @@ void processInferenceResult(PyTorchResult result) { } logger.debug(() -> format("[%s] Parsed inference result with id [%s]", modelId, result.requestId())); - updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit())); PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { logger.debug(() -> format("[%s] no pending result for inference [%s]", modelId, result.requestId())); } else { + updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit())); pendingResult.listener.onResponse(result); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index 1a9fc6ce99823..e308eb6007973 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -94,7 +94,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient NUMBER_OF_ALLOCATIONS.getPreferredName(), RestApiVersion.V_8, restRequest, - (r, s) -> r.paramAsInt(s, request.getNumberOfAllocations()), + // This is to propagate a null value, which paramAsInt does not support. + (r, s) -> r.hasParam(s) ? (Integer) r.paramAsInt(s, 0) : request.getNumberOfAllocations(), request::setNumberOfAllocations ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java index 084a9d95939c5..afa372fb94527 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java @@ -1015,6 +1015,7 @@ private Map setupComplexMocks() { null, null, null, + null, Instant.now(), List.of( AssignmentStats.NodeStats.forStartedState( @@ -1064,6 +1065,7 @@ private Map setupComplexMocks() { "model_4", 2, 2, + null, 1000, ByteSizeValue.ofBytes(1000), Instant.now(), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java index 2f30d131021b4..2f251e3b0aee6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java @@ -13,11 +13,14 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.Before; import java.util.Map; @@ -36,6 +39,7 @@ public class MlInitializationServiceTests extends ESTestCase { private ThreadPool threadPool; private ClusterService clusterService; private Client client; + private InferenceAuditor inferenceAuditor; private MlAssignmentNotifier mlAssignmentNotifier; @Before @@ -44,9 +48,11 @@ public void setUpMocks() { threadPool = deterministicTaskQueue.getThreadPool(); clusterService = mock(ClusterService.class); client = mock(Client.class); + inferenceAuditor = mock(InferenceAuditor.class); mlAssignmentNotifier = mock(MlAssignmentNotifier.class); when(clusterService.getClusterName()).thenReturn(CLUSTER_NAME); + when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE); @SuppressWarnings("unchecked") ActionFuture getSettingsResponseActionFuture = mock(ActionFuture.class); @@ -68,6 +74,7 @@ public void testInitialize() { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, true, true, @@ -83,6 +90,7 @@ public void testInitialize_noMasterNode() { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, true, true, @@ -94,11 +102,13 @@ public void testInitialize_noMasterNode() { public void testNodeGoesFromMasterToNonMasterAndBack() { MlDailyMaintenanceService initialDailyMaintenanceService = mock(MlDailyMaintenanceService.class); + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService = mock(AdaptiveAllocationsScalerService.class); MlInitializationService initializationService = new MlInitializationService( client, threadPool, initialDailyMaintenanceService, + adaptiveAllocationsScalerService, clusterService ); initializationService.offMaster(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java index 2b206de4cf42f..bdabb42ecd467 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java @@ -191,7 +191,7 @@ public void testIsNodeSafeToShutdownReturnsFalseWhenStartingDeploymentExists() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -215,12 +215,12 @@ public void testIsNodeSafeToShutdownReturnsFalseWhenStoppingAndStoppedDeployment TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( "2", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -244,12 +244,12 @@ public void testIsNodeSafeToShutdownReturnsTrueWhenStoppedDeploymentsExist() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( "2", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java index 2262c21070e75..5fb1381b881ea 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java @@ -132,18 +132,18 @@ public void testFindTrainedModelAllocationCounts() { TrainedModelAssignmentMetadata.Builder metadataBuilder = TrainedModelAssignmentMetadata.Builder.empty(); metadataBuilder.addNewAssignment( "model1", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node2", new RoutingInfo(0, 1, RoutingState.FAILED, "")) ); metadataBuilder.addNewAssignment( "model2", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) ); metadataBuilder.addNewAssignment( "model3", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node2", new RoutingInfo(0, 1, RoutingState.STARTING, "")) ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java index b8dd3559253ee..4a66be4a773f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java @@ -83,6 +83,7 @@ public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostExceptio "deployment1", randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -121,6 +122,7 @@ public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostExce "deployment1", randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -169,7 +171,8 @@ private static TrainedModelAssignment createAssignment(String modelId) { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java index 0d91ce45c46ba..41a86e436f468 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java @@ -1143,7 +1143,8 @@ public void testGetMemoryAndProcessorsScaleDown() throws InterruptedException { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1158,7 +1159,8 @@ public void testGetMemoryAndProcessorsScaleDown() throws InterruptedException { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1242,7 +1244,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByMinNodes() throws Inte Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry("ml-node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("ml-node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")) @@ -1260,7 +1263,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByMinNodes() throws Inte Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1334,7 +1338,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByDummyEntityMemory() th Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1349,7 +1354,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByDummyEntityMemory() th Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1432,7 +1438,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityProcesso Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1447,7 +1454,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityProcesso Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1525,7 +1533,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityAsMemory Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1540,7 +1549,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityAsMemory Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java index a916900b199ce..970044c188849 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java @@ -1069,7 +1069,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1083,7 +1084,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2"), @@ -1105,7 +1107,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1119,7 +1122,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2"), @@ -1141,7 +1145,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1155,7 +1160,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2", "ml_node_3", "ml_node_4"), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java index 97fd66e284010..ba40dc0bfdda7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java @@ -79,7 +79,8 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -96,7 +97,8 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(8, 8, RoutingState.STARTED, "")) @@ -153,7 +155,8 @@ public void testScale_GivenUnsatisfiedDeployments() { Priority.NORMAL, 0L, 0L - ) + ), + null ) ) .addNewAssignment( @@ -170,7 +173,8 @@ public void testScale_GivenUnsatisfiedDeployments() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -227,7 +231,8 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() Priority.LOW, 0L, 0L - ) + ), + null ) ) .addNewAssignment( @@ -244,7 +249,8 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -301,7 +307,8 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -318,7 +325,8 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -386,7 +394,8 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -403,7 +412,8 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -459,7 +469,8 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -476,7 +487,8 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -536,7 +548,8 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { Priority.LOW, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -553,7 +566,8 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { Priority.LOW, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java new file mode 100644 index 0000000000000..3ad44f256dc66 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java @@ -0,0 +1,239 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AdaptiveAllocationsScalerServiceTests extends ESTestCase { + + private TestThreadPool threadPool; + private ClusterService clusterService; + private Client client; + private InferenceAuditor inferenceAuditor; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool( + new ScalingExecutorBuilder(MachineLearning.UTILITY_THREAD_POOL_NAME, 0, 1, TimeValue.timeValueMinutes(10), false) + ); + clusterService = mock(ClusterService.class); + client = mock(Client.class); + inferenceAuditor = mock(InferenceAuditor.class); + } + + @Override + @After + public void tearDown() throws Exception { + this.threadPool.close(); + super.tearDown(); + } + + private ClusterState getClusterState(int numAllocations) { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.custom("trained_model_assignment")).thenReturn( + new TrainedModelAssignmentMetadata( + Map.of( + "test-deployment", + TrainedModelAssignment.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams( + "model-id", + "test-deployment", + 100_000_000, + numAllocations, + 1, + 1024, + ByteSizeValue.ZERO, + Priority.NORMAL, + 100_000_000, + 100_000_000 + ), + new AdaptiveAllocationsSettings(true, null, null) + ).build() + ) + ) + ); + return clusterState; + } + + private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) { + return new GetDeploymentStatsAction.Response( + List.of(), + List.of(), + List.of( + new AssignmentStats( + "test-deployment", + "model-id", + 1, + numAllocations, + new AdaptiveAllocationsSettings(true, null, null), + 1024, + ByteSizeValue.ZERO, + Instant.now(), + List.of( + AssignmentStats.NodeStats.forStartedState( + DiscoveryNodeUtils.create("node_1"), + inferenceCount, + latency, + latency, + 0, + 0, + 0, + 0, + 0, + Instant.now(), + Instant.now(), + 1, + numAllocations, + inferenceCount, + inferenceCount, + latency, + 0 + ) + ), + Priority.NORMAL + ) + ), + 0 + ); + } + + public void test() throws IOException { + // Initialize the cluster with a deployment with 1 allocation. + ClusterState clusterState = getClusterState(1); + when(clusterService.state()).thenReturn(clusterState); + + AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( + threadPool, + clusterService, + client, + inferenceAuditor, + true, + 1 + ); + service.start(); + + verify(clusterService).state(); + verify(clusterService).addListener(same(service)); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // First cycle: 1 inference request, so no need for scaling. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Second cycle: 150 inference request with a latency of 10ms, so scale up to 2 allocations. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(2)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment"); + updateRequest.setNumberOfAllocations(2); + updateRequest.setIsInternal(true); + verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + clusterState = getClusterState(2); + ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class); + when(clusterChangedEvent.state()).thenReturn(clusterState); + service.clusterChanged(clusterChangedEvent); + + // Third cycle: 0 inference requests, but keep 2 allocations, because of cooldown. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + + service.stop(); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java new file mode 100644 index 0000000000000..9758d00627efe --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Random; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.nullValue; + +public class AdaptiveAllocationsScalerTests extends ESTestCase { + + public void testAutoscaling_scaleUpAndDown() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + // With 1 allocation the system can handle 500 requests * 0.020 sec/request. + // To handle remaining requests the system should scale to 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(500, 100, 100, 0.020), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + + // With 2 allocation the system can handle 800 requests * 0.025 sec/request. + // To handle remaining requests the system should scale to 3 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(800, 100, 50, 0.025), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(3)); + + // With 3 allocations the system can handle the load. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(1000, 0, 0, 0.025), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // No load anymore, so the system should gradually scale down to 1 allocation. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(1)); + } + + public void testAutoscaling_noOscillating() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + // With 1 allocation the system can handle 880 requests * 0.010 sec/request. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Increase the load to 980 requests * 0.010 sec/request, and the system + // should scale to 2 allocations to have some spare capacity. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(920, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(950, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(980, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(980, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Reducing the load to just 880 requests * 0.010 sec/request should not + // trigger scaling down again, to prevent oscillating. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + } + + public void testAutoscaling_respectMinMaxAllocations() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + adaptiveAllocationsScaler.setMinMaxNumberOfAllocations(2, 5); + + // Even though there are no requests, scale to the minimum of 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Even though there are many requests, the scale to the maximum of 5 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(100, 10000, 1000, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(5)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(500, 10000, 1000, 0.010), 10, 5); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // After a while of no requests, scale to the minimum of 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + } + + public void testEstimation_highVariance() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + Random random = new Random(42); + + double averageLoadMean = 0.0; + double averageLoadError = 0.0; + + double time = 0.0; + for (int nextMeasurementTime = 1; nextMeasurementTime <= 100; nextMeasurementTime++) { + // Sample one second of data (until the next measurement time). + // This contains approximately 100 requests with high-variance inference times. + AdaptiveAllocationsScalerService.Stats stats = new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0); + while (time < nextMeasurementTime) { + // Draw inference times from a log-normal distribution, which has high variance. + // This distribution approximately has: mean=3.40, variance=98.4. + double inferenceTime = Math.exp(random.nextGaussian(0.1, 1.5)); + stats = stats.add(new AdaptiveAllocationsScalerService.Stats(1, 0, 0, inferenceTime)); + + // The requests are Poisson distributed, which means the time inbetween + // requests follows an exponential distribution. + // This distribution has on average 100 requests per second. + double dt = 0.01 * random.nextExponential(); + time += dt; + } + + adaptiveAllocationsScaler.process(stats, 1, 1); + double lower = adaptiveAllocationsScaler.getLoadLower(); + double upper = adaptiveAllocationsScaler.getLoadUpper(); + averageLoadMean += (upper + lower) / 2.0; + averageLoadError += (upper - lower) / 2.0; + } + + averageLoadMean /= 100; + averageLoadError /= 100; + + double expectedLoad = 100 * 3.40; + assertThat(averageLoadMean - averageLoadError, lessThan(expectedLoad)); + assertThat(averageLoadMean + averageLoadError, greaterThan(expectedLoad)); + assertThat(averageLoadError / averageLoadMean, lessThan(1 - AdaptiveAllocationsScaler.SCALE_UP_THRESHOLD)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java new file mode 100644 index 0000000000000..f9b3a8966b627 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; + +public class KalmanFilter1dTests extends ESTestCase { + + public void testEstimation_equalValues() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + assertThat(filter.hasValue(), equalTo(false)); + + filter.add(42.0, 9.0, false); + assertThat(filter.hasValue(), equalTo(true)); + assertThat(filter.estimate(), equalTo(42.0)); + assertThat(filter.error(), equalTo(3.0)); + assertThat(filter.lower(), equalTo(39.0)); + assertThat(filter.upper(), equalTo(45.0)); + + // With more data the estimation error should go down. + double previousError = filter.error(); + for (int i = 0; i < 20; i++) { + filter.add(42.0, 9.0, false); + assertThat(filter.estimate(), equalTo(42.0)); + assertThat(filter.error(), lessThan(previousError)); + previousError = filter.error(); + } + } + + public void testEstimation_increasingValues() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(10.0, 1.0, false); + assertThat(filter.estimate(), equalTo(10.0)); + + // As the measured values increase, the estimated value should increase too, + // but it should lag behind. + double previousEstimate = filter.estimate(); + for (double value = 11.0; value < 20.0; value += 1.0) { + filter.add(value, 1.0, false); + assertThat(filter.estimate(), greaterThan(previousEstimate)); + assertThat(filter.estimate(), lessThan(value)); + previousEstimate = filter.estimate(); + } + + // More final values should bring the estimate close to it. + for (int i = 0; i < 20; i++) { + filter.add(20.0, 1.0, false); + } + assertThat(filter.estimate(), greaterThan(19.0)); + assertThat(filter.estimate(), lessThan(20.0)); + } + + public void testEstimation_bigJumpNoAutoDetectDynamicsChanges() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // Without dynamics change autodetection the estimated value should be + // inbetween the old and the new value. + filter.add(100.0, 1.0, false); + assertThat(filter.estimate(), greaterThan(49.0)); + assertThat(filter.estimate(), lessThan(51.0)); + } + + public void testEstimation_bigJumpWithAutoDetectDynamicsChanges() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, true); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // With dynamics change autodetection the estimated value should jump + // instantly to almost the new value. + filter.add(100.0, 1.0, false); + assertThat(filter.estimate(), greaterThan(99.0)); + assertThat(filter.estimate(), lessThan(100.0)); + } + + public void testEstimation_bigJumpWithExternalDetectDynamicsChange() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // For external dynamics changes the estimated value should jump + // instantly to almost the new value. + filter.add(100.0, 1.0, true); + assertThat(filter.estimate(), greaterThan(99.0)); + assertThat(filter.estimate(), lessThan(100.0)); + } + + public void testEstimation_differentSmoothing() { + KalmanFilter1d quickFilter = new KalmanFilter1d("test-filter", 1e-3, false); + for (int i = 0; i < 100; i++) { + quickFilter.add(42.0, 1.0, false); + } + assertThat(quickFilter.estimate(), equalTo(42.0)); + // With low smoothing, the value should be close to the new value. + quickFilter.add(77.0, 1.0, false); + assertThat(quickFilter.estimate(), greaterThan(75.0)); + assertThat(quickFilter.estimate(), lessThan(77.0)); + + KalmanFilter1d slowFilter = new KalmanFilter1d("test-filter", 1e3, false); + for (int i = 0; i < 100; i++) { + slowFilter.add(42.0, 1.0, false); + } + assertThat(slowFilter.estimate(), equalTo(42.0)); + // With high smoothing, the value should be close to the old value. + slowFilter.add(77.0, 1.0, false); + assertThat(slowFilter.estimate(), greaterThan(42.0)); + assertThat(slowFilter.estimate(), lessThan(44.0)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java index f08d2735be8a5..1dc44582492aa 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java @@ -48,6 +48,7 @@ import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; @@ -277,7 +278,7 @@ public void testUpdateModelRoutingTable() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(startedNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -389,7 +390,10 @@ public void testRemoveAssignment() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()))) + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()), null) + ) .build() ) .build() @@ -450,7 +454,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsPo .build(); TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(5); - ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150, 4, 1)); + ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment( + currentState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150, 4, 1), null) + ); TrainedModelAssignment createdAssignment = TrainedModelAssignmentMetadata.fromState(newState).getDeploymentAssignment("new-model"); assertThat(createdAssignment, is(not(nullValue()))); @@ -466,7 +473,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsPo expectThrows( ResourceAlreadyExistsException.class, - () -> trainedModelAssignmentClusterService.createModelAssignment(newState, newParams("new-model", 150)) + () -> trainedModelAssignmentClusterService.createModelAssignment( + newState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150), null) + ) ); } @@ -495,7 +505,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsNo TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(0); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, - () -> trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150, 4, 1)) + () -> trainedModelAssignmentClusterService.createModelAssignment( + currentState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150, 4, 1), null) + ) ); assertThat( @@ -528,7 +541,7 @@ public void testCreateAssignmentWhileResetModeIsTrue() throws InterruptedExcepti CountDownLatch latch = new CountDownLatch(1); trainedModelAssignmentClusterService.createNewModelAssignment( - newParams("new-model", 150), + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150), null), new LatchedActionListener<>( ActionListener.wrap( trainedModelAssignment -> fail("assignment should have failed to be created because reset mode is set"), @@ -560,7 +573,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenPreviouslyShuttingDownNode_Is TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -597,7 +610,7 @@ public void testHaveMlNodesChanged_ReturnsTrueWhenNodeShutsDownAndWasRoutedTo() TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -614,7 +627,7 @@ public void testHaveMlNodesChanged_ReturnsTrueWhenNodeShutsDownAndWasRoutedTo() TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -641,7 +654,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenNodeShutsDownAndWasRoutedTo_B TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -658,7 +671,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenNodeShutsDownAndWasRoutedTo_B TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -700,7 +713,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -747,7 +760,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -759,7 +772,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -781,7 +794,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -793,7 +806,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -815,7 +828,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -827,7 +840,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -851,7 +864,7 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).stopAssignment("test") + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null).stopAssignment("test") ) .build() ) @@ -864,7 +877,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -886,7 +899,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(mlNode2)) @@ -899,7 +912,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -923,12 +936,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -945,12 +958,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -977,12 +990,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .stopAssignment("test") @@ -1000,12 +1013,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -1032,7 +1045,7 @@ public void testDetectReasonToRebalanceModels_WithNodeShutdowns() { TrainedModelAssignmentMetadata fullModelAllocation = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNode2.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1227,7 +1240,7 @@ public void testDetectReasonToRebalanceModels_GivenSingleMlJobStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1242,7 +1255,7 @@ public void testDetectReasonToRebalanceModels_GivenSingleMlJobStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1265,7 +1278,7 @@ public void testDetectReasonToRebalanceModels_GivenOutdatedAssignments() { TrainedModelAssignmentMetadata modelMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 100)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null) .addRoutingEntry(mlNodeId, new RoutingInfo(0, 0, RoutingState.STARTED, "")) ) .build(); @@ -1342,7 +1355,7 @@ public void testDetectReasonToRebalanceModels_GivenMultipleMlJobsStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1357,7 +1370,7 @@ public void testDetectReasonToRebalanceModels_GivenMultipleMlJobsStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1419,7 +1432,7 @@ public void testDetectReasonToRebalanceModels_GivenMlJobsStarted() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1434,7 +1447,7 @@ public void testDetectReasonToRebalanceModels_GivenMlJobsStarted() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1459,7 +1472,7 @@ public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsRouted() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1491,7 +1504,7 @@ public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsNotRouted() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -1519,7 +1532,7 @@ public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsRouted() { TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1563,7 +1576,7 @@ public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsNotRouted() { TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1611,13 +1624,13 @@ public void testRemoveRoutingToUnassignableNodes_RemovesRouteForRemovedNodes() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1668,14 +1681,14 @@ public void testRemoveRoutingToUnassignableNodes_AddsAStoppingRouteForShuttingDo TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -1728,14 +1741,14 @@ public void testRemoveRoutingToUnassignableNodes_IgnoresARouteThatIsStoppedForSh TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) @@ -1789,12 +1802,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( availableNodeModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1802,12 +1815,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( availableNodeModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ); @@ -1840,12 +1853,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( notShuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1853,12 +1866,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notShuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ); @@ -1897,7 +1910,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(disappearingNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1905,7 +1918,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ); @@ -1933,7 +1946,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAssignmentDoesNotExist_ TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -2006,7 +2019,10 @@ public void testSetAllocationToStopping() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()))) + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()), null) + ) .build() ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index 6c5223eae4d99..dec85bff87d67 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -64,7 +64,7 @@ public void testIsAssigned() { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( allocatedDeploymentId, - TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId)) + TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId), null) ) .build(); assertThat(metadata.isAssigned(allocatedDeploymentId), is(true)); @@ -78,7 +78,7 @@ public void testModelIsDeployed() { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( allocatedDeploymentId, - TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId)) + TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId), null) ) .build(); assertThat(metadata.modelIsDeployed(allocatedDeploymentId), is(false)); @@ -92,9 +92,9 @@ public void testGetDeploymentsUsingModel() { String deployment2 = "test_deployment_2"; String deployment3 = "test_deployment_3"; TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(deployment1, TrainedModelAssignment.Builder.empty(randomParams(deployment1, modelId1))) - .addNewAssignment(deployment2, TrainedModelAssignment.Builder.empty(randomParams(deployment2, modelId1))) - .addNewAssignment(deployment3, TrainedModelAssignment.Builder.empty(randomParams(deployment3, "different_model"))) + .addNewAssignment(deployment1, TrainedModelAssignment.Builder.empty(randomParams(deployment1, modelId1), null)) + .addNewAssignment(deployment2, TrainedModelAssignment.Builder.empty(randomParams(deployment2, modelId1), null)) + .addNewAssignment(deployment3, TrainedModelAssignment.Builder.empty(randomParams(deployment3, "different_model"), null)) .build(); var assignments = metadata.getDeploymentsUsingModel(modelId1); assertThat(assignments, hasSize(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index a5bba21d9e778..9fbc2b43f1137 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -353,17 +353,17 @@ public void testClusterChangedWithResetMode() throws InterruptedException { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( modelTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedModel, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -411,7 +411,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -464,7 +464,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) .addRoutingEntry(node2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -507,7 +507,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -548,7 +548,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -590,7 +590,7 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .stopAssignment("stopping") ) @@ -639,12 +639,12 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, @@ -658,7 +658,7 @@ public void testClusterChanged() throws Exception { ) .addNewAssignment( previouslyUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(previouslyUsedDeployment, previouslyUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(previouslyUsedDeployment, previouslyUsedModel), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, @@ -672,7 +672,7 @@ public void testClusterChanged() throws Exception { ) .addNewAssignment( notUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -697,17 +697,17 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -751,7 +751,7 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -793,12 +793,12 @@ public void testClusterChanged_GivenAllStartedAssignments_AndNonMatchingTargetAl TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 3, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(2, 1, RoutingState.STARTED, "")) ) .build() @@ -845,7 +845,7 @@ private void givenAssignmentsInClusterStateForModels(List deploymentIds, for (int i = 0; i < modelIds.size(); i++) { builder.addNewAssignment( deploymentIds.get(i), - TrainedModelAssignment.Builder.empty(newParams(deploymentIds.get(i), modelIds.get(i))) + TrainedModelAssignment.Builder.empty(newParams(deploymentIds.get(i), modelIds.get(i)), null) .addRoutingEntry("test-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 53b737b38c284..65a974e04045e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; @@ -61,11 +62,12 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_ShouldMakeNoChanges() TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentId2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(3, 3, RoutingState.STARTED, "")) ) @@ -101,11 +103,12 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEn TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentId2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(3, 3, RoutingState.STARTED, "")) ) @@ -140,11 +143,18 @@ public void testRebalance_GivenModelToAddAlreadyExists() { String modelId = "model-to-add"; StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(modelId, modelId, 1024L, 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(taskParams)) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(taskParams, null)) .build(); expectThrows( ResourceAlreadyExistsException.class, - () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1, false).rebalance() + () -> new TrainedModelAssignmentRebalancer( + currentMetadata, + Map.of(), + Map.of(), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), + 1, + false + ).rebalance() ); } @@ -157,7 +167,7 @@ public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { currentMetadata, Map.of(), Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -185,7 +195,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exce currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -222,7 +232,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exceptio currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -259,7 +269,7 @@ public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws E currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -296,7 +306,7 @@ public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -330,7 +340,7 @@ public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -357,7 +367,7 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeploymentId, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeploymentId, previousDeploymentId, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeploymentId, previousDeploymentId, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -370,7 +380,7 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -416,13 +426,13 @@ public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -483,13 +493,13 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -554,13 +564,13 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 1, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 1, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -610,7 +620,7 @@ public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exce TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(normalPriorityParams(modelId, 1024L, 1, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(modelId, 1024L, 1, 1), null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.FAILED, "some error")) ) .build(); @@ -656,7 +666,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -693,7 +703,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deployment2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -703,7 +713,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor currentMetadata, nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), - Optional.of(taskParams1), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams1, null)), 1, false ).rebalance().build(); @@ -735,8 +745,8 @@ public void testRebalance_GivenMixedPriorityModels_NotEnoughMemoryForLowPriority String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -786,10 +796,11 @@ public void testRebalance_GivenMixedPriorityModels_TwoZones_EachNodeCanHoldOneMo String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(taskParams2).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams2, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -844,8 +855,8 @@ public void testRebalance_GivenModelUsingAllCpu_FittingLowPriorityModelCanStart( String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -895,8 +906,8 @@ public void testRebalance_GivenMultipleLowPriorityModels_AndMultipleNodes() thro String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = lowPriorityParams(modelId2, ByteSizeValue.ofMb(100).getBytes()); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -946,7 +957,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -954,7 +966,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -999,7 +1011,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1007,7 +1020,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -1052,7 +1065,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1060,7 +1074,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -1107,7 +1121,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 2, false ).rebalance().build(); @@ -1130,7 +1144,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java index 85fc83f775670..603eda65fbd51 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java @@ -181,7 +181,8 @@ private static TrainedModelAssignment createAssignment( Priority.NORMAL, randomNonNegativeLong(), randomNonNegativeLong() - ) + ), + null ); allocationsByNode.entrySet() .stream() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index cbbb38f1d1ddd..d84c04f0c41f1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -25,14 +25,14 @@ public class AssignmentPlanTests extends ESTestCase { public void testBuilderCtor_GivenDuplicateNode() { Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m))); } public void testBuilderCtor_GivenDuplicateModel() { Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m))); } @@ -41,7 +41,17 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); { // old memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(40).getBytes(), 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(40).getBytes(), + 1, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -71,6 +81,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(30).getBytes() ); @@ -107,6 +118,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -134,6 +146,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { 2, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(25).getBytes() ); @@ -160,7 +173,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4); { // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, 0, 0); + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -186,6 +199,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() 2, Map.of("n_1", 2), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(25).getBytes() ); @@ -209,7 +223,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1)); @@ -227,6 +241,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -249,6 +264,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { 2, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(5).getBytes() ); @@ -266,7 +282,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() { Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5)); @@ -279,7 +295,17 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() { Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3)); @@ -299,6 +325,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -335,7 +362,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -346,7 +373,7 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, 0, 0); + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); assertThat(builder.canAssign(m, n, 1), is(true)); } @@ -359,6 +386,7 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { 1, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -369,7 +397,17 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { public void testCanAssign_GivenEnoughMemory() { Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -384,7 +422,7 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planSatisfyingPreviousAssignments = builder.build(); @@ -397,6 +435,7 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { 2, Map.of("n_1", 3), 0, + null, 0, 0 ); @@ -420,6 +459,7 @@ public void testCompareTo_GivenDifferenceInAllocations() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -445,7 +485,7 @@ public void testCompareTo_GivenDifferenceInMemory() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingMoreMemory = builder.build(); @@ -458,6 +498,7 @@ public void testCompareTo_GivenDifferenceInMemory() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -482,6 +523,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 2, Map.of(), 0, + null, 0, 0 ); @@ -492,6 +534,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, 0, 0 ); @@ -502,6 +545,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, 0, 0 ); @@ -522,6 +566,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -532,6 +577,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -542,6 +588,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -558,9 +605,9 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 2) @@ -573,9 +620,9 @@ public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) @@ -586,8 +633,8 @@ public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 1) .build(); @@ -597,7 +644,7 @@ public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { public void testCountPreviouslyAssignedThatAreStillAssigned() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( "m_2", ByteSizeValue.ofMb(30).getBytes(), @@ -605,6 +652,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 4, + null, 0, 0 ); @@ -615,6 +663,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 1, + null, 0, 0 ); @@ -625,6 +674,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 0, + null, 0, 0 ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index bc94144bce1c5..ef76c388b81a1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -42,7 +42,7 @@ private static long scaleNodeSize(long nodeMemory) { public void testModelThatDoesNotFitInMemory() { { // Without perDeploymentMemory and perAllocationMemory specified List nodes = List.of(new Node("n_1", scaleNodeSize(50), 4)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment), isEmpty()); } @@ -55,6 +55,7 @@ public void testModelThatDoesNotFitInMemory() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(51).getBytes() ); @@ -65,7 +66,7 @@ public void testModelThatDoesNotFitInMemory() { public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { List nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment), isEmpty()); } @@ -73,13 +74,13 @@ public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { public void testSingleModelThatFitsFullyOnSingleNode() { { Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { Node node = new Node("n_1", scaleNodeSize(1000), 8); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } @@ -92,6 +93,7 @@ public void testSingleModelThatFitsFullyOnSingleNode() { 16, Map.of(), 0, + null, 0, 0 ); @@ -100,7 +102,7 @@ public void testSingleModelThatFitsFullyOnSingleNode() { } { Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } @@ -116,6 +118,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -131,6 +134,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -142,7 +146,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() { Node node1 = new Node("n_1", scaleNodeSize(100), 4); Node node2 = new Node("n_2", scaleNodeSize(100), 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); @@ -164,6 +168,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(150).getBytes() ); @@ -179,7 +184,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully } public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() { - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, null, 0, 0); // Single node { Node node = new Node("n_1", scaleNodeSize(100), 4); @@ -220,6 +225,7 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -260,10 +266,10 @@ public void testMultipleModelsAndNodesWithSingleSolution() { Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7); Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2); Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0); + Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner( List.of(node1, node2, node3, node4), @@ -322,6 +328,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 4, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -332,6 +339,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 3, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -342,6 +350,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -352,6 +361,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -402,7 +412,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { } public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() { - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, null, 0, 0); // Single node { Node node = new Node("n_1", scaleNodeSize(100), 4); @@ -443,6 +453,7 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA 3, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -487,6 +498,7 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { 1, Map.of("n_1", 4), 0, + null, 0, 0 ); @@ -506,18 +518,18 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation() { new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) ); List deployments = List.of( - new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, 0, 0), - new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), - new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, 0, 0), - new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, 0, 0), - new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, 0, 0), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, null, 0, 0), + new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, null, 0, 0), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -550,10 +562,11 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_1", 5), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), - new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0), new Deployment( "m_3", ByteSizeValue.ofMb(50).getBytes(), @@ -561,6 +574,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ), @@ -571,6 +585,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_3", 2), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), @@ -581,6 +596,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(800).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), @@ -591,6 +607,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(50).getBytes(), ByteSizeValue.ofMb(20).getBytes() ), @@ -601,14 +618,15 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_2", 6), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -718,6 +736,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode m.threadsPerAllocation(), previousAssignments, 0, + null, 0, 0 ) @@ -741,10 +760,11 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss 1, Map.of("n_1", 2, "n_2", 1), 0, + null, 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2)) .computePlan(); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); @@ -776,6 +796,7 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 1, Map.of("n_1", 2, "n_2", 1), 3, + null, 0, 0 ); @@ -786,6 +807,7 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 2, Map.of(), 1, + null, 0, 0 ); @@ -807,8 +829,8 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, 0, 0); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, null, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, null, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan(); @@ -818,9 +840,9 @@ public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -860,9 +882,9 @@ public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -931,9 +953,9 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels() { // Ensure that plan is removing previously allocated models if not enough memory is available Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); // Create a plan where all deployments are assigned at least once AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) @@ -965,6 +987,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -975,6 +998,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(150).getBytes() ); @@ -985,6 +1009,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -1028,6 +1053,7 @@ public static List createModelsFromPlan(AssignmentPlan plan) { m.threadsPerAllocation(), currentAllocations, Math.max(m.maxAssignedAllocations(), totalAllocations), + null, 0, 0 ) @@ -1096,6 +1122,7 @@ public static Deployment randomModel(String idSuffix) { randomIntBetween(1, 4), Map.of(), 0, + null, 0, 0 ); @@ -1107,6 +1134,7 @@ public static Deployment randomModel(String idSuffix) { randomIntBetween(1, 4), Map.of(), 0, + null, randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()) ); @@ -1137,7 +1165,7 @@ private void runTooManyNodesAndModels(int nodesSize, int modelsSize) { } List deployments = new ArrayList<>(); for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, 0, 0)); + deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, null, 0, 0)); } // Check plan is computed without OOM exception diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index 7f83df5835494..9885c4d583198 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -25,8 +25,8 @@ public class PreserveAllAllocationsTests extends ESTestCase { public void testGivenNoPreviousAssignments() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) @@ -45,10 +45,21 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, + 0, + 0 + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + null, 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) @@ -117,6 +128,7 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -127,6 +139,7 @@ public void testGivenPreviousAssignments() { 4, Map.of("n_1", 1, "n_2", 2), 3, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -195,7 +208,7 @@ public void testGivenPreviousAssignments() { public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index d2907eb31160b..50ba8763c690d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -26,8 +26,8 @@ public class PreserveOneAllocationTests extends ESTestCase { public void testGivenNoPreviousAssignments() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); @@ -42,8 +42,18 @@ public void testGivenPreviousAssignments() { // old memory format Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, null, 0, 0); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + null, + 0, + 0 + ); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( List.of(node1, node2), List.of(deployment1, deployment2) @@ -117,6 +127,7 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -127,6 +138,7 @@ public void testGivenPreviousAssignments() { 4, Map.of("n_1", 1, "n_2", 2), 3, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -199,7 +211,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments { // old memory format Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); @@ -221,6 +233,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments 2, Map.of("n_1", 2), 2, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index 651e4764cb894..4993600d0d3b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase { public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { Node node = new Node("n_1", 100, 1); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -52,6 +52,7 @@ public void testGivenOneModel_OneNode_OneZone_FullyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -70,6 +71,7 @@ public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -91,6 +93,7 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { 2, Map.of(), 0, + null, 0, 0 ); @@ -118,6 +121,7 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -144,6 +148,7 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { 3, Map.of(), 0, + null, 0, 0 ); @@ -168,9 +173,9 @@ public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, null, 0, 0); Map, List> nodesByZone = Map.of( List.of("z_1"), @@ -216,8 +221,8 @@ public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)), @@ -255,6 +260,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode m.threadsPerAllocation(), previousAssignments, 0, + null, 0, 0 ) @@ -270,9 +276,9 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 860da3140f4fe..7eb9d7e940dda 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -276,10 +276,15 @@ public void testsTimeDependentStats() { var timeSupplier = new TimeSupplier(resultTimestamps); var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier); + for (int i = 0; i < 10; i++) { + processor.registerRequest("foo" + i, ActionListener.noop()); + } + // 1st period - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo0", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo1", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo2", false, 200L)); + // first call has no results as is in the same period var stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); @@ -293,7 +298,7 @@ public void testsTimeDependentStats() { assertThat(stats.peakThroughput(), equalTo(3L)); // 2nd period - processor.processInferenceResult(wrapInferenceResult("foo", false, 100L)); + processor.processInferenceResult(wrapInferenceResult("foo3", false, 100L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -305,7 +310,7 @@ public void testsTimeDependentStats() { assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); // 4th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 300L)); + processor.processInferenceResult(wrapInferenceResult("foo4", false, 300L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -313,8 +318,8 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9]))); // 7th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 410L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 390L)); + processor.processInferenceResult(wrapInferenceResult("foo5", false, 410L)); + processor.processInferenceResult(wrapInferenceResult("foo6", false, 390L)); stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); assertThat(stats.recentStats().avgInferenceTime(), nullValue()); @@ -325,9 +330,9 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12]))); // 8th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 510L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 500L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 490L)); + processor.processInferenceResult(wrapInferenceResult("foo7", false, 510L)); + processor.processInferenceResult(wrapInferenceResult("foo8", false, 500L)); + processor.processInferenceResult(wrapInferenceResult("foo9", false, 490L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(3L)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index fef9b07429702..c3ad54427f70c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -133,7 +133,8 @@ public void testNodeLoadDetection() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test")) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java index 2bb10d66d3d58..cce6b284a524d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java @@ -30,7 +30,7 @@ public void testNumberOfAllocationInParam() { assertThat(actionRequest, instanceOf(UpdateTrainedModelDeploymentAction.Request.class)); var request = (UpdateTrainedModelDeploymentAction.Request) actionRequest; - assertEquals(request.getNumberOfAllocations(), 5); + assertEquals(request.getNumberOfAllocations().intValue(), 5); executeCalled.set(true); return mock(CreateTrainedModelAssignmentAction.Response.class); @@ -53,7 +53,7 @@ public void testNumberOfAllocationInBody() { assertThat(actionRequest, instanceOf(UpdateTrainedModelDeploymentAction.Request.class)); var request = (UpdateTrainedModelDeploymentAction.Request) actionRequest; - assertEquals(request.getNumberOfAllocations(), 6); + assertEquals(request.getNumberOfAllocations().intValue(), 6); executeCalled.set(true); return mock(CreateTrainedModelAssignmentAction.Response.class); From d5958cd72e154f5675d97a574a1b108ee14ec637 Mon Sep 17 00:00:00 2001 From: Ievgen Degtiarenko Date: Tue, 9 Jul 2024 11:57:08 +0200 Subject: [PATCH 38/64] Reword exception message (#110481) Rewords the exception message to make it clear the documents limit is per shard, not per index. --- .../java/org/elasticsearch/index/engine/MaxDocsLimitIT.java | 4 ++-- .../java/org/elasticsearch/index/engine/InternalEngine.java | 2 +- .../org/elasticsearch/index/engine/InternalEngineTests.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java index d475208d7e1ff..be7610e55b8e6 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java @@ -102,7 +102,7 @@ public void testMaxDocsLimit() throws Exception { assertThat(indexingResult.numFailures, equalTo(rejectedRequests)); assertThat(indexingResult.numSuccess, equalTo(0)); final IllegalArgumentException deleteError = expectThrows(IllegalArgumentException.class, client().prepareDelete("test", "any-id")); - assertThat(deleteError.getMessage(), containsString("Number of documents in the index can't exceed [" + maxDocs.get() + "]")); + assertThat(deleteError.getMessage(), containsString("Number of documents in the shard cannot exceed [" + maxDocs.get() + "]")); indicesAdmin().prepareRefresh("test").get(); assertNoFailuresAndResponse( prepareSearch("test").setQuery(new MatchAllQueryBuilder()).setTrackTotalHitsUpTo(Integer.MAX_VALUE).setSize(0), @@ -162,7 +162,7 @@ static IndexingResult indexDocs(int numRequests, int numThreads) throws Exceptio assertThat(resp.status(), equalTo(RestStatus.CREATED)); } catch (IllegalArgumentException e) { numFailure.incrementAndGet(); - assertThat(e.getMessage(), containsString("Number of documents in the index can't exceed [" + maxDocs.get() + "]")); + assertThat(e.getMessage(), containsString("Number of documents in the shard cannot exceed [" + maxDocs.get() + "]")); } } }); diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index a991c5544a1e1..88712a6cf28d2 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -1688,7 +1688,7 @@ private Exception tryAcquireInFlightDocs(Operation operation, int addingDocs) { final long totalDocs = indexWriter.getPendingNumDocs() + inFlightDocCount.addAndGet(addingDocs); if (totalDocs > maxDocs) { releaseInFlightDocs(addingDocs); - return new IllegalArgumentException("Number of documents in the index can't exceed [" + maxDocs + "]"); + return new IllegalArgumentException("Number of documents in the shard cannot exceed [" + maxDocs + "]"); } else { return null; } diff --git a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java index a89ac5bc5b74e..c08e47ea906c3 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java @@ -7417,7 +7417,7 @@ public void testMaxDocsOnPrimary() throws Exception { assertNotNull(result.getFailure()); assertThat( result.getFailure().getMessage(), - containsString("Number of documents in the index can't exceed [" + maxDocs + "]") + containsString("Number of documents in the shard cannot exceed [" + maxDocs + "]") ); assertThat(result.getSeqNo(), equalTo(UNASSIGNED_SEQ_NO)); assertThat(engine.getLocalCheckpointTracker().getMaxSeqNo(), equalTo(maxSeqNo)); From 38cd0b333e5372b724ab9084a9b5c849a0997804 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 9 Jul 2024 12:01:46 +0200 Subject: [PATCH 39/64] ESQL: AVG aggregation tests and ignore complex surrogates (#110579) Some work around aggregation tests, with AVG as an example: - Added tests and autogenerated docs for AVG - As AVG uses "complex" surrogates (A combination of functions), we can't trivially execute them without a complete plan. As I'm not sure it's worth it for most aggregations, I'm skipping those cases for now, as to avoid blocking other aggs tests. The bad side effect of skipping those tests is that most tests in AvgTests are actually ignored (74 of 100) --- .../functions/aggregation-functions.asciidoc | 4 +- docs/reference/esql/functions/avg.asciidoc | 47 --------- .../esql/functions/description/avg.asciidoc | 5 + .../esql/functions/examples/avg.asciidoc | 22 +++++ .../esql/functions/kibana/definition/avg.json | 48 ++++++++++ .../esql/functions/kibana/docs/avg.md | 11 +++ .../esql/functions/layout/avg.asciidoc | 15 +++ .../esql/functions/parameters/avg.asciidoc | 6 ++ .../esql/functions/signature/avg.svg | 1 + .../esql/functions/types/avg.asciidoc | 11 +++ .../expression/function/aggregate/Avg.java | 16 +++- .../function/AbstractAggregationTestCase.java | 9 ++ .../function/aggregate/AvgTests.java | 95 +++++++++++++++++++ .../function/aggregate/TopTests.java | 10 +- 14 files changed, 246 insertions(+), 54 deletions(-) delete mode 100644 docs/reference/esql/functions/avg.asciidoc create mode 100644 docs/reference/esql/functions/description/avg.asciidoc create mode 100644 docs/reference/esql/functions/examples/avg.asciidoc create mode 100644 docs/reference/esql/functions/kibana/definition/avg.json create mode 100644 docs/reference/esql/functions/kibana/docs/avg.md create mode 100644 docs/reference/esql/functions/layout/avg.asciidoc create mode 100644 docs/reference/esql/functions/parameters/avg.asciidoc create mode 100644 docs/reference/esql/functions/signature/avg.svg create mode 100644 docs/reference/esql/functions/types/avg.asciidoc create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 11fcd576d336e..7bd2fa08b7c7c 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -8,7 +8,7 @@ The <> command supports these aggregate functions: // tag::agg_list[] -* <> +* <> * <> * <> * <> @@ -23,7 +23,6 @@ The <> command supports these aggregate functions: * experimental:[] <> // end::agg_list[] -include::avg.asciidoc[] include::count.asciidoc[] include::count-distinct.asciidoc[] include::max.asciidoc[] @@ -33,6 +32,7 @@ include::min.asciidoc[] include::percentile.asciidoc[] include::st_centroid_agg.asciidoc[] include::sum.asciidoc[] +include::layout/avg.asciidoc[] include::layout/top.asciidoc[] include::values.asciidoc[] include::weighted-avg.asciidoc[] diff --git a/docs/reference/esql/functions/avg.asciidoc b/docs/reference/esql/functions/avg.asciidoc deleted file mode 100644 index 7eadff29f1bfc..0000000000000 --- a/docs/reference/esql/functions/avg.asciidoc +++ /dev/null @@ -1,47 +0,0 @@ -[discrete] -[[esql-agg-avg]] -=== `AVG` - -*Syntax* - -[source,esql] ----- -AVG(expression) ----- - -`expression`:: -Numeric expression. -//If `null`, the function returns `null`. -// TODO: Remove comment when https://github.com/elastic/elasticsearch/issues/104900 is fixed. - -*Description* - -The average of a numeric expression. - -*Supported types* - -The result is always a `double` no matter the input type. - -*Examples* - -[source.merge.styled,esql] ----- -include::{esql-specs}/stats.csv-spec[tag=avg] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/stats.csv-spec[tag=avg-result] -|=== - -The expression can use inline functions. For example, to calculate the average -over a multivalued column, first use `MV_AVG` to average the multiple values per -row, and use the result with the `AVG` function: - -[source.merge.styled,esql] ----- -include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result] -|=== diff --git a/docs/reference/esql/functions/description/avg.asciidoc b/docs/reference/esql/functions/description/avg.asciidoc new file mode 100644 index 0000000000000..545d7e8394e8b --- /dev/null +++ b/docs/reference/esql/functions/description/avg.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The average of a numeric field. diff --git a/docs/reference/esql/functions/examples/avg.asciidoc b/docs/reference/esql/functions/examples/avg.asciidoc new file mode 100644 index 0000000000000..b6193ad50ed21 --- /dev/null +++ b/docs/reference/esql/functions/examples/avg.asciidoc @@ -0,0 +1,22 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Examples* + +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=avg] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=avg-result] +|=== +The expression can use inline functions. For example, to calculate the average over a multivalued column, first use `MV_AVG` to average the multiple values per row, and use the result with the `AVG` function +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/avg.json b/docs/reference/esql/functions/kibana/definition/avg.json new file mode 100644 index 0000000000000..eb0be684a468e --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/avg.json @@ -0,0 +1,48 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "avg", + "description" : "The average of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "FROM employees\n| STATS AVG(height)", + "FROM employees\n| STATS avg_salary_change = ROUND(AVG(MV_AVG(salary_change)), 10)" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/avg.md b/docs/reference/esql/functions/kibana/docs/avg.md new file mode 100644 index 0000000000000..54006a0556175 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/avg.md @@ -0,0 +1,11 @@ + + +### AVG +The average of a numeric field. + +``` +FROM employees +| STATS AVG(height) +``` diff --git a/docs/reference/esql/functions/layout/avg.asciidoc b/docs/reference/esql/functions/layout/avg.asciidoc new file mode 100644 index 0000000000000..8292af8e75554 --- /dev/null +++ b/docs/reference/esql/functions/layout/avg.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-avg]] +=== `AVG` + +*Syntax* + +[.text-center] +image::esql/functions/signature/avg.svg[Embedded,opts=inline] + +include::../parameters/avg.asciidoc[] +include::../description/avg.asciidoc[] +include::../types/avg.asciidoc[] +include::../examples/avg.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/avg.asciidoc b/docs/reference/esql/functions/parameters/avg.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/avg.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/signature/avg.svg b/docs/reference/esql/functions/signature/avg.svg new file mode 100644 index 0000000000000..f325358aff960 --- /dev/null +++ b/docs/reference/esql/functions/signature/avg.svg @@ -0,0 +1 @@ +AVG(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/avg.asciidoc b/docs/reference/esql/functions/types/avg.asciidoc new file mode 100644 index 0000000000000..273dae4af76c2 --- /dev/null +++ b/docs/reference/esql/functions/types/avg.asciidoc @@ -0,0 +1,11 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | double +long | double +|=== diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index cb70b73117397..b5c0b8e5ffdc8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; @@ -28,7 +29,20 @@ public class Avg extends AggregateFunction implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new); - @FunctionInfo(returnType = "double", description = "The average of a numeric field.", isAggregation = true) + @FunctionInfo( + returnType = "double", + description = "The average of a numeric field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "avg"), + @Example( + description = "The expression can use inline functions. For example, to calculate the average " + + "over a multivalued column, first use `MV_AVG` to average the multiple values per row, " + + "and use the result with the `AVG` function", + file = "stats", + tag = "docsStatsAvgNestedExpression" + ) } + ) public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { super(source, field); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index e20b9a987f5ef..4fcbde6573f92 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -15,6 +15,8 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; @@ -251,6 +253,13 @@ private void resolveExpression(Expression expression, Consumer onAgg expression = new FoldNull().rule(expression); assertThat(expression.dataType(), equalTo(testCase.expectedType())); + assumeTrue( + "Surrogate expression with non-trivial children cannot be evaluated", + expression.children() + .stream() + .allMatch(child -> child instanceof FieldAttribute || child instanceof DeepCopy || child instanceof Literal) + ); + if (expression instanceof AggregateFunction == false) { onEvaluableExpression.accept(expression); return; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java new file mode 100644 index 0000000000000..f456bd409059a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class AvgTests extends AbstractAggregationTestCase { + public AvgTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + ).flatMap(List::stream).map(AvgTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.add( + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Avg[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Avg(source, args.get(0)); + } + + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + + Object expected = switch (fieldTypedData.type().widenSmallNumeric()) { + case INTEGER -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Integer) v) + .collect(Collectors.summarizingInt(Integer::intValue)) + .getAverage(); + case LONG -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Long) v) + .collect(Collectors.summarizingLong(Long::longValue)) + .getAverage(); + case DOUBLE -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Double) v) + .collect(Collectors.summarizingDouble(Double::doubleValue)) + .getAverage(); + default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); + }; + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Avg[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index 00457f46266d8..b7b7e7ce84756 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -22,6 +22,7 @@ import java.util.Comparator; import java.util.List; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -37,14 +38,15 @@ public static Iterable parameters() { for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) { for (String order : List.of("asc", "desc")) { - for (var fieldCaseSupplier : Stream.of( + Stream.of( MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), MultiRowTestCaseSupplier.dateCases(1, 1000) - ).flatMap(List::stream).toList()) { - suppliers.add(TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order)); - } + ) + .flatMap(List::stream) + .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order)) + .collect(Collectors.toCollection(() -> suppliers)); } } From 5d3512fb33fea104c0581e6463fd8169b1f02f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Tue, 9 Jul 2024 13:05:00 +0200 Subject: [PATCH 40/64] ESQL: Fix Max doubles bug with negatives and add tests for Max and Min (#110586) `MAX()` currently doesn't work with doubles smaller than `Double.MIN_VALUE` (Note that `Double.MIN_VALUE` returns the smallest non-zero positive, not the smallest double). This PR adds tests for Max and Min, and fixes the bug (Detected by the tests). Also, as the tests now generate the docs, replaced the old docs with the generated ones, and updated the Max&Min examples. --- docs/changelog/110586.yaml | 5 + .../functions/aggregation-functions.asciidoc | 8 +- .../esql/functions/description/max.asciidoc | 5 + .../esql/functions/description/min.asciidoc | 5 + .../functions/{ => examples}/max.asciidoc | 31 +--- .../functions/{ => examples}/min.asciidoc | 29 +--- .../esql/functions/kibana/definition/max.json | 60 +++++++ .../esql/functions/kibana/definition/min.json | 60 +++++++ .../esql/functions/kibana/docs/max.md | 11 ++ .../esql/functions/kibana/docs/min.md | 11 ++ .../esql/functions/layout/max.asciidoc | 15 ++ .../esql/functions/layout/min.asciidoc | 15 ++ .../esql/functions/parameters/max.asciidoc | 6 + .../esql/functions/parameters/min.asciidoc | 6 + .../esql/functions/signature/max.svg | 1 + .../esql/functions/signature/min.svg | 1 + .../esql/functions/types/max.asciidoc | 12 ++ .../esql/functions/types/min.asciidoc | 12 ++ .../aggregation/MaxDoubleAggregator.java | 2 +- .../expression/function/aggregate/Max.java | 12 +- .../expression/function/aggregate/Min.java | 12 +- .../function/AbstractAggregationTestCase.java | 26 +-- .../function/AbstractFunctionTestCase.java | 49 +++--- .../function/aggregate/MaxTests.java | 151 ++++++++++++++++++ .../function/aggregate/MinTests.java | 151 ++++++++++++++++++ 25 files changed, 609 insertions(+), 87 deletions(-) create mode 100644 docs/changelog/110586.yaml create mode 100644 docs/reference/esql/functions/description/max.asciidoc create mode 100644 docs/reference/esql/functions/description/min.asciidoc rename docs/reference/esql/functions/{ => examples}/max.asciidoc (55%) rename docs/reference/esql/functions/{ => examples}/min.asciidoc (55%) create mode 100644 docs/reference/esql/functions/kibana/definition/max.json create mode 100644 docs/reference/esql/functions/kibana/definition/min.json create mode 100644 docs/reference/esql/functions/kibana/docs/max.md create mode 100644 docs/reference/esql/functions/kibana/docs/min.md create mode 100644 docs/reference/esql/functions/layout/max.asciidoc create mode 100644 docs/reference/esql/functions/layout/min.asciidoc create mode 100644 docs/reference/esql/functions/parameters/max.asciidoc create mode 100644 docs/reference/esql/functions/parameters/min.asciidoc create mode 100644 docs/reference/esql/functions/signature/max.svg create mode 100644 docs/reference/esql/functions/signature/min.svg create mode 100644 docs/reference/esql/functions/types/max.asciidoc create mode 100644 docs/reference/esql/functions/types/min.asciidoc create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java diff --git a/docs/changelog/110586.yaml b/docs/changelog/110586.yaml new file mode 100644 index 0000000000000..cc2bcb85a2dac --- /dev/null +++ b/docs/changelog/110586.yaml @@ -0,0 +1,5 @@ +pr: 110586 +summary: "ESQL: Fix Max doubles bug with negatives and add tests for Max and Min" +area: ES|QL +type: bug +issues: [] diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 7bd2fa08b7c7c..82931b84fd44a 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -11,10 +11,10 @@ The <> command supports these aggregate functions: * <> * <> * <> -* <> +* <> * <> * <> -* <> +* <> * <> * experimental:[] <> * <> @@ -25,14 +25,14 @@ The <> command supports these aggregate functions: include::count.asciidoc[] include::count-distinct.asciidoc[] -include::max.asciidoc[] include::median.asciidoc[] include::median-absolute-deviation.asciidoc[] -include::min.asciidoc[] include::percentile.asciidoc[] include::st_centroid_agg.asciidoc[] include::sum.asciidoc[] include::layout/avg.asciidoc[] +include::layout/max.asciidoc[] +include::layout/min.asciidoc[] include::layout/top.asciidoc[] include::values.asciidoc[] include::weighted-avg.asciidoc[] diff --git a/docs/reference/esql/functions/description/max.asciidoc b/docs/reference/esql/functions/description/max.asciidoc new file mode 100644 index 0000000000000..ffc15dcd4c8bd --- /dev/null +++ b/docs/reference/esql/functions/description/max.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The maximum value of a numeric field. diff --git a/docs/reference/esql/functions/description/min.asciidoc b/docs/reference/esql/functions/description/min.asciidoc new file mode 100644 index 0000000000000..4f640854dbd37 --- /dev/null +++ b/docs/reference/esql/functions/description/min.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The minimum value of a numeric field. diff --git a/docs/reference/esql/functions/max.asciidoc b/docs/reference/esql/functions/examples/max.asciidoc similarity index 55% rename from docs/reference/esql/functions/max.asciidoc rename to docs/reference/esql/functions/examples/max.asciidoc index f2e0d0a0205b3..dc57118931ef7 100644 --- a/docs/reference/esql/functions/max.asciidoc +++ b/docs/reference/esql/functions/examples/max.asciidoc @@ -1,24 +1,6 @@ -[discrete] -[[esql-agg-max]] -=== `MAX` +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. -*Syntax* - -[source,esql] ----- -MAX(expression) ----- - -*Parameters* - -`expression`:: -Expression from which to return the maximum value. - -*Description* - -Returns the maximum value of a numeric expression. - -*Example* +*Examples* [source.merge.styled,esql] ---- @@ -28,11 +10,7 @@ include::{esql-specs}/stats.csv-spec[tag=max] |=== include::{esql-specs}/stats.csv-spec[tag=max-result] |=== - -The expression can use inline functions. For example, to calculate the maximum -over an average of a multivalued column, use `MV_AVG` to first average the -multiple values per row, and use the result with the `MAX` function: - +The expression can use inline functions. For example, to calculate the maximum over an average of a multivalued column, use `MV_AVG` to first average the multiple values per row, and use the result with the `MAX` function [source.merge.styled,esql] ---- include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression] @@ -40,4 +18,5 @@ include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression] [%header.monospaced.styled,format=dsv,separator=|] |=== include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression-result] -|=== \ No newline at end of file +|=== + diff --git a/docs/reference/esql/functions/min.asciidoc b/docs/reference/esql/functions/examples/min.asciidoc similarity index 55% rename from docs/reference/esql/functions/min.asciidoc rename to docs/reference/esql/functions/examples/min.asciidoc index 313822818128c..b4088196d750b 100644 --- a/docs/reference/esql/functions/min.asciidoc +++ b/docs/reference/esql/functions/examples/min.asciidoc @@ -1,24 +1,6 @@ -[discrete] -[[esql-agg-min]] -=== `MIN` +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. -*Syntax* - -[source,esql] ----- -MIN(expression) ----- - -*Parameters* - -`expression`:: -Expression from which to return the minimum value. - -*Description* - -Returns the minimum value of a numeric expression. - -*Example* +*Examples* [source.merge.styled,esql] ---- @@ -28,11 +10,7 @@ include::{esql-specs}/stats.csv-spec[tag=min] |=== include::{esql-specs}/stats.csv-spec[tag=min-result] |=== - -The expression can use inline functions. For example, to calculate the minimum -over an average of a multivalued column, use `MV_AVG` to first average the -multiple values per row, and use the result with the `MIN` function: - +The expression can use inline functions. For example, to calculate the minimum over an average of a multivalued column, use `MV_AVG` to first average the multiple values per row, and use the result with the `MIN` function [source.merge.styled,esql] ---- include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression] @@ -41,3 +19,4 @@ include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression] |=== include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression-result] |=== + diff --git a/docs/reference/esql/functions/kibana/definition/max.json b/docs/reference/esql/functions/kibana/definition/max.json new file mode 100644 index 0000000000000..aaa765ea79ce4 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/max.json @@ -0,0 +1,60 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "max", + "description" : "The maximum value of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "datetime", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "datetime" + }, + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "long" + } + ], + "examples" : [ + "FROM employees\n| STATS MAX(languages)", + "FROM employees\n| STATS max_avg_salary_change = MAX(MV_AVG(salary_change))" + ] +} diff --git a/docs/reference/esql/functions/kibana/definition/min.json b/docs/reference/esql/functions/kibana/definition/min.json new file mode 100644 index 0000000000000..ff48c87ecb8ea --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/min.json @@ -0,0 +1,60 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "min", + "description" : "The minimum value of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "datetime", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "datetime" + }, + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "long" + } + ], + "examples" : [ + "FROM employees\n| STATS MIN(languages)", + "FROM employees\n| STATS min_avg_salary_change = MIN(MV_AVG(salary_change))" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/max.md b/docs/reference/esql/functions/kibana/docs/max.md new file mode 100644 index 0000000000000..9bda0fbbe972d --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/max.md @@ -0,0 +1,11 @@ + + +### MAX +The maximum value of a numeric field. + +``` +FROM employees +| STATS MAX(languages) +``` diff --git a/docs/reference/esql/functions/kibana/docs/min.md b/docs/reference/esql/functions/kibana/docs/min.md new file mode 100644 index 0000000000000..100abf0260d0d --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/min.md @@ -0,0 +1,11 @@ + + +### MIN +The minimum value of a numeric field. + +``` +FROM employees +| STATS MIN(languages) +``` diff --git a/docs/reference/esql/functions/layout/max.asciidoc b/docs/reference/esql/functions/layout/max.asciidoc new file mode 100644 index 0000000000000..a4eb3d99c0d02 --- /dev/null +++ b/docs/reference/esql/functions/layout/max.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-max]] +=== `MAX` + +*Syntax* + +[.text-center] +image::esql/functions/signature/max.svg[Embedded,opts=inline] + +include::../parameters/max.asciidoc[] +include::../description/max.asciidoc[] +include::../types/max.asciidoc[] +include::../examples/max.asciidoc[] diff --git a/docs/reference/esql/functions/layout/min.asciidoc b/docs/reference/esql/functions/layout/min.asciidoc new file mode 100644 index 0000000000000..60ad2cc21b561 --- /dev/null +++ b/docs/reference/esql/functions/layout/min.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-min]] +=== `MIN` + +*Syntax* + +[.text-center] +image::esql/functions/signature/min.svg[Embedded,opts=inline] + +include::../parameters/min.asciidoc[] +include::../description/min.asciidoc[] +include::../types/min.asciidoc[] +include::../examples/min.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/max.asciidoc b/docs/reference/esql/functions/parameters/max.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/max.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/parameters/min.asciidoc b/docs/reference/esql/functions/parameters/min.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/min.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/signature/max.svg b/docs/reference/esql/functions/signature/max.svg new file mode 100644 index 0000000000000..cfc7bfda2c0a0 --- /dev/null +++ b/docs/reference/esql/functions/signature/max.svg @@ -0,0 +1 @@ +MAX(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/signature/min.svg b/docs/reference/esql/functions/signature/min.svg new file mode 100644 index 0000000000000..31660b1490e7e --- /dev/null +++ b/docs/reference/esql/functions/signature/min.svg @@ -0,0 +1 @@ +MIN(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/max.asciidoc b/docs/reference/esql/functions/types/max.asciidoc new file mode 100644 index 0000000000000..cec61a56db87a --- /dev/null +++ b/docs/reference/esql/functions/types/max.asciidoc @@ -0,0 +1,12 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +datetime | datetime +double | double +integer | integer +long | long +|=== diff --git a/docs/reference/esql/functions/types/min.asciidoc b/docs/reference/esql/functions/types/min.asciidoc new file mode 100644 index 0000000000000..cec61a56db87a --- /dev/null +++ b/docs/reference/esql/functions/types/min.asciidoc @@ -0,0 +1,12 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +datetime | datetime +double | double +integer | integer +long | long +|=== diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java index ee6555c4af67d..f0804278e5002 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java @@ -16,7 +16,7 @@ class MaxDoubleAggregator { public static double init() { - return Double.MIN_VALUE; + return -Double.MAX_VALUE; } public static double combine(double current, double v) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java index 97a6f6b4b5e1f..44954a1cfea8b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; @@ -31,7 +32,16 @@ public class Max extends NumericAggregate implements SurrogateExpression { @FunctionInfo( returnType = { "double", "integer", "long", "date" }, description = "The maximum value of a numeric field.", - isAggregation = true + isAggregation = true, + examples = { + @Example(file = "stats", tag = "max"), + @Example( + description = "The expression can use inline functions. For example, to calculate the maximum " + + "over an average of a multivalued column, use `MV_AVG` to first average the " + + "multiple values per row, and use the result with the `MAX` function", + file = "stats", + tag = "docsStatsMaxNestedExpression" + ) } ) public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { super(source, field); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java index 2dd3e973937f5..b9f71d86a6fb1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; @@ -31,7 +32,16 @@ public class Min extends NumericAggregate implements SurrogateExpression { @FunctionInfo( returnType = { "double", "integer", "long", "date" }, description = "The minimum value of a numeric field.", - isAggregation = true + isAggregation = true, + examples = { + @Example(file = "stats", tag = "min"), + @Example( + description = "The expression can use inline functions. For example, to calculate the minimum " + + "over an average of a multivalued column, use `MV_AVG` to first average the " + + "multiple values per row, and use the result with the `MIN` function", + file = "stats", + tag = "docsStatsMinNestedExpression" + ) } ) public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { super(source, field); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 4fcbde6573f92..792c6b5139796 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -133,11 +133,12 @@ public void testFold() { private void aggregateSingleMode(Expression expression) { Object result; try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) { - Page inputPage = rows(testCase.getMultiRowFields()); - try { - aggregator.processPage(inputPage); - } finally { - inputPage.releaseBlocks(); + for (Page inputPage : rows(testCase.getMultiRowFields())) { + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } } result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType())); @@ -166,11 +167,12 @@ private void aggregateWithIntermediates(Expression expression) { int intermediateBlockExtraSize = randomIntBetween(0, 10); intermediateBlocks = new Block[intermediateBlockOffset + intermediateStates + intermediateBlockExtraSize]; - Page inputPage = rows(testCase.getMultiRowFields()); - try { - aggregator.processPage(inputPage); - } finally { - inputPage.releaseBlocks(); + for (Page inputPage : rows(testCase.getMultiRowFields())) { + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } } aggregator.evaluate(intermediateBlocks, intermediateBlockOffset, driverContext()); @@ -197,7 +199,9 @@ private void aggregateWithIntermediates(Expression expression) { ) { Page inputPage = new Page(intermediateBlocks); try { - aggregator.processPage(inputPage); + if (inputPage.getPositionCount() > 0) { + aggregator.processPage(inputPage); + } } finally { inputPage.releaseBlocks(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index f8a5d997f4c54..80dc2e434ab0f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -215,11 +215,11 @@ protected final Page row(List values) { } /** - * Creates a page based on a list of multi-row fields. + * Creates a list of pages based on a list of multi-row fields. */ - protected final Page rows(List multirowFields) { + protected final List rows(List multirowFields) { if (multirowFields.isEmpty()) { - return new Page(0, BlockUtils.NO_BLOCKS); + return List.of(); } var rowsCount = multirowFields.get(0).multiRowData().size(); @@ -230,27 +230,40 @@ protected final Page rows(List multirowFields) { field -> assertThat("All multi-row fields must have the same number of rows", field.multiRowData(), hasSize(rowsCount)) ); - var blocks = new Block[multirowFields.size()]; + List pages = new ArrayList<>(); - for (int i = 0; i < multirowFields.size(); i++) { - var field = multirowFields.get(i); - try ( - var wrapper = BlockUtils.wrapperFor( - TestBlockFactory.getNonBreakingInstance(), - PlannerUtils.toElementType(field.type()), - rowsCount - ) - ) { + int pageSize = randomIntBetween(1, 100); + for (int initialRow = 0; initialRow < rowsCount;) { + if (pageSize > rowsCount - initialRow) { + pageSize = rowsCount - initialRow; + } - for (var row : field.multiRowData()) { - wrapper.accept(row); - } + var blocks = new Block[multirowFields.size()]; - blocks[i] = wrapper.builder().build(); + for (int i = 0; i < multirowFields.size(); i++) { + var field = multirowFields.get(i); + try ( + var wrapper = BlockUtils.wrapperFor( + TestBlockFactory.getNonBreakingInstance(), + PlannerUtils.toElementType(field.type()), + pageSize + ) + ) { + var multiRowData = field.multiRowData(); + for (int row = initialRow; row < initialRow + pageSize; row++) { + wrapper.accept(multiRowData.get(row)); + } + + blocks[i] = wrapper.builder().build(); + } } + + pages.add(new Page(pageSize, blocks)); + initialRow += pageSize; + pageSize = randomIntBetween(1, 100); } - return new Page(rowsCount, blocks); + return pages; } /** diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java new file mode 100644 index 0000000000000..ddff3bc3a8138 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MaxTests extends AbstractAggregationTestCase { + public MaxTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000) + ).flatMap(List::stream).map(MaxTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), + "Max[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), + "Max[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Max[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")), + "Max[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Max(source, args.get(0)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .max(Comparator.naturalOrder()) + .orElse(null); + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Max[field=Attribute[channel=0]]", + fieldSupplier.type(), + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java new file mode 100644 index 0000000000000..fdacf448d52a0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MinTests extends AbstractAggregationTestCase { + public MinTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000) + ).flatMap(List::stream).map(MinTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), + "Min[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(-2) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), + "Min[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(-2L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(-2.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(0L) + ) + ), + + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Min[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")), + "Min[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Min(source, args.get(0)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .min(Comparator.naturalOrder()) + .orElse(null); + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Min[field=Attribute[channel=0]]", + fieldSupplier.type(), + equalTo(expected) + ); + }); + } +} From 8e04af986aaff939190d24cfc793083399a5032d Mon Sep 17 00:00:00 2001 From: Valeriy Khakhutskyy <1292899+valeriy42@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:32:18 +0200 Subject: [PATCH 41/64] [ML] Updated filtering in DetectionRulesIt.testCondition() (#110628) While working on elastic/ml-cpp#2677, I encountered a failure in the integration test DetectionRulesIt.testCondition(). It checks the number of return records. With the new change in ml-cpp the native code returns two more values that have no significant score. I added filtering those out in the integration test code so it continues working as expected. --- .../elasticsearch/xpack/ml/integration/DetectionRulesIT.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java index 8cb13398a70ae..fec85730aaf2b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java @@ -95,6 +95,9 @@ public void testCondition() throws Exception { closeJob(job.getId()); List records = getRecords(job.getId()); + // remove records that are not anomalies + records.removeIf(record -> record.getInitialRecordScore() < 1e-5); + assertThat(records.size(), equalTo(1)); assertThat(records.get(0).getByFieldValue(), equalTo("high")); long firstRecordTimestamp = records.get(0).getTimestamp().getTime(); From 20565137a1193dc44833f732b3aa6575655c42c3 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 9 Jul 2024 07:55:54 -0400 Subject: [PATCH 42/64] ESQL: Remove unused option for lexer util (#110583) This removes the option to match `UPPER_CASE` tokens from a lexer utility. We only match `lower_case` tokens. --- .../core/parser/CaseChangingCharStream.java | 19 ++++++++----------- .../xpack/esql/parser/EsqlParser.java | 2 +- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java index f38daa472ddff..6248004d73dac 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java @@ -18,27 +18,24 @@ /** * This class supports case-insensitive lexing by wrapping an existing - * {@link CharStream} and forcing the lexer to see either upper or - * lowercase characters. Grammar literals should then be either upper or - * lower case such as 'BEGIN' or 'begin'. The text of the character - * stream is unaffected. Example: input 'BeGiN' would match lexer rule - * 'BEGIN' if constructor parameter upper=true but getText() would return - * 'BeGiN'. + * {@link CharStream} and forcing the lexer to see lowercase characters + * Grammar literals should then be lower case such as {@code begin}. + * The text of the character stream is unaffected. + *

Example: input {@code BeGiN} would match lexer rule {@code begin} + * but {@link CharStream#getText} will return {@code BeGiN}. + *

*/ public class CaseChangingCharStream implements CharStream { private final CharStream stream; - private final boolean upper; /** * Constructs a new CaseChangingCharStream wrapping the given {@link CharStream} forcing * all characters to upper case or lower case. * @param stream The stream to wrap. - * @param upper If true force each symbol to upper case, otherwise force to lower. */ - public CaseChangingCharStream(CharStream stream, boolean upper) { + public CaseChangingCharStream(CharStream stream) { this.stream = stream; - this.upper = upper; } @Override @@ -57,7 +54,7 @@ public int LA(int i) { if (c <= 0) { return c; } - return upper ? Character.toUpperCase(c) : Character.toLowerCase(c); + return Character.toLowerCase(c); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java index 70daa5a535fa7..ebbcfa3b2863b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java @@ -51,7 +51,7 @@ private T invokeParser( BiFunction result ) { try { - EsqlBaseLexer lexer = new EsqlBaseLexer(new CaseChangingCharStream(CharStreams.fromString(query), false)); + EsqlBaseLexer lexer = new EsqlBaseLexer(new CaseChangingCharStream(CharStreams.fromString(query))); lexer.removeErrorListeners(); lexer.addErrorListener(ERROR_LISTENER); From 5d26c67d22c5ecd9abd635908c8436d6cc85ef29 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 9 Jul 2024 12:56:10 +0100 Subject: [PATCH 43/64] Convert some internal engine classes to records (#110329) --- .../ReplicaShardAllocatorSyncIdIT.java | 4 +- .../index/shard/IndexShardIT.java | 6 +- .../indices/recovery/IndexRecoveryIT.java | 8 +-- .../index/engine/CombinedDeletionPolicy.java | 4 +- .../index/engine/InternalEngine.java | 14 ++-- .../index/engine/ReadOnlyEngine.java | 4 +- .../index/engine/SafeCommitInfo.java | 11 +--- .../index/seqno/ReplicationTracker.java | 2 +- .../index/seqno/SequenceNumbers.java | 10 +-- .../elasticsearch/index/shard/IndexShard.java | 8 +-- .../RemoveCorruptedShardDataCommand.java | 2 +- .../org/elasticsearch/index/store/Store.java | 2 +- .../index/translog/BaseTranslogReader.java | 6 +- .../index/translog/Translog.java | 57 ++--------------- .../recovery/RecoverySourceHandler.java | 2 +- .../snapshots/SnapshotShardsService.java | 4 +- .../TransportWriteActionTests.java | 13 ++-- .../index/engine/FlushListenersTests.java | 16 ++--- .../index/engine/InternalEngineTests.java | 10 +-- .../RecoveryDuringReplicationTests.java | 4 +- .../index/shard/IndexShardTests.java | 8 +-- .../index/translog/TranslogTests.java | 64 +++++++++---------- .../PeerRecoveryTargetServiceTests.java | 8 +-- .../indices/recovery/RecoveryTests.java | 2 +- 24 files changed, 104 insertions(+), 165 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java b/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java index 27e63e5614744..13886cba9084c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java @@ -100,8 +100,8 @@ void syncFlush(String syncId) throws IOException { assertThat(getTranslogStats().getUncommittedOperations(), equalTo(0)); Map userData = new HashMap<>(getLastCommittedSegmentInfos().userData); SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userData.entrySet()); - assertThat(commitInfo.localCheckpoint, equalTo(getLastSyncedGlobalCheckpoint())); - assertThat(commitInfo.maxSeqNo, equalTo(getLastSyncedGlobalCheckpoint())); + assertThat(commitInfo.localCheckpoint(), equalTo(getLastSyncedGlobalCheckpoint())); + assertThat(commitInfo.maxSeqNo(), equalTo(getLastSyncedGlobalCheckpoint())); userData.put(Engine.SYNC_COMMIT_ID, syncId); indexWriter.setLiveCommitData(userData.entrySet()); indexWriter.commit(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java index b9850bc95275c..5d996e44c6868 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java @@ -164,7 +164,7 @@ public void testDurableFlagHasEffect() { try { // the lastWriteLocaltion has a Integer.MAX_VALUE size so we have to create a new one return tlog.ensureSynced( - new Translog.Location(lastWriteLocation.generation, lastWriteLocation.translogLocation, 0), + new Translog.Location(lastWriteLocation.generation(), lastWriteLocation.translogLocation(), 0), SequenceNumbers.UNASSIGNED_SEQ_NO ); } catch (IOException e) { @@ -389,7 +389,7 @@ public void testMaybeFlush() throws Exception { logger.info( "--> translog stats [{}] gen [{}] commit_stats [{}] flush_stats [{}/{}]", Strings.toString(translogStats), - translog.getGeneration().translogFileGeneration, + translog.getGeneration().translogFileGeneration(), commitStats.getUserData(), flushStats.getPeriodic(), flushStats.getTotal() @@ -428,7 +428,7 @@ public void testMaybeRollTranslogGeneration() throws Exception { ); final Translog.Location location = result.getTranslogLocation(); shard.afterWriteOperation(); - if (location.translogLocation + location.size > generationThreshold) { + if (location.translogLocation() + location.size() > generationThreshold) { // wait until the roll completes assertBusy(() -> assertFalse(shard.shouldRollTranslogGeneration())); rolls++; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java index 204d7131c44d2..d56e4a372c17c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java @@ -1210,8 +1210,8 @@ public void testRecoverLocallyUpToGlobalCheckpoint() throws Exception { SequenceNumbers.CommitInfo commitInfoAfterLocalRecovery = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( startRecoveryRequest.metadataSnapshot().commitUserData().entrySet() ); - assertThat(commitInfoAfterLocalRecovery.localCheckpoint, equalTo(lastSyncedGlobalCheckpoint)); - assertThat(commitInfoAfterLocalRecovery.maxSeqNo, equalTo(lastSyncedGlobalCheckpoint)); + assertThat(commitInfoAfterLocalRecovery.localCheckpoint(), equalTo(lastSyncedGlobalCheckpoint)); + assertThat(commitInfoAfterLocalRecovery.maxSeqNo(), equalTo(lastSyncedGlobalCheckpoint)); assertThat(startRecoveryRequest.startingSeqNo(), equalTo(lastSyncedGlobalCheckpoint + 1)); ensureGreen(indexName); assertThat((long) localRecoveredOps.get(), equalTo(lastSyncedGlobalCheckpoint - localCheckpointOfSafeCommit)); @@ -2011,8 +2011,8 @@ private long getLocalCheckpointOfSafeCommit(IndexCommit safeIndexCommit) throws final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( safeIndexCommit.getUserData().entrySet() ); - final long commitLocalCheckpoint = commitInfo.localCheckpoint; - final long maxSeqNo = commitInfo.maxSeqNo; + final long commitLocalCheckpoint = commitInfo.localCheckpoint(); + final long maxSeqNo = commitInfo.maxSeqNo(); final LocalCheckpointTracker localCheckpointTracker = new LocalCheckpointTracker(maxSeqNo, commitLocalCheckpoint); // In certain scenarios it is possible that the local checkpoint captured during commit lags behind, diff --git a/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java b/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java index a69cc42163dd2..22bab1742589e 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java @@ -153,7 +153,7 @@ private SafeCommitInfo getNewSafeCommitInfo(IndexCommit newSafeCommit) { return currentSafeCommitInfo; } - if (currentSafeCommitInfo.localCheckpoint == newSafeCommitLocalCheckpoint) { + if (currentSafeCommitInfo.localCheckpoint() == newSafeCommitLocalCheckpoint) { // the new commit could in principle have the same LCP but a different doc count due to extra operations between its LCP and // MSN, but that is a transient state since we'll eventually advance the LCP. The doc count is only used for heuristics around // expiring excessively-lagging retention leases, so a little inaccuracy is tolerable here. @@ -164,7 +164,7 @@ private SafeCommitInfo getNewSafeCommitInfo(IndexCommit newSafeCommit) { return new SafeCommitInfo(newSafeCommitLocalCheckpoint, getDocCountOfCommit(newSafeCommit)); } catch (IOException ex) { logger.info("failed to get the total docs from the safe commit; use the total docs from the previous safe commit", ex); - return new SafeCommitInfo(newSafeCommitLocalCheckpoint, currentSafeCommitInfo.docCount); + return new SafeCommitInfo(newSafeCommitLocalCheckpoint, currentSafeCommitInfo.docCount()); } } diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index 88712a6cf28d2..03d244cd8e4ef 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -344,8 +344,8 @@ private LocalCheckpointTracker createLocalCheckpointTracker( final SequenceNumbers.CommitInfo seqNoStats = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( store.readLastCommittedSegmentsInfo().userData.entrySet() ); - maxSeqNo = seqNoStats.maxSeqNo; - localCheckpoint = seqNoStats.localCheckpoint; + maxSeqNo = seqNoStats.maxSeqNo(); + localCheckpoint = seqNoStats.localCheckpoint(); logger.trace("recovered maximum sequence number [{}] and local checkpoint [{}]", maxSeqNo, localCheckpoint); return localCheckpointTrackerSupplier.apply(maxSeqNo, localCheckpoint); } @@ -2143,9 +2143,8 @@ private boolean shouldPeriodicallyFlush(long flushThresholdSizeInBytes, long flu final long localCheckpointOfLastCommit = Long.parseLong( lastCommittedSegmentInfos.userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) ); - final long translogGenerationOfLastCommit = translog.getMinGenerationForSeqNo( - localCheckpointOfLastCommit + 1 - ).translogFileGeneration; + final long translogGenerationOfLastCommit = translog.getMinGenerationForSeqNo(localCheckpointOfLastCommit + 1) + .translogFileGeneration(); if (translog.sizeInBytesByMinGen(translogGenerationOfLastCommit) < flushThresholdSizeInBytes && relativeTimeInNanosSupplier.getAsLong() - lastFlushTimestamp < flushThresholdAgeInNanos) { return false; @@ -2165,9 +2164,8 @@ private boolean shouldPeriodicallyFlush(long flushThresholdSizeInBytes, long flu * * This method is to maintain translog only, thus IndexWriter#hasUncommittedChanges condition is not considered. */ - final long translogGenerationOfNewCommit = translog.getMinGenerationForSeqNo( - localCheckpointTracker.getProcessedCheckpoint() + 1 - ).translogFileGeneration; + final long translogGenerationOfNewCommit = translog.getMinGenerationForSeqNo(localCheckpointTracker.getProcessedCheckpoint() + 1) + .translogFileGeneration(); return translogGenerationOfLastCommit < translogGenerationOfNewCommit || localCheckpointTracker.getProcessedCheckpoint() == localCheckpointTracker.getMaxSeqNo(); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java index eda408a9c8fde..c9474b58ef447 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java @@ -244,8 +244,8 @@ protected void closeNoLock(String reason, CountDownLatch closedLatch) { private static SeqNoStats buildSeqNoStats(EngineConfig config, SegmentInfos infos) { final SequenceNumbers.CommitInfo seqNoStats = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(infos.userData.entrySet()); - long maxSeqNo = seqNoStats.maxSeqNo; - long localCheckpoint = seqNoStats.localCheckpoint; + long maxSeqNo = seqNoStats.maxSeqNo(); + long localCheckpoint = seqNoStats.localCheckpoint(); return new SeqNoStats(maxSeqNo, localCheckpoint, config.getGlobalCheckpointSupplier().getAsLong()); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java b/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java index 6858315f5b37f..5b206ecfd90dc 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java +++ b/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java @@ -12,15 +12,6 @@ /** * Information about the safe commit, for making decisions about recoveries. */ -public class SafeCommitInfo { - - public final long localCheckpoint; - public final int docCount; - - public SafeCommitInfo(long localCheckpoint, int docCount) { - this.localCheckpoint = localCheckpoint; - this.docCount = docCount; - } - +public record SafeCommitInfo(long localCheckpoint, int docCount) { public static final SafeCommitInfo EMPTY = new SafeCommitInfo(SequenceNumbers.NO_OPS_PERFORMED, 0); } diff --git a/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java b/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java index 0b3b15670ef78..247c2fd70761e 100644 --- a/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java +++ b/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java @@ -280,7 +280,7 @@ public synchronized RetentionLeases getRetentionLeases(final boolean expireLease private long getMinimumReasonableRetainedSeqNo() { final SafeCommitInfo safeCommitInfo = safeCommitInfoSupplier.get(); - return safeCommitInfo.localCheckpoint + 1 - Math.round(Math.ceil(safeCommitInfo.docCount * fileBasedRecoveryThreshold)); + return safeCommitInfo.localCheckpoint() + 1 - Math.round(Math.ceil(safeCommitInfo.docCount() * fileBasedRecoveryThreshold)); // NB safeCommitInfo.docCount is a very low-level count of the docs in the index, and in particular if this shard contains nested // docs then safeCommitInfo.docCount counts every child doc separately from the parent doc. However every part of a nested document // has the same seqno, so we may be overestimating the cost of a file-based recovery when compared to an ops-based recovery and diff --git a/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java b/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java index 0cd451f6be2cf..bb4ef40d28129 100644 --- a/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java +++ b/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java @@ -103,15 +103,7 @@ public static long max(final long maxSeqNo, final long seqNo) { } } - public static final class CommitInfo { - public final long maxSeqNo; - public final long localCheckpoint; - - public CommitInfo(long maxSeqNo, long localCheckpoint) { - this.maxSeqNo = maxSeqNo; - this.localCheckpoint = localCheckpoint; - } - + public record CommitInfo(long maxSeqNo, long localCheckpoint) { @Override public String toString() { return "CommitInfo{maxSeqNo=" + maxSeqNo + ", localCheckpoint=" + localCheckpoint + '}'; diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 881f4602be1c7..73cbca36a69c8 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -1856,8 +1856,8 @@ private void doLocalRecovery( return; } - assert safeCommit.get().localCheckpoint <= globalCheckpoint : safeCommit.get().localCheckpoint + " > " + globalCheckpoint; - if (safeCommit.get().localCheckpoint == globalCheckpoint) { + assert safeCommit.get().localCheckpoint() <= globalCheckpoint : safeCommit.get().localCheckpoint() + " > " + globalCheckpoint; + if (safeCommit.get().localCheckpoint() == globalCheckpoint) { logger.trace( "skip local recovery as the safe commit is up to date; safe commit {} global checkpoint {}", safeCommit.get(), @@ -1876,7 +1876,7 @@ private void doLocalRecovery( globalCheckpoint ); recoveryState.getTranslog().totalLocal(0); - recoveryStartingSeqNoListener.onResponse(safeCommit.get().localCheckpoint + 1); + recoveryStartingSeqNoListener.onResponse(safeCommit.get().localCheckpoint() + 1); return; } @@ -1915,7 +1915,7 @@ private void doLocalRecovery( // we need to find the safe commit again as we should have created a new one during the local recovery final Optional newSafeCommit = store.findSafeIndexCommit(globalCheckpoint); assert newSafeCommit.isPresent() : "no safe commit found after local recovery"; - return newSafeCommit.get().localCheckpoint + 1; + return newSafeCommit.get().localCheckpoint() + 1; } catch (Exception e) { logger.debug( () -> format( diff --git a/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java b/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java index ace891f9aead6..3783b64a0a04f 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java +++ b/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java @@ -396,7 +396,7 @@ protected static void addNewHistoryCommit(Directory indexDirectory, Terminal ter // We can only safely do it because we will generate a new history uuid this shard. final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userData.entrySet()); // Also advances the local checkpoint of the last commit to its max_seqno. - userData.put(SequenceNumbers.LOCAL_CHECKPOINT_KEY, Long.toString(commitInfo.maxSeqNo)); + userData.put(SequenceNumbers.LOCAL_CHECKPOINT_KEY, Long.toString(commitInfo.maxSeqNo())); } // commit the new history id diff --git a/server/src/main/java/org/elasticsearch/index/store/Store.java b/server/src/main/java/org/elasticsearch/index/store/Store.java index 5a33084e3ea83..b9c50edf50216 100644 --- a/server/src/main/java/org/elasticsearch/index/store/Store.java +++ b/server/src/main/java/org/elasticsearch/index/store/Store.java @@ -1529,7 +1529,7 @@ public Optional findSafeIndexCommit(long globalCheck final IndexCommit safeCommit = CombinedDeletionPolicy.findSafeCommitPoint(commits, globalCheckpoint); final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(safeCommit.getUserData().entrySet()); // all operations of the safe commit must be at most the global checkpoint. - if (commitInfo.maxSeqNo <= globalCheckpoint) { + if (commitInfo.maxSeqNo() <= globalCheckpoint) { return Optional.of(commitInfo); } else { return Optional.empty(); diff --git a/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java b/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java index d2c862bbf35d7..3be2532e3c3aa 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java +++ b/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java @@ -149,8 +149,8 @@ public long getLastModifiedTime() throws IOException { * Reads a single operation from the given location. */ Translog.Operation read(Translog.Location location) throws IOException { - assert location.generation == this.generation : "generation mismatch expected: " + generation + " got: " + location.generation; - ByteBuffer buffer = ByteBuffer.allocate(location.size); - return read(checksummedStream(buffer, location.translogLocation, location.size, null)); + assert location.generation() == this.generation : "generation mismatch expected: " + generation + " got: " + location.generation(); + ByteBuffer buffer = ByteBuffer.allocate(location.size()); + return read(checksummedStream(buffer, location.translogLocation(), location.size(), null)); } } diff --git a/server/src/main/java/org/elasticsearch/index/translog/Translog.java b/server/src/main/java/org/elasticsearch/index/translog/Translog.java index a079a852021bd..c02a810ed4952 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/Translog.java +++ b/server/src/main/java/org/elasticsearch/index/translog/Translog.java @@ -964,20 +964,10 @@ public TranslogDeletionPolicy getDeletionPolicy() { return deletionPolicy; } - public static class Location implements Comparable { + public record Location(long generation, long translogLocation, int size) implements Comparable { public static Location EMPTY = new Location(0, 0, 0); - public final long generation; - public final long translogLocation; - public final int size; - - public Location(long generation, long translogLocation, int size) { - this.generation = generation; - this.translogLocation = translogLocation; - this.size = size; - } - @Override public String toString() { return "[generation: " + generation + ", location: " + translogLocation + ", size: " + size + "]"; @@ -985,38 +975,10 @@ public String toString() { @Override public int compareTo(Location o) { - if (generation == o.generation) { - return Long.compare(translogLocation, o.translogLocation); - } - return Long.compare(generation, o.generation); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - Location location = (Location) o; - - if (generation != location.generation) { - return false; + int result = Long.compare(generation, o.generation); + if (result == 0) { + result = Long.compare(translogLocation, o.translogLocation); } - if (translogLocation != location.translogLocation) { - return false; - } - return size == location.size; - - } - - @Override - public int hashCode() { - int result = Long.hashCode(generation); - result = 31 * result + Long.hashCode(translogLocation); - result = 31 * result + size; return result; } } @@ -1819,16 +1781,7 @@ void closeFilesIfNoPendingRetentionLocks() throws IOException { /** * References a transaction log generation */ - public static final class TranslogGeneration { - public final String translogUUID; - public final long translogFileGeneration; - - public TranslogGeneration(String translogUUID, long translogFileGeneration) { - this.translogUUID = translogUUID; - this.translogFileGeneration = translogFileGeneration; - } - - } + public record TranslogGeneration(String translogUUID, long translogFileGeneration) {} /** * Returns the current generation of this translog. This corresponds to the latest uncommitted translog generation diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java index 538cfdabef324..df2a9d16ebd6a 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java @@ -1052,7 +1052,7 @@ boolean hasSameLegacySyncId(Store.MetadataSnapshot source, Store.MetadataSnapsho } SequenceNumbers.CommitInfo sourceSeqNos = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(source.commitUserData().entrySet()); SequenceNumbers.CommitInfo targetSeqNos = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(target.commitUserData().entrySet()); - if (sourceSeqNos.localCheckpoint != targetSeqNos.localCheckpoint || targetSeqNos.maxSeqNo != sourceSeqNos.maxSeqNo) { + if (sourceSeqNos.localCheckpoint() != targetSeqNos.localCheckpoint() || targetSeqNos.maxSeqNo() != sourceSeqNos.maxSeqNo()) { final String message = "try to recover " + request.shardId() + " with sync id but " diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 7606299c62bc8..1529ef556037a 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -545,8 +545,8 @@ private String description() { public static String getShardStateId(IndexShard indexShard, IndexCommit snapshotIndexCommit) throws IOException { final Map userCommitData = snapshotIndexCommit.getUserData(); final SequenceNumbers.CommitInfo seqNumInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userCommitData.entrySet()); - final long maxSeqNo = seqNumInfo.maxSeqNo; - if (maxSeqNo != seqNumInfo.localCheckpoint || maxSeqNo != indexShard.getLastSyncedGlobalCheckpoint()) { + final long maxSeqNo = seqNumInfo.maxSeqNo(); + if (maxSeqNo != seqNumInfo.localCheckpoint() || maxSeqNo != indexShard.getLastSyncedGlobalCheckpoint()) { return null; } return userCommitData.get(Engine.HISTORY_UUID_KEY) diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java index 5530ec61fea33..340ca87968db0 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java @@ -92,7 +92,6 @@ public class TransportWriteActionTests extends ESTestCase { private ClusterService clusterService; private IndexShard indexShard; - private Translog.Location location; @BeforeClass public static void beforeClass() { @@ -102,7 +101,6 @@ public static void beforeClass() { @Before public void initCommonMocks() { indexShard = mock(IndexShard.class); - location = mock(Translog.Location.class); clusterService = createClusterService(threadPool); when(indexShard.refresh(any())).thenReturn(new Engine.RefreshResult(true, randomNonNegativeLong(), 1)); ReplicationGroup replicationGroup = mock(ReplicationGroup.class); @@ -483,7 +481,14 @@ protected void dispatchedShardOperationOnPrimary( if (withDocumentFailureOnPrimary) { throw new RuntimeException("simulated"); } else { - return new WritePrimaryResult<>(request, new TestResponse(), location, primary, logger, postWriteRefresh); + return new WritePrimaryResult<>( + request, + new TestResponse(), + Translog.Location.EMPTY, + primary, + logger, + postWriteRefresh + ); } }); } @@ -495,7 +500,7 @@ protected void dispatchedShardOperationOnReplica(TestRequest request, IndexShard if (withDocumentFailureOnReplica) { replicaResult = new WriteReplicaResult<>(request, null, new RuntimeException("simulated"), replica, logger); } else { - replicaResult = new WriteReplicaResult<>(request, location, null, replica, logger); + replicaResult = new WriteReplicaResult<>(request, Translog.Location.EMPTY, null, replica, logger); } return replicaResult; }); diff --git a/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java b/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java index 9c345eb923ab4..bff978f8e79d8 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java @@ -29,8 +29,8 @@ public void testFlushListenerCompletedImmediatelyIfFlushAlreadyOccurred() { ); flushListeners.afterFlush(generation, lastWriteLocation); Translog.Location waitLocation = new Translog.Location( - lastWriteLocation.generation - randomLongBetween(0, 2), - lastWriteLocation.generation - randomLongBetween(10, 90), + lastWriteLocation.generation() - randomLongBetween(0, 2), + lastWriteLocation.generation() - randomLongBetween(10, 90), 2 ); PlainActionFuture future = new PlainActionFuture<>(); @@ -48,8 +48,8 @@ public void testFlushListenerCompletedAfterLocationFlushed() { Integer.MAX_VALUE ); Translog.Location waitLocation = new Translog.Location( - lastWriteLocation.generation - randomLongBetween(0, 2), - lastWriteLocation.generation - randomLongBetween(10, 90), + lastWriteLocation.generation() - randomLongBetween(0, 2), + lastWriteLocation.generation() - randomLongBetween(10, 90), 2 ); PlainActionFuture future = new PlainActionFuture<>(); @@ -61,13 +61,13 @@ public void testFlushListenerCompletedAfterLocationFlushed() { long generation2 = generation + 1; Translog.Location secondLastWriteLocation = new Translog.Location( - lastWriteLocation.generation, - lastWriteLocation.translogLocation + 10, + lastWriteLocation.generation(), + lastWriteLocation.translogLocation() + 10, Integer.MAX_VALUE ); Translog.Location waitLocation2 = new Translog.Location( - lastWriteLocation.generation, - lastWriteLocation.translogLocation + 4, + lastWriteLocation.generation(), + lastWriteLocation.translogLocation() + 4, 2 ); diff --git a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java index c08e47ea906c3..c668cfbb502a2 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java @@ -1249,7 +1249,7 @@ public void testSyncTranslogConcurrently() throws Exception { SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( safeCommit.getIndexCommit().getUserData().entrySet() ); - assertThat(commitInfo.localCheckpoint, equalTo(engine.getProcessedLocalCheckpoint())); + assertThat(commitInfo.localCheckpoint(), equalTo(engine.getProcessedLocalCheckpoint())); } }; final Thread[] threads = new Thread[randomIntBetween(2, 4)]; @@ -3414,7 +3414,7 @@ protected void commitIndexWriter(IndexWriter writer, Translog translog) throws I final long localCheckpoint = Long.parseLong( engine.getLastCommittedSegmentInfos().userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) ); - final long committedGen = engine.getTranslog().getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + final long committedGen = engine.getTranslog().getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); for (int gen = 1; gen < committedGen; gen++) { final Path genFile = translogPath.resolve(Translog.getFilename(gen)); assertFalse(genFile + " wasn't cleaned up", Files.exists(genFile)); @@ -3601,7 +3601,7 @@ public void testRecoverFromForeignTranslog() throws IOException { seqNo -> {} ); translog.add(TranslogOperationsUtils.indexOp("SomeBogusId", 0, primaryTerm.get())); - assertEquals(generation.translogFileGeneration, translog.currentFileGeneration()); + assertEquals(generation.translogFileGeneration(), translog.currentFileGeneration()); translog.close(); EngineConfig config = engine.config(); @@ -5232,7 +5232,7 @@ public void testMinGenerationForSeqNo() throws IOException, BrokenBarrierExcepti * This sequence number landed in the last generation, but the lower and upper bounds for an earlier generation straddle * this sequence number. */ - assertThat(translog.getMinGenerationForSeqNo(3 * i + 1).translogFileGeneration, equalTo(i + generation)); + assertThat(translog.getMinGenerationForSeqNo(3 * i + 1).translogFileGeneration(), equalTo(i + generation)); } int i = 0; @@ -5855,7 +5855,7 @@ public void testShouldPeriodicallyFlushOnSize() throws Exception { final Translog translog = engine.getTranslog(); final IntSupplier uncommittedTranslogOperationsSinceLastCommit = () -> { long localCheckpoint = Long.parseLong(engine.getLastCommittedSegmentInfos().userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); - return translog.totalOperationsByMinGen(translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration); + return translog.totalOperationsByMinGen(translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration()); }; final long extraTranslogSizeInNewEngine = engine.getTranslog().stats().getUncommittedSizeInBytes() - Translog.DEFAULT_HEADER_SIZE_IN_BYTES; diff --git a/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java b/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java index ff6b27924404e..7d018c23597b7 100644 --- a/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java @@ -272,11 +272,11 @@ public void testRecoveryAfterPrimaryPromotion() throws Exception { assertThat(newReplica.recoveryState().getIndex().fileDetails(), empty()); assertThat( newReplica.recoveryState().getTranslog().totalLocal(), - equalTo(Math.toIntExact(globalCheckpointOnOldPrimary - safeCommitOnOldPrimary.get().localCheckpoint)) + equalTo(Math.toIntExact(globalCheckpointOnOldPrimary - safeCommitOnOldPrimary.get().localCheckpoint())) ); assertThat( newReplica.recoveryState().getTranslog().recoveredOperations(), - equalTo(Math.toIntExact(totalDocs - 1 - safeCommitOnOldPrimary.get().localCheckpoint)) + equalTo(Math.toIntExact(totalDocs - 1 - safeCommitOnOldPrimary.get().localCheckpoint())) ); } else { assertThat(newReplica.recoveryState().getIndex().fileDetails(), not(empty())); diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index 9d53b95e01db3..29f39134d2bcf 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -659,7 +659,7 @@ public void testPrimaryFillsSeqNoGapsOnPromotion() throws Exception { public void testPrimaryPromotionRollsGeneration() throws Exception { final IndexShard indexShard = newStartedShard(false); - final long currentTranslogGeneration = getTranslog(indexShard).getGeneration().translogFileGeneration; + final long currentTranslogGeneration = getTranslog(indexShard).getGeneration().translogFileGeneration(); // promote the replica final ShardRouting replicaRouting = indexShard.routingEntry(); @@ -698,7 +698,7 @@ public void onFailure(Exception e) { }, threadPool.generic()); latch.await(); - assertThat(getTranslog(indexShard).getGeneration().translogFileGeneration, equalTo(currentTranslogGeneration + 1)); + assertThat(getTranslog(indexShard).getGeneration().translogFileGeneration(), equalTo(currentTranslogGeneration + 1)); assertThat(TestTranslog.getCurrentTerm(getTranslog(indexShard)), equalTo(newPrimaryTerm)); closeShards(indexShard); @@ -995,7 +995,7 @@ public void testOperationPermitOnReplicaShards() throws Exception { } final long primaryTerm = indexShard.getPendingPrimaryTerm(); - final long translogGen = engineClosed ? -1 : getTranslog(indexShard).getGeneration().translogFileGeneration; + final long translogGen = engineClosed ? -1 : getTranslog(indexShard).getGeneration().translogFileGeneration(); final Releasable operation1; final Releasable operation2; @@ -1115,7 +1115,7 @@ private void finish() { assertTrue(onResponse.get()); assertNull(onFailure.get()); assertThat( - getTranslog(indexShard).getGeneration().translogFileGeneration, + getTranslog(indexShard).getGeneration().translogFileGeneration(), // if rollback happens we roll translog twice: one when we flush a commit before opening a read-only engine // and one after replaying translog (upto the global checkpoint); otherwise we roll translog once. either(equalTo(translogGen + 1)).or(equalTo(translogGen + 2)) diff --git a/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java b/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java index cd7e637d58bcc..8a277e400ad6c 100644 --- a/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java +++ b/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java @@ -1250,7 +1250,7 @@ public void testLocationComparison() throws IOException { max = max(max, location); } - assertEquals(max.generation, translog.currentFileGeneration()); + assertEquals(max.generation(), translog.currentFileGeneration()); try (Translog.Snapshot snap = new SortedSnapshot(translog.newSnapshot())) { Translog.Operation next; Translog.Operation maxOp = null; @@ -1655,17 +1655,17 @@ public void testTranslogOperationListener() throws IOException { try (Translog translog = createTranslog(config)) { Location location1 = translog.add(indexOp(randomAlphaOfLength(10), 0, primaryTerm.get())); Location location2 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 1, primaryTerm.get())); - long firstGeneration = translog.getGeneration().translogFileGeneration; - assertThat(location1.generation, equalTo(firstGeneration)); - assertThat(location2.generation, equalTo(firstGeneration)); + long firstGeneration = translog.getGeneration().translogFileGeneration(); + assertThat(location1.generation(), equalTo(firstGeneration)); + assertThat(location2.generation(), equalTo(firstGeneration)); translog.rollGeneration(); Location location3 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 3, primaryTerm.get())); Location location4 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 2, primaryTerm.get())); - long secondGeneration = translog.getGeneration().translogFileGeneration; - assertThat(location3.generation, equalTo(secondGeneration)); - assertThat(location4.generation, equalTo(secondGeneration)); + long secondGeneration = translog.getGeneration().translogFileGeneration(); + assertThat(location3.generation(), equalTo(secondGeneration)); + assertThat(location4.generation(), equalTo(secondGeneration)); assertThat(seqNos, equalTo(List.of(0L, 1L, 3L, 2L))); assertThat(locations, equalTo(List.of(location1, location2, location3, location4))); @@ -1741,7 +1741,7 @@ public void testBasicRecovery() throws IOException { } else { translog = new Translog( config, - translogGeneration.translogUUID, + translogGeneration.translogUUID(), translog.getDeletionPolicy(), () -> SequenceNumbers.NO_OPS_PERFORMED, primaryTerm::get, @@ -1749,7 +1749,7 @@ public void testBasicRecovery() throws IOException { ); assertEquals( "lastCommitted must be 1 less than current", - translogGeneration.translogFileGeneration + 1, + translogGeneration.translogFileGeneration() + 1, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1758,7 +1758,7 @@ public void testBasicRecovery() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", translog.currentFileGeneration() - 1, - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -1782,9 +1782,9 @@ public void testRecoveryUncommitted() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } if (sync) { @@ -1808,7 +1808,7 @@ public void testRecoveryUncommitted() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1835,7 +1835,7 @@ public void testRecoveryUncommitted() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 3 less than current - we never finished the commit and run recovery twice", - translogGeneration.translogFileGeneration + 3, + translogGeneration.translogFileGeneration() + 3, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1869,9 +1869,9 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } if (sync) { @@ -1899,7 +1899,7 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1927,7 +1927,7 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 3 less than current - we never finished the commit and run recovery twice", - translogGeneration.translogFileGeneration + 3, + translogGeneration.translogFileGeneration() + 3, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1960,9 +1960,9 @@ public void testRecoveryUncommittedCorruptedCheckpoint() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } translog.sync(); @@ -2015,7 +2015,7 @@ public void testRecoveryUncommittedCorruptedCheckpoint() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -2284,7 +2284,7 @@ public void testOpenForeignTranslog() throws IOException { Translog.TranslogGeneration translogGeneration = translog.getGeneration(); translog.close(); - final String foreignTranslog = randomRealisticUnicodeOfCodepointLengthBetween(1, translogGeneration.translogUUID.length()); + final String foreignTranslog = randomRealisticUnicodeOfCodepointLengthBetween(1, translogGeneration.translogUUID().length()); try { new Translog( config, @@ -2507,7 +2507,7 @@ public void testFailFlush() throws IOException { ) { assertEquals( "lastCommitted must be 1 less than current", - translogGeneration.translogFileGeneration + 1, + translogGeneration.translogFileGeneration() + 1, tlog.currentFileGeneration() ); assertFalse(tlog.syncNeeded()); @@ -2518,7 +2518,7 @@ public void testFailFlush() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", tlog.currentFileGeneration() - 1, - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2540,7 +2540,7 @@ public void testTranslogOpsCountIsCorrect() throws IOException { assertEquals( "expected operation" + i + " to be in the current translog but wasn't", translog.currentFileGeneration(), - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2640,7 +2640,7 @@ protected void afterAdd() throws IOException { assertFalse(translog.isOpen()); final Checkpoint checkpoint = Checkpoint.read(config.getTranslogPath().resolve(Translog.CHECKPOINT_FILE_NAME)); // drop all that haven't been synced - writtenOperations.removeIf(next -> checkpoint.offset < (next.location.translogLocation + next.location.size)); + writtenOperations.removeIf(next -> checkpoint.offset < (next.location.translogLocation() + next.location.size())); try ( Translog tlog = new Translog( config, @@ -2664,7 +2664,7 @@ protected void afterAdd() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", tlog.currentFileGeneration() - 1, - writtenOperations.get(i).location.generation + writtenOperations.get(i).location.generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2695,7 +2695,7 @@ public void testRecoveryFromAFutureGenerationCleansUp() throws IOException { translog.rollGeneration(); } } - long minRetainedGen = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + long minRetainedGen = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); // engine blows up, after committing the above generation translog.close(); TranslogConfig config = translog.getConfig(); @@ -2753,7 +2753,7 @@ public void testRecoveryFromFailureOnTrimming() throws IOException { } } deletionPolicy.setLocalCheckpointOfSafeCommit(localCheckpoint); - minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); fail.failRandomly(); try { translog.trimUnreferencedReaders(); @@ -2777,7 +2777,7 @@ public void testRecoveryFromFailureOnTrimming() throws IOException { assertThat(translog.getMinFileGeneration(), greaterThanOrEqualTo(1L)); assertThat(translog.getMinFileGeneration(), lessThanOrEqualTo(minGenForRecovery)); assertFilePresences(translog); - minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); translog.trimUnreferencedReaders(); assertThat(translog.getMinFileGeneration(), equalTo(minGenForRecovery)); assertFilePresences(translog); @@ -3539,7 +3539,7 @@ public void testMinSeqNoBasedAPI() throws IOException { translog.rollGeneration(); for (long seqNo = 0; seqNo < operations; seqNo++) { final Set> seenSeqNos = new HashSet<>(); - final long generation = translog.getMinGenerationForSeqNo(seqNo).translogFileGeneration; + final long generation = translog.getMinGenerationForSeqNo(seqNo).translogFileGeneration(); int expectedSnapshotOps = 0; for (long g = generation; g < translog.currentFileGeneration(); g++) { if (seqNoPerGeneration.containsKey(g) == false) { @@ -3924,7 +3924,7 @@ public void testSyncConcurrently() throws Exception { assertThat("seq# " + op.seqNo() + " was not marked as persisted", persistedSeqNos, hasItem(op.seqNo())); } Checkpoint checkpoint = translog.getLastSyncedCheckpoint(); - assertThat(checkpoint.offset, greaterThanOrEqualTo(location.translogLocation)); + assertThat(checkpoint.offset, greaterThanOrEqualTo(location.translogLocation())); for (Translog.Operation op : ops) { assertThat(checkpoint.minSeqNo, lessThanOrEqualTo(op.seqNo())); assertThat(checkpoint.maxSeqNo, greaterThanOrEqualTo(op.seqNo())); diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java index 4266b514bf544..8001c8c901829 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java @@ -223,8 +223,8 @@ public void testPrepareIndexForPeerRecovery() throws Exception { Optional safeCommit = shard.store().findSafeIndexCommit(globalCheckpoint); assertTrue(safeCommit.isPresent()); int expectedTotalLocal = 0; - if (safeCommit.get().localCheckpoint < globalCheckpoint) { - try (Translog.Snapshot snapshot = getTranslog(shard).newSnapshot(safeCommit.get().localCheckpoint + 1, globalCheckpoint)) { + if (safeCommit.get().localCheckpoint() < globalCheckpoint) { + try (Translog.Snapshot snapshot = getTranslog(shard).newSnapshot(safeCommit.get().localCheckpoint() + 1, globalCheckpoint)) { Translog.Operation op; while ((op = snapshot.next()) != null) { if (op.seqNo() <= globalCheckpoint) { @@ -276,7 +276,7 @@ public void testPrepareIndexForPeerRecovery() throws Exception { replica.markAsRecovering("for testing", new RecoveryState(replica.routingEntry(), localNode, localNode)); replica.prepareForIndexRecovery(); if (safeCommit.isPresent()) { - assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint + 1)); + assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint() + 1)); assertThat(replica.recoveryState().getTranslog().totalLocal(), equalTo(0)); } else { assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(UNASSIGNED_SEQ_NO)); @@ -313,7 +313,7 @@ public void testClosedIndexSkipsLocalRecovery() throws Exception { ); replica.markAsRecovering("for testing", new RecoveryState(replica.routingEntry(), localNode, localNode)); replica.prepareForIndexRecovery(); - assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint + 1)); + assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint() + 1)); assertThat(replica.recoveryState().getTranslog().totalLocal(), equalTo(0)); assertThat(replica.recoveryState().getTranslog().recoveredOperations(), equalTo(0)); assertThat(replica.getLastKnownGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO)); diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java index fc8f1988a732b..47c9c5e85f7b9 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java @@ -252,7 +252,7 @@ public void testDifferentHistoryUUIDDisablesOPsRecovery() throws Exception { replica.getPendingPrimaryTerm() ); } else { - translogUUIDtoUse = translogGeneration.translogUUID; + translogUUIDtoUse = translogGeneration.translogUUID(); } try (IndexWriter writer = new IndexWriter(replica.store().directory(), iwc)) { userData.put(Engine.HISTORY_UUID_KEY, historyUUIDtoUse); From 7e0222df32c6911610b8ec2924c8b84c5d105d3b Mon Sep 17 00:00:00 2001 From: Albert Zaharovits Date: Tue, 9 Jul 2024 15:11:04 +0300 Subject: [PATCH 44/64] Remove dep com.nimbusds:nimbus-jose-jwt from module org.elasticsearch.xcore (#110565) The types from com.nimbusds.jwt are almost not needed in x-pack/plugin/core. They're only needed in module org.elasticsearch.security, x-pack:plugin:security project. --- x-pack/plugin/core/build.gradle | 23 +----------- .../core/src/main/java/module-info.java | 1 - .../xpack/core/security/action/Grant.java | 30 --------------- .../licenses/nimbus-jose-jwt-LICENSE.txt | 0 .../licenses/nimbus-jose-jwt-NOTICE.txt | 0 .../authc/jwt/JwtRealmSingleNodeTests.java | 1 - .../security/action/TransportGrantAction.java | 37 ++++++++++++++++++- .../security/authc/jwt/JwkSetLoader.java | 1 - .../security/authc/jwt/JwkValidateUtil.java | 1 - .../authc/jwt/JwtAuthenticationToken.java | 2 +- .../security/authc/jwt/JwtAuthenticator.java | 1 - .../xpack/security/authc/jwt/JwtRealm.java | 2 - .../authc/jwt/JwtSignatureValidator.java | 3 +- .../xpack}/security/authc/jwt/JwtUtil.java | 3 +- .../oidc/OpenIdConnectAuthenticator.java | 2 +- .../authc/jwt/JwtAuthenticatorTests.java | 1 - .../xpack/security/authc/jwt/JwtIssuer.java | 1 - .../authc/jwt/JwtRealmAuthenticateTests.java | 1 - .../authc/jwt/JwtRealmGenerateTests.java | 1 - .../security/authc/jwt/JwtRealmInspector.java | 1 - .../security/authc/jwt/JwtRealmTestCase.java | 1 - .../security/authc/jwt/JwtUtilTests.java | 1 - 22 files changed, 42 insertions(+), 72 deletions(-) rename x-pack/plugin/{core => security}/licenses/nimbus-jose-jwt-LICENSE.txt (100%) rename x-pack/plugin/{core => security}/licenses/nimbus-jose-jwt-NOTICE.txt (100%) rename x-pack/plugin/{core/src/main/java/org/elasticsearch/xpack/core => security/src/main/java/org/elasticsearch/xpack}/security/authc/jwt/JwtAuthenticationToken.java (98%) rename x-pack/plugin/{core/src/main/java/org/elasticsearch/xpack/core => security/src/main/java/org/elasticsearch/xpack}/security/authc/jwt/JwtUtil.java (99%) diff --git a/x-pack/plugin/core/build.gradle b/x-pack/plugin/core/build.gradle index 0c65c7e4b6d29..1ed59d6fe3581 100644 --- a/x-pack/plugin/core/build.gradle +++ b/x-pack/plugin/core/build.gradle @@ -51,7 +51,6 @@ dependencies { // security deps api 'com.unboundid:unboundid-ldapsdk:6.0.3' - api "com.nimbusds:nimbus-jose-jwt:9.23" implementation project(":x-pack:plugin:core:template-resources") @@ -135,27 +134,7 @@ tasks.named("thirdPartyAudit").configure { //commons-logging provided dependencies 'javax.servlet.ServletContextEvent', 'javax.servlet.ServletContextListener', - 'javax.jms.Message', - // Optional dependency of nimbus-jose-jwt for handling Ed25519 signatures and ECDH with X25519 (RFC 8037) - 'com.google.crypto.tink.subtle.Ed25519Sign', - 'com.google.crypto.tink.subtle.Ed25519Sign$KeyPair', - 'com.google.crypto.tink.subtle.Ed25519Verify', - 'com.google.crypto.tink.subtle.X25519', - 'com.google.crypto.tink.subtle.XChaCha20Poly1305', - // optional dependencies for nimbus-jose-jwt - 'org.bouncycastle.asn1.pkcs.PrivateKeyInfo', - 'org.bouncycastle.asn1.x509.AlgorithmIdentifier', - 'org.bouncycastle.asn1.x509.SubjectPublicKeyInfo', - 'org.bouncycastle.cert.X509CertificateHolder', - 'org.bouncycastle.cert.jcajce.JcaX509CertificateHolder', - 'org.bouncycastle.crypto.InvalidCipherTextException', - 'org.bouncycastle.crypto.engines.AESEngine', - 'org.bouncycastle.crypto.modes.GCMBlockCipher', - 'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider', - 'org.bouncycastle.jce.provider.BouncyCastleProvider', - 'org.bouncycastle.openssl.PEMKeyPair', - 'org.bouncycastle.openssl.PEMParser', - 'org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter' + 'javax.jms.Message' ) } diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index 282072417875b..72436bb9d5171 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -22,7 +22,6 @@ requires unboundid.ldapsdk; requires org.elasticsearch.tdigest; requires org.elasticsearch.xcore.templates; - requires com.nimbusds.jose.jwt; exports org.elasticsearch.index.engine.frozen; exports org.elasticsearch.license; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java index b186ab45a7dc7..c98564251cd43 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java @@ -7,19 +7,13 @@ package org.elasticsearch.xpack.core.security.action; -import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.support.BearerToken; -import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import java.io.IOException; @@ -136,30 +130,6 @@ public void setClientAuthentication(ClientAuthentication clientAuthentication) { this.clientAuthentication = clientAuthentication; } - public AuthenticationToken getAuthenticationToken() { - assert validate(null) == null : "grant is invalid"; - return switch (type) { - case PASSWORD_GRANT_TYPE -> new UsernamePasswordToken(username, password); - case ACCESS_TOKEN_GRANT_TYPE -> { - SecureString clientAuthentication = this.clientAuthentication != null ? this.clientAuthentication.value() : null; - AuthenticationToken token = JwtAuthenticationToken.tryParseJwt(accessToken, clientAuthentication); - if (token != null) { - yield token; - } - if (clientAuthentication != null) { - clientAuthentication.close(); - throw new ElasticsearchSecurityException( - "[client_authentication] not supported with the supplied access_token type", - RestStatus.BAD_REQUEST - ); - } - // here we effectively assume it's an ES access token (from the {@code TokenService}) - yield new BearerToken(accessToken); - } - default -> throw new ElasticsearchSecurityException("the grant type [{}] is not supported", type); - }; - } - public ActionRequestValidationException validate(ActionRequestValidationException validationException) { if (type == null) { validationException = addValidationError("[grant_type] is required", validationException); diff --git a/x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt b/x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt similarity index 100% rename from x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt rename to x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt diff --git a/x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt b/x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt similarity index 100% rename from x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt rename to x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java index 2ced54a513146..435706dce7019 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java @@ -52,7 +52,6 @@ import org.elasticsearch.xpack.core.security.action.user.AuthenticateResponse; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Realm; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.security.LocalStateSecurity; import org.elasticsearch.xpack.security.Security; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java index 667b513555594..fffcb476abaa4 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java @@ -7,24 +7,33 @@ package org.elasticsearch.xpack.security.action; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.security.action.Grant; import org.elasticsearch.xpack.core.security.action.GrantRequest; import org.elasticsearch.xpack.core.security.action.user.AuthenticateAction; import org.elasticsearch.xpack.core.security.action.user.AuthenticateRequest; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; +import org.elasticsearch.xpack.core.security.authc.support.BearerToken; +import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.security.authz.AuthorizationService; +import static org.elasticsearch.xpack.core.security.action.Grant.ACCESS_TOKEN_GRANT_TYPE; +import static org.elasticsearch.xpack.core.security.action.Grant.PASSWORD_GRANT_TYPE; + public abstract class TransportGrantAction extends TransportAction< Request, Response> { @@ -50,7 +59,7 @@ public TransportGrantAction( @Override public final void doExecute(Task task, Request request, ActionListener listener) { try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - final AuthenticationToken authenticationToken = request.getGrant().getAuthenticationToken(); + final AuthenticationToken authenticationToken = getAuthenticationToken(request.getGrant()); assert authenticationToken != null : "authentication token must not be null"; final String runAsUsername = request.getGrant().getRunAsUsername(); @@ -109,4 +118,30 @@ protected abstract void doExecuteWithGrantAuthentication( Authentication authentication, ActionListener listener ); + + public static AuthenticationToken getAuthenticationToken(Grant grant) { + assert grant.validate(null) == null : "grant is invalid"; + return switch (grant.getType()) { + case PASSWORD_GRANT_TYPE -> new UsernamePasswordToken(grant.getUsername(), grant.getPassword()); + case ACCESS_TOKEN_GRANT_TYPE -> { + SecureString clientAuthentication = grant.getClientAuthentication() != null + ? grant.getClientAuthentication().value() + : null; + AuthenticationToken token = JwtAuthenticationToken.tryParseJwt(grant.getAccessToken(), clientAuthentication); + if (token != null) { + yield token; + } + if (clientAuthentication != null) { + clientAuthentication.close(); + throw new ElasticsearchSecurityException( + "[client_authentication] not supported with the supplied access_token type", + RestStatus.BAD_REQUEST + ); + } + // here we effectively assume it's an ES access token (from the {@code TokenService}) + yield new BearerToken(grant.getAccessToken()); + } + default -> throw new ElasticsearchSecurityException("the grant type [{}] is not supported", grant.getType()); + }; + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java index 0266fc7488e29..063cc85ea0187 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.IOException; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java index cc07b7dfa8381..89391f91a2731 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import java.nio.charset.StandardCharsets; import java.security.PublicKey; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java similarity index 98% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java index ebfaae72b9df2..cfef9aed5967a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.jwt; +package org.elasticsearch.xpack.security.authc.jwt; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java index b06aba1c9d87a..2345add07ba51 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java @@ -19,7 +19,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.security.authc.RealmConfig; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index 30a7e438e70b0..7613e7b3972af 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -31,9 +31,7 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.CachingRealm; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java index e183ee7d73ac2..b1ee1b77998ec 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java @@ -35,14 +35,13 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.util.Arrays; import java.util.List; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil.toStringRedactSignature; +import static org.elasticsearch.xpack.security.authc.jwt.JwtUtil.toStringRedactSignature; public interface JwtSignatureValidator extends Releasable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java similarity index 99% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java index d70b76f8bc574..928ecd7fa265d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.jwt; +package org.elasticsearch.xpack.security.authc.jwt; import com.nimbusds.jose.JWSObject; import com.nimbusds.jose.jwk.JWK; @@ -47,6 +47,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.InputStream; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java index e637bda19d886..0f34850b861b7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java @@ -91,9 +91,9 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.security.authc.jwt.JwtUtil; import java.io.IOException; import java.net.URI; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java index 7a44ebae95738..6d4861212e286 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; import org.junit.Before; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java index 3d4d9eae6acd0..789ac04c40622 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java @@ -14,7 +14,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.user.User; import java.io.Closeable; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index bf6c64242701b..4f7b82a16e8f1 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java index 7a0e138305b83..8a5daa642002e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java index 40a613a0907c8..7697849179acf 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.ClaimSetting; import java.net.URI; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index 1bc49cb628464..ffc1fec1f5788 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings.ClientAuthenticationType; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java index 7d90dffd7517c..6fab33b4d6adf 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; From bfc32a2acc8942dea31c7ac1f893805906b833fb Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Tue, 9 Jul 2024 14:52:13 +0200 Subject: [PATCH 45/64] ES|QL: better validation for GROK patterns (#110574) --- docs/changelog/110574.yaml | 6 ++++ .../xpack/esql/action/EsqlCapabilities.java | 8 ++++- .../xpack/esql/parser/LogicalPlanBuilder.java | 20 ++++++++++- .../xpack/esql/plan/logical/Grok.java | 2 +- .../esql/parser/StatementParserTests.java | 18 ++++++++-- .../rest-api-spec/test/esql/100_bug_fix.yml | 35 +++++++++++++++++++ 6 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 docs/changelog/110574.yaml diff --git a/docs/changelog/110574.yaml b/docs/changelog/110574.yaml new file mode 100644 index 0000000000000..1840838500151 --- /dev/null +++ b/docs/changelog/110574.yaml @@ -0,0 +1,6 @@ +pr: 110574 +summary: "ES|QL: better validation for GROK patterns" +area: ES|QL +type: bug +issues: + - 110533 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 88f6ff0c95b05..fa822b50ffcf5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -111,7 +111,13 @@ public enum Cap { /** * Fix for union-types when aggregating over an inline conversion with casting operator. Done in #110476. */ - UNION_TYPES_AGG_CAST; + UNION_TYPES_AGG_CAST, + + /** + * Fix to GROK validation in case of multiple fields with same name and different types + * https://github.com/elastic/elasticsearch/issues/110533 + */ + GROK_VALIDATION; private final boolean snapshotOnly; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 9ee5931c85c36..e97323f963887 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -146,12 +146,30 @@ public PlanFactory visitEvalCommand(EsqlBaseParser.EvalCommandContext ctx) { @Override public PlanFactory visitGrokCommand(EsqlBaseParser.GrokCommandContext ctx) { return p -> { + Source source = source(ctx); String pattern = visitString(ctx.string()).fold().toString(); - Grok result = new Grok(source(ctx), p, expression(ctx.primaryExpression()), Grok.pattern(source(ctx), pattern)); + Grok.Parser grokParser = Grok.pattern(source, pattern); + validateGrokPattern(source, grokParser, pattern); + Grok result = new Grok(source(ctx), p, expression(ctx.primaryExpression()), grokParser); return result; }; } + private void validateGrokPattern(Source source, Grok.Parser grokParser, String pattern) { + Map definedAttributes = new HashMap<>(); + for (Attribute field : grokParser.extractedFields()) { + String name = field.name(); + DataType type = field.dataType(); + DataType prev = definedAttributes.put(name, type); + if (prev != null) { + throw new ParsingException( + source, + "Invalid GROK pattern [" + pattern + "]: the attribute [" + name + "] is defined multiple times with different types" + ); + } + } + } + @Override public PlanFactory visitDissectCommand(EsqlBaseParser.DissectCommandContext ctx) { return p -> { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java index e084f6d3e5e3a..963fd318f814c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java @@ -30,7 +30,7 @@ public class Grok extends RegexExtract { public record Parser(String pattern, org.elasticsearch.grok.Grok grok) { - private List extractedFields() { + public List extractedFields() { return grok.captureConfig() .stream() .sorted(Comparator.comparing(GrokCaptureConfig::name)) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index eee40b25176ab..2f76cb2049820 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -758,15 +758,27 @@ public void testDissectPattern() { public void testGrokPattern() { LogicalPlan cmd = processingCommand("grok a \"%{WORD:foo}\""); assertEquals(Grok.class, cmd.getClass()); - Grok dissect = (Grok) cmd; - assertEquals("%{WORD:foo}", dissect.parser().pattern()); - assertEquals(List.of(referenceAttribute("foo", KEYWORD)), dissect.extractedFields()); + Grok grok = (Grok) cmd; + assertEquals("%{WORD:foo}", grok.parser().pattern()); + assertEquals(List.of(referenceAttribute("foo", KEYWORD)), grok.extractedFields()); ParsingException pe = expectThrows(ParsingException.class, () -> statement("row a = \"foo bar\" | grok a \"%{_invalid_:x}\"")); assertThat( pe.getMessage(), containsString("Invalid pattern [%{_invalid_:x}] for grok: Unable to find pattern [_invalid_] in Grok's pattern dictionary") ); + + cmd = processingCommand("grok a \"%{WORD:foo} %{WORD:foo}\""); + assertEquals(Grok.class, cmd.getClass()); + grok = (Grok) cmd; + assertEquals("%{WORD:foo} %{WORD:foo}", grok.parser().pattern()); + assertEquals(List.of(referenceAttribute("foo", KEYWORD)), grok.extractedFields()); + + expectError( + "row a = \"foo bar\" | GROK a \"%{NUMBER:foo} %{WORD:foo}\"", + "line 1:22: Invalid GROK pattern [%{NUMBER:foo} %{WORD:foo}]:" + + " the attribute [foo] is defined multiple times with different types" + ); } public void testLikeRLike() { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml index b91343d03d3d4..cffc161b11539 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml @@ -303,3 +303,38 @@ - match: { values.0.2: [1, 2] } - match: { values.0.3: [1, 2] } - match: { values.0.4: [1.1, 2.2] } + + +--- +"grok with duplicate names and different types #110533": + - requires: + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_query + parameters: [] + capabilities: [grok_validation] + reason: "fixed grok validation with patterns containing the same attribute multiple times with different types" + - do: + indices.create: + index: test_grok + body: + mappings: + properties: + first_name : + type : keyword + last_name: + type: keyword + + - do: + bulk: + refresh: true + body: + - { "index": { "_index": "test_grok" } } + - { "first_name": "Georgi", "last_name":"Facello" } + + - do: + catch: '/Invalid GROK pattern \[%\{NUMBER:foo\} %\{WORD:foo\}\]: the attribute \[foo\] is defined multiple times with different types/' + esql.query: + body: + query: 'FROM test_grok | KEEP name | WHERE last_name == "Facello" | EVAL name = concat("1 ", last_name) | GROK name "%{NUMBER:foo} %{WORD:foo}"' From 1b6d44b55d68b9b2efc03b5894d10aafdf70837d Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 9 Jul 2024 15:30:42 +0100 Subject: [PATCH 46/64] [DOCS] Fix typo: though -> through (#110636) --- docs/reference/inference/delete-inference.asciidoc | 2 +- docs/reference/inference/get-inference.asciidoc | 2 +- docs/reference/inference/inference-apis.asciidoc | 2 +- docs/reference/inference/post-inference.asciidoc | 2 +- docs/reference/inference/put-inference.asciidoc | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/reference/inference/delete-inference.asciidoc b/docs/reference/inference/delete-inference.asciidoc index 2f9d9511e6326..4df72ba672092 100644 --- a/docs/reference/inference/delete-inference.asciidoc +++ b/docs/reference/inference/delete-inference.asciidoc @@ -8,7 +8,7 @@ Deletes an {infer} endpoint. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/get-inference.asciidoc b/docs/reference/inference/get-inference.asciidoc index 7f4dc1c496837..c3fe841603bcc 100644 --- a/docs/reference/inference/get-inference.asciidoc +++ b/docs/reference/inference/get-inference.asciidoc @@ -8,7 +8,7 @@ Retrieves {infer} endpoint information. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/inference-apis.asciidoc b/docs/reference/inference/inference-apis.asciidoc index 896cb02a9e699..02a57504da1cf 100644 --- a/docs/reference/inference/inference-apis.asciidoc +++ b/docs/reference/inference/inference-apis.asciidoc @@ -6,7 +6,7 @@ experimental[] IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/post-inference.asciidoc b/docs/reference/inference/post-inference.asciidoc index 3ad23ac3300cc..52131c0b10776 100644 --- a/docs/reference/inference/post-inference.asciidoc +++ b/docs/reference/inference/post-inference.asciidoc @@ -8,7 +8,7 @@ Performs an inference task on an input text by using an {infer} endpoint. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index 101c0a24b66b7..656feb54ffe42 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -8,7 +8,7 @@ Creates an {infer} endpoint to perform an {infer} task. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Mistral, Azure OpenAI, Google AI Studio, Google Vertex AI or Hugging Face. -For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models. +For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. From 3a82bfb9c46db5ff420c3edb19b322828a956b20 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 9 Jul 2024 11:00:24 -0400 Subject: [PATCH 47/64] ESQL: Move `Failures` into the esql proper (#110585) This moves the `Failueres` and `Failure` class into ESQL's main package, further slimming down our custom fork of the shared ql code. Slowly slowly slowly, it will be no more. --- .../org/elasticsearch/xpack/esql/VerificationException.java | 4 ++-- .../java/org/elasticsearch/xpack/esql/analysis/Analyzer.java | 2 +- .../java/org/elasticsearch/xpack/esql/analysis/Verifier.java | 4 ++-- .../elasticsearch/xpack/esql/capabilities/Validatable.java | 2 +- .../java/org/elasticsearch/xpack/esql}/common/Failure.java | 2 +- .../java/org/elasticsearch/xpack/esql}/common/Failures.java | 2 +- .../org/elasticsearch/xpack/esql/expression/Validations.java | 2 +- .../xpack/esql/expression/function/grouping/Bucket.java | 2 +- .../esql/expression/function/scalar/multivalue/MvSort.java | 4 ++-- .../xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java | 2 +- .../xpack/esql/optimizer/LogicalPlanOptimizer.java | 2 +- .../elasticsearch/xpack/esql/optimizer/LogicalVerifier.java | 2 +- .../elasticsearch/xpack/esql/optimizer/OptimizerRules.java | 4 ++-- .../xpack/esql/optimizer/PhysicalPlanOptimizer.java | 2 +- .../elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java | 4 ++-- .../elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java | 2 +- .../predicate/operator/AbstractBinaryOperatorTestCase.java | 2 +- 17 files changed, 22 insertions(+), 22 deletions(-) rename x-pack/plugin/{esql-core/src/main/java/org/elasticsearch/xpack/esql/core => esql/src/main/java/org/elasticsearch/xpack/esql}/common/Failure.java (97%) rename x-pack/plugin/{esql-core/src/main/java/org/elasticsearch/xpack/esql/core => esql/src/main/java/org/elasticsearch/xpack/esql}/common/Failures.java (96%) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java index 99e4a57757e38..8443b8d99d04a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.esql; -import org.elasticsearch.xpack.esql.core.common.Failure; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failure; +import org.elasticsearch.xpack.esql.common.Failures; import java.util.Collection; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 30ffffd4770a9..fbc98e093c0fb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -16,8 +16,8 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.analysis.AnalyzerRules.BaseAnalyzerRule; import org.elasticsearch.xpack.esql.analysis.AnalyzerRules.ParameterizedAnalyzerRule; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; -import org.elasticsearch.xpack.esql.core.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 9b90f411c4eb8..a4e0d99b0d3fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.esql.analysis; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; -import org.elasticsearch.xpack.esql.core.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; @@ -53,7 +53,7 @@ import java.util.function.Consumer; import java.util.stream.Stream; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java index 4d30f32af5f15..f6733fa3f175c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.capabilities; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; /** * Interface implemented by expressions that require validation post logical optimization, diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java similarity index 97% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java index 719ae7ffbd1ca..e5d0fb7ba0b3d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.common; +package org.elasticsearch.xpack.esql.common; import org.elasticsearch.xpack.esql.core.tree.Location; import org.elasticsearch.xpack.esql.core.tree.Node; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java similarity index 96% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java index c06fe94c9a338..fd25cb427d95b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.common; +package org.elasticsearch.xpack.esql.common; import java.util.Collection; import java.util.LinkedHashSet; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java index dffa723a1f3dd..ffcc26cb6f188 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.expression; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expression.TypeResolution; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 40e927404befd..3ce51b8086dd0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -16,7 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Foldables; import org.elasticsearch.xpack.esql.core.expression.Literal; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java index 199dc49b46097..ee83236ac6a63 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java @@ -30,8 +30,8 @@ import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeLong; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failure; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failure; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 1b40a1c2b02ad..c03dc46216621 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -17,7 +17,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 284f264b85e1c..50819b8ee7480 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 007fb3939db0c..cd61b4eb8892c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.optimizer.OptimizerRules.LogicalPlanDependencyCheck; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index ecd83fbba022c..bff76fb1a706e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.plan.QueryPlan; @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.esql.plan.physical.RowExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; class OptimizerRules { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java index 70c2a9007408a..e9fd6a713945c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java index 77c8e7da5d895..7843464650e37 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.optimizer.OptimizerRules.PhysicalPlanDependencyCheck; @@ -18,7 +18,7 @@ import java.util.LinkedHashSet; import java.util.Set; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; /** Physical plan verifier. */ public final class PhysicalVerifier { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index e97323f963887..d1e0bdac0bf2f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -16,7 +16,7 @@ import org.elasticsearch.dissect.DissectParser; import org.elasticsearch.index.IndexMode; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java index a9663f9e37852..974c8703b2a09 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java @@ -9,7 +9,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.xpack.esql.analysis.Verifier; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; From 3004d14ccbaa47334d3e6ec02d2b5fcb06a8e0ee Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 9 Jul 2024 08:28:33 -0700 Subject: [PATCH 48/64] Adjust exchange timeout in tests (#110569) Some of our tests use an exchange timeout that is too short, causing the exchange sinks to expire before the fetch page requests arrive. This change adjusts the exchange timeout to between 3 and 4 seconds, which should be sufficient without increasing the execution time of the disruption tests. Closes #109944 Closes #106641 --- muted-tests.yml | 3 --- .../compute/operator/exchange/ExchangeService.java | 13 ++++++------- .../esql/action/AbstractEsqlIntegTestCase.java | 7 ++++++- .../xpack/esql/action/AsyncEsqlQueryActionIT.java | 2 +- .../esql/action/CrossClustersCancellationIT.java | 2 +- .../elasticsearch/xpack/esql/action/EnrichIT.java | 2 +- .../xpack/esql/action/EsqlActionBreakerIT.java | 4 +--- .../xpack/esql/action/EsqlDisruptionIT.java | 2 +- 8 files changed, 17 insertions(+), 18 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index ccbdb68fbb8c7..34fb81a01590f 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -53,9 +53,6 @@ tests: - class: "org.elasticsearch.xpack.security.ScrollHelperIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/109905" method: "testFetchAllEntities" -- class: "org.elasticsearch.xpack.esql.action.AsyncEsqlQueryActionIT" - issue: "https://github.com/elastic/elasticsearch/issues/109944" - method: "testBasicAsyncExecution" - class: "org.elasticsearch.action.admin.indices.rollover.RolloverIT" issue: "https://github.com/elastic/elasticsearch/issues/110034" method: "testRolloverWithClosedWriteIndex" diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index f647f4fba0225..a365a655370a2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -250,21 +250,20 @@ public boolean isForceExecution() { protected void doRun() { assert Transports.assertNotTransportThread("reaping inactive exchanges can be expensive"); assert ThreadPool.assertNotScheduleThread("reaping inactive exchanges can be expensive"); + logger.debug("start removing inactive sinks"); final long nowInMillis = threadPool.relativeTimeInMillis(); for (Map.Entry e : sinks.entrySet()) { ExchangeSinkHandler sink = e.getValue(); if (sink.hasData() && sink.hasListeners()) { continue; } - long elapsed = nowInMillis - sink.lastUpdatedTimeInMillis(); - if (elapsed > keepAlive.millis()) { + long elapsedInMillis = nowInMillis - sink.lastUpdatedTimeInMillis(); + if (elapsedInMillis > keepAlive.millis()) { + TimeValue elapsedTime = TimeValue.timeValueMillis(elapsedInMillis); + logger.debug("removed sink {} inactive for {}", e.getKey(), elapsedTime); finishSinkHandler( e.getKey(), - new ElasticsearchTimeoutException( - "Exchange sink {} has been inactive for {}", - e.getKey(), - TimeValue.timeValueMillis(elapsed) - ) + new ElasticsearchTimeoutException("Exchange sink {} has been inactive for {}", e.getKey(), elapsedTime) ); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java index 22e3de8499bc1..84738f733f86b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -44,7 +45,11 @@ public void ensureExchangesAreReleased() throws Exception { for (String node : internalCluster().getNodeNames()) { TransportEsqlQueryAction esqlQueryAction = internalCluster().getInstance(TransportEsqlQueryAction.class, node); ExchangeService exchangeService = esqlQueryAction.exchangeService(); - assertBusy(() -> assertTrue("Leftover exchanges " + exchangeService + " on node " + node, exchangeService.isEmpty())); + assertBusy(() -> { + if (exchangeService.lifecycleState() == Lifecycle.State.STARTED) { + assertTrue("Leftover exchanges " + exchangeService + " on node " + node, exchangeService.isEmpty()); + } + }); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java index da9aa96876fd7..f85de51101af5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java @@ -54,7 +54,7 @@ protected Collection> nodePlugins() { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder() - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .build(); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java index 800067fef8b1c..df6a1e00b0212 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java @@ -68,7 +68,7 @@ public List> getSettings() { return List.of( Setting.timeSetting( ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, - TimeValue.timeValueMillis(between(1000, 3000)), + TimeValue.timeValueMillis(between(3000, 4000)), Setting.Property.NodeScope ) ); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java index 5be816712cf20..cdfa6eb2d03f3 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java @@ -111,7 +111,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getKey(), HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getDefault(Settings.EMPTY) ) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 256))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_MAX_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 1024))) // allow reading pages from network can trip the circuit breaker diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java index 089cb4a9a5084..37833d8aed2d3 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.action; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.DocWriteResponse; @@ -35,7 +34,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105543") @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class EsqlActionBreakerIT extends EsqlActionIT { @@ -72,7 +70,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getKey(), HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getDefault(Settings.EMPTY) ) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 256))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_MAX_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 1024))) // allow reading pages from network can trip the circuit breaker diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java index df1b2c9f00f49..e9eada5def0dc 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java @@ -52,7 +52,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings settings = Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) .put(DEFAULT_SETTINGS) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(1000, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .build(); logger.info("settings {}", settings); return settings; From 763f2f1600162502c447ad04af3c149921ef8b0b Mon Sep 17 00:00:00 2001 From: Stanislav Malyshev Date: Tue, 9 Jul 2024 09:32:21 -0600 Subject: [PATCH 49/64] Add feature flag ccs_telemetry (#110619) --- .../org/elasticsearch/action/search/TransportSearchAction.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index c2d1cdae85cd9..0368dec76df0e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -49,6 +49,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -121,6 +122,8 @@ public class TransportSearchAction extends HandledTransportAction SHARD_COUNT_LIMIT_SETTING = Setting.longSetting( "action.search.shard_count.limit", From 1e7e42d0ec8a99183ab1a844b7002829f4f1cf30 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 9 Jul 2024 13:20:02 -0400 Subject: [PATCH 50/64] [Transform] Unmute testMaxPageSearchSizeIsResetToConfiguredValue (#110537) This was fixed as part of PR#109876. Relate #109844 Co-authored-by: Elastic Machine --- muted-tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 34fb81a01590f..612d3bdc72ca1 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -56,9 +56,6 @@ tests: - class: "org.elasticsearch.action.admin.indices.rollover.RolloverIT" issue: "https://github.com/elastic/elasticsearch/issues/110034" method: "testRolloverWithClosedWriteIndex" -- class: org.elasticsearch.xpack.transform.transforms.TransformIndexerTests - method: testMaxPageSearchSizeIsResetToConfiguredValue - issue: https://github.com/elastic/elasticsearch/issues/109844 - class: org.elasticsearch.index.store.FsDirectoryFactoryTests method: testStoreDirectory issue: https://github.com/elastic/elasticsearch/issues/110210 From 4ec94be1df93ec40e67484d77533f2eb62383bf4 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 9 Jul 2024 13:20:31 -0400 Subject: [PATCH 51/64] [Transform] log search payload for preview (#110653) A quick change to help debug search requests performed by the validate and preview API. Co-authored-by: Elastic Machine --- .../common/AbstractCompositeAggFunction.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java index 3412be813dcf6..23bab56de5ec9 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.transform.transforms.common; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; @@ -45,6 +47,7 @@ * Basic abstract class for implementing a transform function that utilizes composite aggregations */ public abstract class AbstractCompositeAggFunction implements Function { + private static final Logger logger = LogManager.getLogger(AbstractCompositeAggFunction.class); public static final int TEST_QUERY_PAGE_SIZE = 50; public static final String COMPOSITE_AGGREGATION_NAME = "_transform"; @@ -78,7 +81,7 @@ public void preview( ClientHelper.TRANSFORM_ORIGIN, client, TransportSearchAction.TYPE, - buildSearchRequest(sourceConfig, timeout, numberOfBuckets), + buildSearchRequestForValidation("preview", sourceConfig, timeout, numberOfBuckets), ActionListener.wrap(r -> { try { final InternalAggregations aggregations = r.getAggregations(); @@ -116,7 +119,7 @@ public void validateQuery( TimeValue timeout, ActionListener listener ) { - SearchRequest searchRequest = buildSearchRequest(sourceConfig, timeout, TEST_QUERY_PAGE_SIZE); + SearchRequest searchRequest = buildSearchRequestForValidation("validate", sourceConfig, timeout, TEST_QUERY_PAGE_SIZE); ClientHelper.executeWithHeadersAsync( headers, ClientHelper.TRANSFORM_ORIGIN, @@ -193,11 +196,12 @@ protected abstract Stream> extractResults( TransformProgress progress ); - private SearchRequest buildSearchRequest(SourceConfig sourceConfig, TimeValue timeout, int pageSize) { + private SearchRequest buildSearchRequestForValidation(String logId, SourceConfig sourceConfig, TimeValue timeout, int pageSize) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(sourceConfig.getQueryConfig().getQuery()) .runtimeMappings(sourceConfig.getRuntimeMappings()) .timeout(timeout); buildSearchQuery(sourceBuilder, null, pageSize); + logger.debug("[{}] Querying {} for data: {}", logId, sourceConfig.getIndex(), sourceBuilder); return new SearchRequest(sourceConfig.getIndex()).source(sourceBuilder).indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); } From 4f25a395aacbe4b24363db4a212b5d1c97a6b36e Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Tue, 9 Jul 2024 14:19:37 -0500 Subject: [PATCH 52/64] Adding a unit test for GeoIpDownloader.cleanDatabases (#110650) Co-authored-by: Joe Gallo --- .../ingest/geoip/GeoIpDownloaderTests.java | 95 ++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java index 4834c581e9386..4d5070d96683e 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java @@ -30,11 +30,17 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.ingest.geoip.stats.GeoIpDownloaderStats; import org.elasticsearch.node.Node; +import org.elasticsearch.persistent.PersistentTaskResponse; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction; import org.elasticsearch.telemetry.metric.MeterRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; @@ -49,6 +55,9 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -63,6 +72,8 @@ import static org.elasticsearch.ingest.geoip.GeoIpDownloader.MAX_CHUNK_SIZE; import static org.elasticsearch.tasks.TaskId.EMPTY_TASK_ID; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -76,8 +87,9 @@ public class GeoIpDownloaderTests extends ESTestCase { private GeoIpDownloader geoIpDownloader; @Before - public void setup() { + public void setup() throws IOException { httpClient = mock(HttpClient.class); + when(httpClient.getBytes(anyString())).thenReturn("[]".getBytes(StandardCharsets.UTF_8)); clusterService = mock(ClusterService.class); threadPool = new ThreadPool(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP); when(clusterService.getClusterSettings()).thenReturn( @@ -109,7 +121,13 @@ public void setup() { () -> GeoIpDownloaderTaskExecutor.POLL_INTERVAL_SETTING.getDefault(Settings.EMPTY), () -> GeoIpDownloaderTaskExecutor.EAGER_DOWNLOAD_SETTING.getDefault(Settings.EMPTY), () -> true - ); + ) { + { + GeoIpTaskParams geoIpTaskParams = mock(GeoIpTaskParams.class); + when(geoIpTaskParams.getWriteableName()).thenReturn(GeoIpDownloader.GEOIP_DOWNLOADER); + init(new PersistentTasksService(clusterService, threadPool, client), null, null, 0); + } + }; } @After @@ -541,6 +559,79 @@ public void testUpdateDatabasesIndexNotReady() { verifyNoInteractions(httpClient); } + public void testThatRunDownloaderDeletesExpiredDatabases() { + /* + * This test puts some expired databases and some non-expired ones into the GeoIpTaskState, and then calls runDownloader(), making + * sure that the expired databases have been deleted. + */ + AtomicInteger updatePersistentTaskStateCount = new AtomicInteger(0); + AtomicInteger deleteCount = new AtomicInteger(0); + int expiredDatabasesCount = randomIntBetween(1, 100); + int unexpiredDatabasesCount = randomIntBetween(0, 100); + Map databases = new HashMap<>(); + for (int i = 0; i < expiredDatabasesCount; i++) { + databases.put("expiredDatabase" + i, newGeoIpTaskStateMetadata(true)); + } + for (int i = 0; i < unexpiredDatabasesCount; i++) { + databases.put("unexpiredDatabase" + i, newGeoIpTaskStateMetadata(false)); + } + GeoIpTaskState geoIpTaskState = new GeoIpTaskState(databases); + geoIpDownloader.setState(geoIpTaskState); + client.addHandler( + UpdatePersistentTaskStatusAction.INSTANCE, + (UpdatePersistentTaskStatusAction.Request request, ActionListener taskResponseListener) -> { + + PersistentTasksCustomMetadata.Assignment assignment = mock(PersistentTasksCustomMetadata.Assignment.class); + PersistentTasksCustomMetadata.PersistentTask persistentTask = new PersistentTasksCustomMetadata.PersistentTask<>( + GeoIpDownloader.GEOIP_DOWNLOADER, + GeoIpDownloader.GEOIP_DOWNLOADER, + new GeoIpTaskParams(), + request.getAllocationId(), + assignment + ); + taskResponseListener.onResponse(new PersistentTaskResponse(new PersistentTask<>(persistentTask, request.getState()))); + updatePersistentTaskStateCount.incrementAndGet(); + } + ); + client.addHandler( + DeleteByQueryAction.INSTANCE, + (DeleteByQueryRequest request, ActionListener flushResponseActionListener) -> { + deleteCount.incrementAndGet(); + } + ); + geoIpDownloader.runDownloader(); + assertThat(geoIpDownloader.getStatus().getExpiredDatabases(), equalTo(expiredDatabasesCount)); + for (int i = 0; i < expiredDatabasesCount; i++) { + // This currently fails because we subtract one millisecond from the lastChecked time + // assertThat(geoIpDownloader.state.getDatabases().get("expiredDatabase" + i).lastCheck(), equalTo(-1L)); + } + for (int i = 0; i < unexpiredDatabasesCount; i++) { + assertThat( + geoIpDownloader.state.getDatabases().get("unexpiredDatabase" + i).lastCheck(), + greaterThanOrEqualTo(Instant.now().minus(30, ChronoUnit.DAYS).toEpochMilli()) + ); + } + assertThat(deleteCount.get(), equalTo(expiredDatabasesCount)); + assertThat(updatePersistentTaskStateCount.get(), equalTo(expiredDatabasesCount)); + geoIpDownloader.runDownloader(); + /* + * The following two lines assert current behavior that might not be desirable -- we continue to delete expired databases every + * time that runDownloader runs. This seems unnecessary. + */ + assertThat(deleteCount.get(), equalTo(expiredDatabasesCount * 2)); + assertThat(updatePersistentTaskStateCount.get(), equalTo(expiredDatabasesCount * 2)); + } + + private GeoIpTaskState.Metadata newGeoIpTaskStateMetadata(boolean expired) { + Instant lastChecked; + if (expired) { + lastChecked = Instant.now().minus(randomIntBetween(31, 100), ChronoUnit.DAYS); + } else { + lastChecked = Instant.now().minus(randomIntBetween(0, 29), ChronoUnit.DAYS); + } + return new GeoIpTaskState.Metadata(0, 0, 0, randomAlphaOfLength(20), lastChecked.toEpochMilli()); + } + private static class MockClient extends NoOpClient { private final Map, BiConsumer>> handlers = new HashMap<>(); From c6f82604d740c3fa77d5c6944a8f1abb5b17e0c9 Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Tue, 9 Jul 2024 12:25:27 -0700 Subject: [PATCH 53/64] Move exec syscall filtering to NativeAccess (#108970) This commit moves the system call filtering initialization into NativeAccess. The code is essentially unmodified from its existing state, now existing within the *NativeAccess implementations. relates #104876 --- .../nativeaccess/jna/JnaKernel32Library.java | 85 +++ .../nativeaccess/jna/JnaLinuxCLibrary.java | 94 +++ .../nativeaccess/jna/JnaMacCLibrary.java | 59 ++ .../jna/JnaNativeLibraryProvider.java | 6 + .../nativeaccess/jna/JnaPosixCLibrary.java | 67 ++ .../nativeaccess/AbstractNativeAccess.java | 6 + .../nativeaccess/LinuxNativeAccess.java | 268 +++++++- .../nativeaccess/MacNativeAccess.java | 83 +++ .../nativeaccess/NativeAccess.java | 22 + .../nativeaccess/NoopNativeAccess.java | 10 + .../nativeaccess/WindowsNativeAccess.java | 51 ++ .../nativeaccess/lib/Kernel32Library.java | 61 ++ .../nativeaccess/lib/LinuxCLibrary.java | 38 ++ .../nativeaccess/lib/MacCLibrary.java | 25 + .../nativeaccess/lib/NativeLibrary.java | 3 +- .../nativeaccess/lib/PosixCLibrary.java | 22 + .../nativeaccess/jdk/JdkKernel32Library.java | 116 ++++ .../nativeaccess/jdk/JdkLinuxCLibrary.java | 103 +++ .../nativeaccess/jdk/JdkMacCLibrary.java | 73 ++ .../jdk/JdkNativeLibraryProvider.java | 6 + .../nativeaccess/jdk/JdkPosixCLibrary.java | 83 +++ .../nativeaccess}/SystemCallFilterTests.java | 15 +- .../build.gradle | 6 +- .../bootstrap/BootstrapChecks.java | 4 +- .../bootstrap/BootstrapInfo.java | 17 - .../bootstrap/Elasticsearch.java | 9 +- .../bootstrap/JNAKernel32Library.java | 255 ------- .../elasticsearch/bootstrap/JNANatives.java | 50 -- .../org/elasticsearch/bootstrap/Natives.java | 69 -- .../bootstrap/SystemCallFilter.java | 641 ------------------ 30 files changed, 1298 insertions(+), 1049 deletions(-) create mode 100644 libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java create mode 100644 libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java create mode 100644 libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java create mode 100644 libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java create mode 100644 libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java create mode 100644 libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java rename {qa/evil-tests/src/test/java/org/elasticsearch/bootstrap => libs/native/src/test/java/org/elasticsearch/nativeaccess}/SystemCallFilterTests.java (84%) delete mode 100644 server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java delete mode 100644 server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java delete mode 100644 server/src/main/java/org/elasticsearch/bootstrap/Natives.java delete mode 100644 server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java index 0bfdf959f7b58..2c7ec70f36eb3 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java @@ -13,6 +13,7 @@ import com.sun.jna.NativeLong; import com.sun.jna.Pointer; import com.sun.jna.Structure; +import com.sun.jna.Structure.ByReference; import com.sun.jna.WString; import com.sun.jna.win32.StdCallLibrary; @@ -98,6 +99,38 @@ public long Type() { } } + /** + * Basic limit information for a job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx + */ + public static class JnaJobObjectBasicLimitInformation extends Structure implements ByReference, JobObjectBasicLimitInformation { + public byte[] _ignore1 = new byte[16]; + public int LimitFlags; + public byte[] _ignore2 = new byte[20]; + public int ActiveProcessLimit; + public byte[] _ignore3 = new byte[20]; + + public JnaJobObjectBasicLimitInformation() { + super(8); + } + + @Override + protected List getFieldOrder() { + return List.of("_ignore1", "LimitFlags", "_ignore2", "ActiveProcessLimit", "_ignore3"); + } + + @Override + public void setLimitFlags(int v) { + LimitFlags = v; + } + + @Override + public void setActiveProcessLimit(int v) { + ActiveProcessLimit = v; + } + } + /** * JNA adaptation of {@link ConsoleCtrlHandler} */ @@ -128,6 +161,20 @@ private interface NativeFunctions extends StdCallLibrary { int GetShortPathNameW(WString lpszLongPath, char[] lpszShortPath, int cchBuffer); boolean SetConsoleCtrlHandler(StdCallLibrary.StdCallCallback handler, boolean add); + + Pointer CreateJobObjectW(Pointer jobAttributes, String name); + + boolean AssignProcessToJobObject(Pointer job, Pointer process); + + boolean QueryInformationJobObject( + Pointer job, + int infoClass, + JnaJobObjectBasicLimitInformation info, + int infoLength, + Pointer returnLength + ); + + boolean SetInformationJobObject(Pointer job, int infoClass, JnaJobObjectBasicLimitInformation info, int infoLength); } private final NativeFunctions functions; @@ -197,4 +244,42 @@ public boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add) { consoleCtrlHandlerCallback = new NativeHandlerCallback(handler); return functions.SetConsoleCtrlHandler(consoleCtrlHandlerCallback, true); } + + @Override + public Handle CreateJobObjectW() { + return new JnaHandle(functions.CreateJobObjectW(null, null)); + } + + @Override + public boolean AssignProcessToJobObject(Handle job, Handle process) { + assert job instanceof JnaHandle; + assert process instanceof JnaHandle; + var jnaJob = (JnaHandle) job; + var jnaProcess = (JnaHandle) process; + return functions.AssignProcessToJobObject(jnaJob.pointer, jnaProcess.pointer); + } + + @Override + public JobObjectBasicLimitInformation newJobObjectBasicLimitInformation() { + return new JnaJobObjectBasicLimitInformation(); + } + + @Override + public boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JnaHandle; + assert info instanceof JnaJobObjectBasicLimitInformation; + var jnaJob = (JnaHandle) job; + var jnaInfo = (JnaJobObjectBasicLimitInformation) info; + var ret = functions.QueryInformationJobObject(jnaJob.pointer, infoClass, jnaInfo, jnaInfo.size(), null); + return ret; + } + + @Override + public boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JnaHandle; + assert info instanceof JnaJobObjectBasicLimitInformation; + var jnaJob = (JnaHandle) job; + var jnaInfo = (JnaJobObjectBasicLimitInformation) info; + return functions.SetInformationJobObject(jnaJob.pointer, infoClass, jnaInfo, jnaInfo.size()); + } } diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java new file mode 100644 index 0000000000000..742c666d59c23 --- /dev/null +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jna; + +import com.sun.jna.Library; +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.NativeLong; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; + +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +class JnaLinuxCLibrary implements LinuxCLibrary { + + @Structure.FieldOrder({ "len", "filter" }) + public static final class JnaSockFProg extends Structure implements Structure.ByReference, SockFProg { + public short len; // number of filters + public Pointer filter; // filters + + JnaSockFProg(SockFilter filters[]) { + len = (short) filters.length; + // serialize struct sock_filter * explicitly, its less confusing than the JNA magic we would need + Memory filter = new Memory(len * 8); + ByteBuffer bbuf = filter.getByteBuffer(0, len * 8); + bbuf.order(ByteOrder.nativeOrder()); // little endian + for (SockFilter f : filters) { + bbuf.putShort(f.code()); + bbuf.put(f.jt()); + bbuf.put(f.jf()); + bbuf.putInt(f.k()); + } + this.filter = filter; + } + + @Override + public long address() { + return Pointer.nativeValue(getPointer()); + } + } + + private interface NativeFunctions extends Library { + + /** + * maps to prctl(2) + */ + int prctl(int option, NativeLong arg2, NativeLong arg3, NativeLong arg4, NativeLong arg5); + + /** + * used to call seccomp(2), its too new... + * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing + */ + NativeLong syscall(NativeLong number, Object... args); + } + + private final NativeFunctions functions; + + JnaLinuxCLibrary() { + try { + this.functions = Native.load("c", NativeFunctions.class); + } catch (UnsatisfiedLinkError e) { + throw new UnsupportedOperationException( + "seccomp unavailable: could not link methods. requires kernel 3.5+ " + + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } + } + + @Override + public SockFProg newSockFProg(SockFilter[] filters) { + var prog = new JnaSockFProg(filters); + prog.write(); + return prog; + } + + @Override + public int prctl(int option, long arg2, long arg3, long arg4, long arg5) { + return functions.prctl(option, new NativeLong(arg2), new NativeLong(arg3), new NativeLong(arg4), new NativeLong(arg5)); + } + + @Override + public long syscall(long number, int operation, int flags, long address) { + return functions.syscall(new NativeLong(number), operation, flags, address).longValue(); + } +} diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java new file mode 100644 index 0000000000000..f416cf862b417 --- /dev/null +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jna; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import com.sun.jna.ptr.PointerByReference; + +import org.elasticsearch.nativeaccess.lib.MacCLibrary; + +class JnaMacCLibrary implements MacCLibrary { + static class JnaErrorReference implements ErrorReference { + final PointerByReference ref = new PointerByReference(); + + @Override + public String toString() { + return ref.getValue().getString(0); + } + } + + private interface NativeFunctions extends Library { + int sandbox_init(String profile, long flags, PointerByReference errorbuf); + + void sandbox_free_error(Pointer errorbuf); + } + + private final NativeFunctions functions; + + JnaMacCLibrary() { + this.functions = Native.load("c", NativeFunctions.class); + } + + @Override + public ErrorReference newErrorReference() { + return new JnaErrorReference(); + } + + @Override + public int sandbox_init(String profile, long flags, ErrorReference errorbuf) { + assert errorbuf instanceof JnaErrorReference; + var jnaErrorbuf = (JnaErrorReference) errorbuf; + return functions.sandbox_init(profile, flags, jnaErrorbuf.ref); + } + + @Override + public void sandbox_free_error(ErrorReference errorbuf) { + assert errorbuf instanceof JnaErrorReference; + var jnaErrorbuf = (JnaErrorReference) errorbuf; + functions.sandbox_free_error(jnaErrorbuf.ref.getValue()); + } + +} diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java index 9d34b1ba617e8..454581ae70b51 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java @@ -10,6 +10,8 @@ import org.elasticsearch.nativeaccess.lib.JavaLibrary; import org.elasticsearch.nativeaccess.lib.Kernel32Library; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.PosixCLibrary; @@ -30,6 +32,10 @@ public JnaNativeLibraryProvider() { JnaJavaLibrary::new, PosixCLibrary.class, JnaPosixCLibrary::new, + LinuxCLibrary.class, + JnaLinuxCLibrary::new, + MacCLibrary.class, + JnaMacCLibrary::new, Kernel32Library.class, JnaKernel32Library::new, SystemdLibrary.class, diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java index 7e8e4f23ab034..03a7b9c0869be 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java @@ -39,6 +39,50 @@ public long rlim_cur() { public long rlim_max() { return rlim_max.longValue(); } + + @Override + public void rlim_cur(long v) { + rlim_cur.setValue(v); + } + + @Override + public void rlim_max(long v) { + rlim_max.setValue(v); + } + } + + public static class JnaFStore extends Structure implements Structure.ByReference, FStore { + + public int fst_flags = 0; + public int fst_posmode = 0; + public NativeLong fst_offset = new NativeLong(0); + public NativeLong fst_length = new NativeLong(0); + public NativeLong fst_bytesalloc = new NativeLong(0); + + @Override + public void set_flags(int flags) { + this.fst_flags = flags; + } + + @Override + public void set_posmode(int posmode) { + this.fst_posmode = posmode; + } + + @Override + public void set_offset(long offset) { + fst_offset.setValue(offset); + } + + @Override + public void set_length(long length) { + fst_length.setValue(length); + } + + @Override + public long bytesalloc() { + return fst_bytesalloc.longValue(); + } } private interface NativeFunctions extends Library { @@ -46,8 +90,12 @@ private interface NativeFunctions extends Library { int getrlimit(int resource, JnaRLimit rlimit); + int setrlimit(int resource, JnaRLimit rlimit); + int mlockall(int flags); + int fcntl(int fd, int cmd, JnaFStore fst); + String strerror(int errno); } @@ -74,11 +122,30 @@ public int getrlimit(int resource, RLimit rlimit) { return functions.getrlimit(resource, jnaRlimit); } + @Override + public int setrlimit(int resource, RLimit rlimit) { + assert rlimit instanceof JnaRLimit; + var jnaRlimit = (JnaRLimit) rlimit; + return functions.setrlimit(resource, jnaRlimit); + } + @Override public int mlockall(int flags) { return functions.mlockall(flags); } + @Override + public FStore newFStore() { + return new JnaFStore(); + } + + @Override + public int fcntl(int fd, int cmd, FStore fst) { + assert fst instanceof JnaFStore; + var jnaFst = (JnaFStore) fst; + return functions.fcntl(fd, cmd, jnaFst); + } + @Override public String strerror(int errno) { return functions.strerror(errno); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java index 80a18a2bc8aa0..c10f57a900ff7 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java @@ -22,6 +22,7 @@ abstract class AbstractNativeAccess implements NativeAccess { private final JavaLibrary javaLib; private final Zstd zstd; protected boolean isMemoryLocked = false; + protected ExecSandboxState execSandboxState = ExecSandboxState.NONE; protected AbstractNativeAccess(String name, NativeLibraryProvider libraryProvider) { this.name = name; @@ -53,4 +54,9 @@ public CloseableByteBuffer newBuffer(int len) { public boolean isMemoryLocked() { return isMemoryLocked; } + + @Override + public ExecSandboxState getExecSandboxState() { + return execSandboxState; + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java index 7948dad1df4ad..c50e639c94d27 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java @@ -8,15 +8,88 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary.SockFProg; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary.SockFilter; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.SystemdLibrary; +import java.util.Map; + class LinuxNativeAccess extends PosixNativeAccess { - Systemd systemd; + /** the preferred method is seccomp(2), since we can apply to all threads of the process */ + static final int SECCOMP_SET_MODE_FILTER = 1; // since Linux 3.17 + static final int SECCOMP_FILTER_FLAG_TSYNC = 1; // since Linux 3.17 + + /** otherwise, we can use prctl(2), which will at least protect ES application threads */ + static final int PR_GET_NO_NEW_PRIVS = 39; // since Linux 3.5 + static final int PR_SET_NO_NEW_PRIVS = 38; // since Linux 3.5 + static final int PR_GET_SECCOMP = 21; // since Linux 2.6.23 + static final int PR_SET_SECCOMP = 22; // since Linux 2.6.23 + static final long SECCOMP_MODE_FILTER = 2; // since Linux Linux 3.5 + + // BPF "macros" and constants + static final int BPF_LD = 0x00; + static final int BPF_W = 0x00; + static final int BPF_ABS = 0x20; + static final int BPF_JMP = 0x05; + static final int BPF_JEQ = 0x10; + static final int BPF_JGE = 0x30; + static final int BPF_JGT = 0x20; + static final int BPF_RET = 0x06; + static final int BPF_K = 0x00; + + static SockFilter BPF_STMT(int code, int k) { + return new SockFilter((short) code, (byte) 0, (byte) 0, k); + } + + static SockFilter BPF_JUMP(int code, int k, int jt, int jf) { + return new SockFilter((short) code, (byte) jt, (byte) jf, k); + } + + static final int SECCOMP_RET_ERRNO = 0x00050000; + static final int SECCOMP_RET_DATA = 0x0000FFFF; + static final int SECCOMP_RET_ALLOW = 0x7FFF0000; + + // some errno constants for error checking/handling + static final int EACCES = 0x0D; + static final int EFAULT = 0x0E; + static final int EINVAL = 0x16; + static final int ENOSYS = 0x26; + + // offsets that our BPF checks + // check with offsetof() when adding a new arch, move to Arch if different. + static final int SECCOMP_DATA_NR_OFFSET = 0x00; + static final int SECCOMP_DATA_ARCH_OFFSET = 0x04; + + record Arch( + int audit, // AUDIT_ARCH_XXX constant from linux/audit.h + int limit, // syscall limit (necessary for blacklisting on amd64, to ban 32-bit syscalls) + int fork, // __NR_fork + int vfork, // __NR_vfork + int execve, // __NR_execve + int execveat, // __NR_execveat + int seccomp // __NR_seccomp + ) {} + + /** supported architectures for seccomp keyed by os.arch */ + private static final Map ARCHITECTURES; + static { + ARCHITECTURES = Map.of( + "amd64", + new Arch(0xC000003E, 0x3FFFFFFF, 57, 58, 59, 322, 317), + "aarch64", + new Arch(0xC00000B7, 0xFFFFFFFF, 1079, 1071, 221, 281, 277) + ); + } + + private final LinuxCLibrary linuxLibc; + private final Systemd systemd; LinuxNativeAccess(NativeLibraryProvider libraryProvider) { super("Linux", libraryProvider, new PosixConstants(-1L, 9, 1, 8)); + this.linuxLibc = libraryProvider.getLibrary(LinuxCLibrary.class); this.systemd = new Systemd(libraryProvider.getLibrary(SystemdLibrary.class)); } @@ -46,4 +119,197 @@ protected void logMemoryLimitInstructions() { \t{} hard memlock unlimited""", user, user, user); logger.warn("If you are logged in interactively, you will have to re-login for the new limits to take effect."); } + + /** + * Installs exec system call filtering for Linux. + *

+ * On Linux exec system call filtering currently supports amd64 and aarch64 architectures. + * It requires Linux kernel 3.5 or above, and {@code CONFIG_SECCOMP} and {@code CONFIG_SECCOMP_FILTER} + * compiled into the kernel. + *

+ * On Linux BPF Filters are installed using either {@code seccomp(2)} (3.17+) or {@code prctl(2)} (3.5+). {@code seccomp(2)} + * is preferred, as it allows filters to be applied to any existing threads in the process, and one motivation + * here is to protect against bugs in the JVM. Otherwise, code will fall back to the {@code prctl(2)} method + * which will at least protect elasticsearch application threads. + *

+ * Linux BPF filters will return {@code EACCES} (Access Denied) for the following system calls: + *

    + *
  • {@code execve}
  • + *
  • {@code fork}
  • + *
  • {@code vfork}
  • + *
  • {@code execveat}
  • + *
+ * @see + * * http://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt + */ + @Override + public void tryInstallExecSandbox() { + // first be defensive: we can give nice errors this way, at the very least. + // also, some of these security features get backported to old versions, checking kernel version here is a big no-no! + String archId = System.getProperty("os.arch"); + final Arch arch = ARCHITECTURES.get(archId); + if (arch == null) { + throw new UnsupportedOperationException("seccomp unavailable: '" + archId + "' architecture unsupported"); + } + + // try to check system calls really are who they claim + // you never know (e.g. https://chromium.googlesource.com/chromium/src.git/+/master/sandbox/linux/seccomp-bpf/sandbox_bpf.cc#57) + final int bogusArg = 0xf7a46a5c; + + // test seccomp(BOGUS) + long ret = linuxLibc.syscall(arch.seccomp, bogusArg, 0, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: seccomp(BOGUS_OPERATION) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("seccomp(BOGUS_OPERATION): " + libc.strerror(errno)); + } + } + + // test seccomp(VALID, BOGUS) + ret = linuxLibc.syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, bogusArg, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG): " + libc.strerror(errno)); + } + } + + // test prctl(BOGUS) + ret = linuxLibc.prctl(bogusArg, 0, 0, 0, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: prctl(BOGUS_OPTION) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("prctl(BOGUS_OPTION): " + libc.strerror(errno)); + } + } + + // now just normal defensive checks + + // check for GET_NO_NEW_PRIVS + switch (linuxLibc.prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)) { + case 0: + break; // not yet set + case 1: + break; // already set by caller + default: + int errno = libc.errno(); + if (errno == EINVAL) { + // friendly error, this will be the typical case for an old kernel + throw new UnsupportedOperationException( + "seccomp unavailable: requires kernel 3.5+ with" + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } else { + throw new UnsupportedOperationException("prctl(PR_GET_NO_NEW_PRIVS): " + libc.strerror(errno)); + } + } + // check for SECCOMP + switch (linuxLibc.prctl(PR_GET_SECCOMP, 0, 0, 0, 0)) { + case 0: + break; // not yet set + case 2: + break; // already in filter mode by caller + default: + int errno = libc.errno(); + if (errno == EINVAL) { + throw new UnsupportedOperationException( + "seccomp unavailable: CONFIG_SECCOMP not compiled into kernel," + + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" + ); + } else { + throw new UnsupportedOperationException("prctl(PR_GET_SECCOMP): " + libc.strerror(errno)); + } + } + // check for SECCOMP_MODE_FILTER + if (linuxLibc.prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, 0, 0, 0) != 0) { + int errno = libc.errno(); + switch (errno) { + case EFAULT: + break; // available + case EINVAL: + throw new UnsupportedOperationException( + "seccomp unavailable: CONFIG_SECCOMP_FILTER not" + + " compiled into kernel, CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" + ); + default: + throw new UnsupportedOperationException("prctl(PR_SET_SECCOMP): " + libc.strerror(errno)); + } + } + + // ok, now set PR_SET_NO_NEW_PRIVS, needed to be able to set a seccomp filter as ordinary user + if (linuxLibc.prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) != 0) { + throw new UnsupportedOperationException("prctl(PR_SET_NO_NEW_PRIVS): " + libc.strerror(libc.errno())); + } + + // check it worked + if (linuxLibc.prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0) != 1) { + throw new UnsupportedOperationException( + "seccomp filter did not really succeed: prctl(PR_GET_NO_NEW_PRIVS): " + libc.strerror(libc.errno()) + ); + } + + // BPF installed to check arch, limit, then syscall. + // See https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt for details. + SockFilter insns[] = { + /* 1 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_ARCH_OFFSET), // + /* 2 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.audit, 0, 7), // if (arch != audit) goto fail; + /* 3 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_NR_OFFSET), // + /* 4 */ BPF_JUMP(BPF_JMP + BPF_JGT + BPF_K, arch.limit, 5, 0), // if (syscall > LIMIT) goto fail; + /* 5 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.fork, 4, 0), // if (syscall == FORK) goto fail; + /* 6 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.vfork, 3, 0), // if (syscall == VFORK) goto fail; + /* 7 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execve, 2, 0), // if (syscall == EXECVE) goto fail; + /* 8 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execveat, 1, 0), // if (syscall == EXECVEAT) goto fail; + /* 9 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), // pass: return OK; + /* 10 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ERRNO | (EACCES & SECCOMP_RET_DATA)), // fail: return EACCES; + }; + // seccomp takes a long, so we pass it one explicitly to keep the JNA simple + SockFProg prog = linuxLibc.newSockFProg(insns); + + int method = 1; + // install filter, if this works, after this there is no going back! + // first try it with seccomp(SECCOMP_SET_MODE_FILTER), falling back to prctl() + if (linuxLibc.syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_TSYNC, prog.address()) != 0) { + method = 0; + int errno1 = libc.errno(); + if (logger.isDebugEnabled()) { + logger.debug("seccomp(SECCOMP_SET_MODE_FILTER): {}, falling back to prctl(PR_SET_SECCOMP)...", libc.strerror(errno1)); + } + if (linuxLibc.prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, prog.address(), 0, 0) != 0) { + int errno2 = libc.errno(); + throw new UnsupportedOperationException( + "seccomp(SECCOMP_SET_MODE_FILTER): " + libc.strerror(errno1) + ", prctl(PR_SET_SECCOMP): " + libc.strerror(errno2) + ); + } + } + + // now check that the filter was really installed, we should be in filter mode. + if (linuxLibc.prctl(PR_GET_SECCOMP, 0, 0, 0, 0) != 2) { + throw new UnsupportedOperationException( + "seccomp filter installation did not really succeed. seccomp(PR_GET_SECCOMP): " + libc.strerror(libc.errno()) + ); + } + + logger.debug("Linux seccomp filter installation successful, threads: [{}]", method == 1 ? "all" : "app"); + execSandboxState = method == 1 ? ExecSandboxState.ALL_THREADS : ExecSandboxState.EXISTING_THREADS; + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java index 0388c66d3962f..c53b7ba6ac2f0 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java @@ -8,12 +8,30 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; +import org.elasticsearch.nativeaccess.lib.PosixCLibrary.RLimit; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; class MacNativeAccess extends PosixNativeAccess { + /** The only supported flag... */ + static final int SANDBOX_NAMED = 1; + /** Allow everything except process fork and execution */ + static final String SANDBOX_RULES = "(version 1) (allow default) (deny process-fork) (deny process-exec)"; + + private final MacCLibrary macLibc; + MacNativeAccess(NativeLibraryProvider libraryProvider) { super("MacOS", libraryProvider, new PosixConstants(9223372036854775807L, 5, 1, 6)); + this.macLibc = libraryProvider.getLibrary(MacCLibrary.class); } @Override @@ -25,4 +43,69 @@ protected long getMaxThreads() { protected void logMemoryLimitInstructions() { // we don't have instructions for macos } + + /** + * Installs exec system call filtering on MacOS. + *

+ * Two different methods of filtering are used. Since MacOS is BSD based, process creation + * is first restricted with {@code setrlimit(RLIMIT_NPROC)}. + *

+ * Additionally, on Mac OS X Leopard or above, a custom {@code sandbox(7)} ("Seatbelt") profile is installed that + * denies the following rules: + *

    + *
  • {@code process-fork}
  • + *
  • {@code process-exec}
  • + *
+ * @see + * * https://reverse.put.as/wp-content/uploads/2011/06/The-Apple-Sandbox-BHDC2011-Paper.pdf + */ + @Override + public void tryInstallExecSandbox() { + initBsdSandbox(); + initMacSandbox(); + execSandboxState = ExecSandboxState.ALL_THREADS; + } + + @SuppressForbidden(reason = "Java tmp dir is ok") + private static Path createTempRulesFile() throws IOException { + return Files.createTempFile("es", "sb"); + } + + private void initMacSandbox() { + // write rules to a temporary file, which will be passed to sandbox_init() + Path rules; + try { + rules = createTempRulesFile(); + Files.write(rules, Collections.singleton(SANDBOX_RULES)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + try { + var errorRef = macLibc.newErrorReference(); + int ret = macLibc.sandbox_init(rules.toAbsolutePath().toString(), SANDBOX_NAMED, errorRef); + // if sandbox_init() fails, add the message from the OS (e.g. syntax error) and free the buffer + if (ret != 0) { + RuntimeException e = new UnsupportedOperationException("sandbox_init(): " + errorRef.toString()); + macLibc.sandbox_free_error(errorRef); + throw e; + } + logger.debug("OS X seatbelt initialization successful"); + } finally { + IOUtils.deleteFilesIgnoringExceptions(rules); + } + } + + private void initBsdSandbox() { + RLimit limit = libc.newRLimit(); + limit.rlim_cur(0); + limit.rlim_max(0); + // not a standard limit, means something different on linux, etc! + final int RLIMIT_NPROC = 7; + if (libc.setrlimit(RLIMIT_NPROC, limit) != 0) { + throw new UnsupportedOperationException("RLIMIT_NPROC unavailable: " + libc.strerror(libc.errno())); + } + + logger.debug("BSD RLIMIT_NPROC initialization successful"); + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java index 7f91d0425af47..61935ac93c5a3 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java @@ -44,6 +44,16 @@ static NativeAccess instance() { */ boolean isMemoryLocked(); + /** + * Attempts to install a system call filter to block process execution. + */ + void tryInstallExecSandbox(); + + /** + * Return whether installing the exec system call filters was successful, and to what degree. + */ + ExecSandboxState getExecSandboxState(); + Systemd systemd(); /** @@ -71,4 +81,16 @@ default WindowsFunctions getWindowsFunctions() { * @return the buffer */ CloseableByteBuffer newBuffer(int len); + + /** + * Possible stats for execution filtering. + */ + enum ExecSandboxState { + /** No execution filtering */ + NONE, + /** Exec is blocked for threads that were already created */ + EXISTING_THREADS, + /** Exec is blocked for all current and future threads */ + ALL_THREADS + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java index c0eed4a9ce09b..fc186cb03b0d9 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java @@ -41,6 +41,16 @@ public boolean isMemoryLocked() { return false; } + @Override + public void tryInstallExecSandbox() { + logger.warn("Cannot install system call filter because native access is not available"); + } + + @Override + public ExecSandboxState getExecSandboxState() { + return ExecSandboxState.NONE; + } + @Override public Systemd systemd() { logger.warn("Cannot get systemd access because native access is not available"); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java index 843cc73fbed02..a9ccd15330595 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java @@ -27,6 +27,16 @@ class WindowsNativeAccess extends AbstractNativeAccess { public static final int PAGE_GUARD = 0x0100; public static final int MEM_COMMIT = 0x1000; + /** + * Constant for JOBOBJECT_BASIC_LIMIT_INFORMATION in Query/Set InformationJobObject + */ + private static final int JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS = 2; + + /** + * Constant for LimitFlags, indicating a process limit has been set + */ + private static final int JOB_OBJECT_LIMIT_ACTIVE_PROCESS = 8; + private final Kernel32Library kernel; private final WindowsFunctions windowsFunctions; @@ -68,6 +78,47 @@ public void tryLockMemory() { // note: no need to close the process handle because GetCurrentProcess returns a pseudo handle } + /** + * Install exec system call filtering on Windows. + *

+ * Process creation is restricted with {@code SetInformationJobObject/ActiveProcessLimit}. + *

+ * Note: This is not intended as a real sandbox. It is another level of security, mostly intended to annoy + * security researchers and make their lives more difficult in achieving "remote execution" exploits. + */ + @Override + public void tryInstallExecSandbox() { + // create a new Job + Handle job = kernel.CreateJobObjectW(); + if (job == null) { + throw new UnsupportedOperationException("CreateJobObject: " + kernel.GetLastError()); + } + + try { + // retrieve the current basic limits of the job + int clazz = JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS; + var info = kernel.newJobObjectBasicLimitInformation(); + if (kernel.QueryInformationJobObject(job, clazz, info) == false) { + throw new UnsupportedOperationException("QueryInformationJobObject: " + kernel.GetLastError()); + } + // modify the number of active processes to be 1 (exactly the one process we will add to the job). + info.setActiveProcessLimit(1); + info.setLimitFlags(JOB_OBJECT_LIMIT_ACTIVE_PROCESS); + if (kernel.SetInformationJobObject(job, clazz, info) == false) { + throw new UnsupportedOperationException("SetInformationJobObject: " + kernel.GetLastError()); + } + // assign ourselves to the job + if (kernel.AssignProcessToJobObject(job, kernel.GetCurrentProcess()) == false) { + throw new UnsupportedOperationException("AssignProcessToJobObject: " + kernel.GetLastError()); + } + } finally { + kernel.CloseHandle(job); + } + + execSandboxState = ExecSandboxState.ALL_THREADS; + logger.debug("Windows ActiveProcessLimit initialization successful"); + } + @Override public ProcessLimits getProcessLimits() { return new ProcessLimits(ProcessLimits.UNKNOWN, ProcessLimits.UNKNOWN, ProcessLimits.UNKNOWN); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java index 43337f4532bed..dd786b56087e2 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java @@ -101,4 +101,65 @@ interface MemoryBasicInformation { * @see SetConsoleCtrlHandler docs */ boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add); + + /** + * Creates or opens a new job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms682409%28v=vs.85%29.aspx + * Note: the two params to this are omitted because all implementations pass null for them both + * + * @return job handle if the function succeeds + */ + Handle CreateJobObjectW(); + + /** + * Associates a process with an existing job + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms681949%28v=vs.85%29.aspx + * + * @param job job handle + * @param process process handle + * @return true if the function succeeds + */ + boolean AssignProcessToJobObject(Handle job, Handle process); + + /** + * Basic limit information for a job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx + */ + interface JobObjectBasicLimitInformation { + void setLimitFlags(int v); + + void setActiveProcessLimit(int v); + } + + JobObjectBasicLimitInformation newJobObjectBasicLimitInformation(); + + /** + * Get job limit and state information + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684925%28v=vs.85%29.aspx + * Note: The infoLength parameter is omitted because implementions handle passing it + * Note: The returnLength parameter is omitted because all implementations pass null + * + * @param job job handle + * @param infoClass information class constant + * @param info pointer to information structure + * @return true if the function succeeds + */ + boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info); + + /** + * Set job limit and state information + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686216%28v=vs.85%29.aspx + * Note: The infoLength parameter is omitted because implementions handle passing it + * + * @param job job handle + * @param infoClass information class constant + * @param info pointer to information structure + * @return true if the function succeeds + */ + boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java new file mode 100644 index 0000000000000..2a7b10ff3588f --- /dev/null +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.lib; + +public non-sealed interface LinuxCLibrary extends NativeLibrary { + + /** + * Corresponds to struct sock_filter + * @param code insn + * @param jt number of insn to jump (skip) if true + * @param jf number of insn to jump (skip) if false + * @param k additional data + */ + record SockFilter(short code, byte jt, byte jf, int k) {} + + interface SockFProg { + long address(); + } + + SockFProg newSockFProg(SockFilter filters[]); + + /** + * maps to prctl(2) + */ + int prctl(int option, long arg2, long arg3, long arg4, long arg5); + + /** + * used to call seccomp(2), its too new... + * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing + */ + long syscall(long number, int operation, int flags, long address); +} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java new file mode 100644 index 0000000000000..b2b2db9c71c90 --- /dev/null +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.lib; + +public non-sealed interface MacCLibrary extends NativeLibrary { + interface ErrorReference {} + + ErrorReference newErrorReference(); + + /** + * maps to sandbox_init(3), since Leopard + */ + int sandbox_init(String profile, long flags, ErrorReference errorbuf); + + /** + * releases memory when an error occurs during initialization (e.g. syntax bug) + */ + void sandbox_free_error(ErrorReference errorbuf); +} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java index d8098a78935b8..faa0e861dc63f 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java @@ -9,4 +9,5 @@ package org.elasticsearch.nativeaccess.lib; /** A marker interface for libraries that can be loaded by {@link org.elasticsearch.nativeaccess.lib.NativeLibraryProvider} */ -public sealed interface NativeLibrary permits JavaLibrary, PosixCLibrary, Kernel32Library, SystemdLibrary, VectorLibrary, ZstdLibrary {} +public sealed interface NativeLibrary permits JavaLibrary, PosixCLibrary, LinuxCLibrary, MacCLibrary, Kernel32Library, SystemdLibrary, + VectorLibrary, ZstdLibrary {} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java index 96e2a0d0e1cdf..d8db5fa070126 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java @@ -26,6 +26,10 @@ interface RLimit { long rlim_cur(); long rlim_max(); + + void rlim_cur(long v); + + void rlim_max(long v); } /** @@ -41,6 +45,8 @@ interface RLimit { */ int getrlimit(int resource, RLimit rlimit); + int setrlimit(int resource, RLimit rlimit); + /** * Lock all the current process's virtual address space into RAM. * @param flags flags determining how memory will be locked @@ -49,6 +55,22 @@ interface RLimit { */ int mlockall(int flags); + interface FStore { + void set_flags(int flags); /* IN: flags word */ + + void set_posmode(int posmode); /* IN: indicates offset field */ + + void set_offset(long offset); /* IN: start of the region */ + + void set_length(long length); /* IN: size of the region */ + + long bytesalloc(); /* OUT: number of bytes allocated */ + } + + FStore newFStore(); + + int fcntl(int fd, int cmd, FStore fst); + /** * Return a string description for an error. * diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java index bbfd26bd061d0..f5eb5238dad93 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java @@ -72,6 +72,22 @@ class JdkKernel32Library implements Kernel32Library { "handle", ConsoleCtrlHandler_handle$fd ); + private static final MethodHandle CreateJobObjectW$mh = downcallHandleWithError( + "CreateJobObjectW", + FunctionDescriptor.of(ADDRESS, ADDRESS, ADDRESS) + ); + private static final MethodHandle AssignProcessToJobObject$mh = downcallHandleWithError( + "AssignProcessToJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, ADDRESS) + ); + private static final MethodHandle QueryInformationJobObject$mh = downcallHandleWithError( + "QueryInformationJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS) + ); + private static final MethodHandle SetInformationJobObject$mh = downcallHandleWithError( + "SetInformationJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT) + ); private static MethodHandle downcallHandleWithError(String function, FunctionDescriptor functionDescriptor) { return downcallHandle(function, functionDescriptor, CAPTURE_GETLASTERROR_OPTION); @@ -146,6 +162,37 @@ public long Type() { } } + static class JdkJobObjectBasicLimitInformation implements JobObjectBasicLimitInformation { + private static final MemoryLayout layout = MemoryLayout.structLayout( + paddingLayout(16), + JAVA_INT, + paddingLayout(20), + JAVA_INT, + paddingLayout(20) + ).withByteAlignment(8); + + private static final VarHandle LimitFlags$vh = varHandleWithoutOffset(layout, groupElement(1)); + private static final VarHandle ActiveProcessLimit$vh = varHandleWithoutOffset(layout, groupElement(3)); + + private final MemorySegment segment; + + JdkJobObjectBasicLimitInformation() { + var arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + segment.fill((byte) 0); + } + + @Override + public void setLimitFlags(int v) { + LimitFlags$vh.set(segment, v); + } + + @Override + public void setActiveProcessLimit(int v) { + ActiveProcessLimit$vh.set(segment, v); + } + } + private final MemorySegment lastErrorState; JdkKernel32Library() { @@ -262,4 +309,73 @@ public boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add) { throw new AssertionError(t); } } + + @Override + public Handle CreateJobObjectW() { + try { + return new JdkHandle((MemorySegment) CreateJobObjectW$mh.invokeExact(lastErrorState, MemorySegment.NULL, MemorySegment.NULL)); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public boolean AssignProcessToJobObject(Handle job, Handle process) { + assert job instanceof JdkHandle; + assert process instanceof JdkHandle; + var jdkJob = (JdkHandle) job; + var jdkProcess = (JdkHandle) process; + + try { + return (boolean) AssignProcessToJobObject$mh.invokeExact(lastErrorState, jdkJob.address, jdkProcess.address); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public JobObjectBasicLimitInformation newJobObjectBasicLimitInformation() { + return new JdkJobObjectBasicLimitInformation(); + } + + @Override + public boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JdkHandle; + assert info instanceof JdkJobObjectBasicLimitInformation; + var jdkJob = (JdkHandle) job; + var jdkInfo = (JdkJobObjectBasicLimitInformation) info; + + try { + return (boolean) QueryInformationJobObject$mh.invokeExact( + lastErrorState, + jdkJob.address, + infoClass, + jdkInfo.segment, + (int) jdkInfo.segment.byteSize(), + MemorySegment.NULL + ); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JdkHandle; + assert info instanceof JdkJobObjectBasicLimitInformation; + var jdkJob = (JdkHandle) job; + var jdkInfo = (JdkJobObjectBasicLimitInformation) info; + + try { + return (boolean) SetInformationJobObject$mh.invokeExact( + lastErrorState, + jdkJob.address, + infoClass, + jdkInfo.segment, + (int) jdkInfo.segment.byteSize() + ); + } catch (Throwable t) { + throw new AssertionError(t); + } + } } diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java new file mode 100644 index 0000000000000..700941e7e1db0 --- /dev/null +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemoryLayout; +import java.lang.foreign.MemorySegment; +import java.lang.invoke.MethodHandle; + +import static java.lang.foreign.MemoryLayout.paddingLayout; +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static java.lang.foreign.ValueLayout.JAVA_SHORT; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.CAPTURE_ERRNO_OPTION; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.downcallHandleWithErrno; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.errnoState; +import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; + +class JdkLinuxCLibrary implements LinuxCLibrary { + private static final MethodHandle prctl$mh; + static { + try { + prctl$mh = downcallHandleWithErrno( + "prctl", + FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_LONG, JAVA_LONG, JAVA_LONG, JAVA_LONG) + ); + } catch (UnsatisfiedLinkError e) { + throw new UnsupportedOperationException( + "seccomp unavailable: could not link methods. requires kernel 3.5+ " + + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } + } + private static final MethodHandle syscall$mh = downcallHandle( + "syscall", + FunctionDescriptor.of(JAVA_LONG, JAVA_LONG, JAVA_INT, JAVA_INT, JAVA_LONG), + CAPTURE_ERRNO_OPTION, + Linker.Option.firstVariadicArg(1) + ); + + private static class JdkSockFProg implements SockFProg { + private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_SHORT, paddingLayout(6), ADDRESS); + + private final MemorySegment segment; + + JdkSockFProg(SockFilter filters[]) { + Arena arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + var instSegment = arena.allocate(filters.length * 8L); + segment.set(JAVA_SHORT, 0, (short) filters.length); + segment.set(ADDRESS, 8, instSegment); + + int offset = 0; + for (SockFilter f : filters) { + instSegment.set(JAVA_SHORT, offset, f.code()); + instSegment.set(JAVA_BYTE, offset + 2, f.jt()); + instSegment.set(JAVA_BYTE, offset + 3, f.jf()); + instSegment.set(JAVA_INT, offset + 4, f.k()); + offset += 8; + } + } + + @Override + public long address() { + return segment.address(); + } + } + + @Override + public SockFProg newSockFProg(SockFilter[] filters) { + return new JdkSockFProg(filters); + } + + @Override + public int prctl(int option, long arg2, long arg3, long arg4, long arg5) { + try { + return (int) prctl$mh.invokeExact(errnoState, option, arg2, arg3, arg4, arg5); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public long syscall(long number, int operation, int flags, long address) { + try { + return (long) syscall$mh.invokeExact(errnoState, number, operation, flags, address); + } catch (Throwable t) { + throw new AssertionError(t); + } + } +} diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java new file mode 100644 index 0000000000000..b946ca3ca4353 --- /dev/null +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import org.elasticsearch.nativeaccess.lib.MacCLibrary; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; + +class JdkMacCLibrary implements MacCLibrary { + + private static final MethodHandle sandbox_init$mh = downcallHandle( + "sandbox_init", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS) + ); + private static final MethodHandle sandbox_free_error$mh = downcallHandle("sandbox_free_error", FunctionDescriptor.ofVoid(ADDRESS)); + + private static class JdkErrorReference implements ErrorReference { + final Arena arena = Arena.ofConfined(); + final MemorySegment segment = arena.allocate(ValueLayout.ADDRESS); + + MemorySegment deref() { + return segment.get(ADDRESS, 0); + } + + @Override + public String toString() { + return deref().reinterpret(Long.MAX_VALUE).getUtf8String(0); + } + } + + @Override + public ErrorReference newErrorReference() { + return new JdkErrorReference(); + } + + @Override + public int sandbox_init(String profile, long flags, ErrorReference errorbuf) { + assert errorbuf instanceof JdkErrorReference; + var jdkErrorbuf = (JdkErrorReference) errorbuf; + try (Arena arena = Arena.ofConfined()) { + MemorySegment nativeProfile = MemorySegmentUtil.allocateString(arena, profile); + return (int) sandbox_init$mh.invokeExact(nativeProfile, flags, jdkErrorbuf.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public void sandbox_free_error(ErrorReference errorbuf) { + assert errorbuf instanceof JdkErrorReference; + var jdkErrorbuf = (JdkErrorReference) errorbuf; + try { + sandbox_free_error$mh.invokeExact(jdkErrorbuf.deref()); + } catch (Throwable t) { + throw new AssertionError(t); + } + } +} diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java index d76170a55284c..cbd43a394379b 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java @@ -10,6 +10,8 @@ import org.elasticsearch.nativeaccess.lib.JavaLibrary; import org.elasticsearch.nativeaccess.lib.Kernel32Library; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.PosixCLibrary; import org.elasticsearch.nativeaccess.lib.SystemdLibrary; @@ -28,6 +30,10 @@ public JdkNativeLibraryProvider() { JdkJavaLibrary::new, PosixCLibrary.class, JdkPosixCLibrary::new, + LinuxCLibrary.class, + JdkLinuxCLibrary::new, + MacCLibrary.class, + JdkMacCLibrary::new, Kernel32Library.class, JdkKernel32Library::new, SystemdLibrary.class, diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java index 43ec9425ccfaa..1a65225873c1d 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java @@ -43,7 +43,12 @@ class JdkPosixCLibrary implements PosixCLibrary { "getrlimit", FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS) ); + private static final MethodHandle setrlimit$mh = downcallHandleWithErrno( + "setrlimit", + FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS) + ); private static final MethodHandle mlockall$mh = downcallHandleWithErrno("mlockall", FunctionDescriptor.of(JAVA_INT, JAVA_INT)); + private static final MethodHandle fcntl$mh = downcallHandle("fcntl", FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_INT, ADDRESS)); static final MemorySegment errnoState = Arena.ofAuto().allocate(CAPTURE_ERRNO_LAYOUT); @@ -91,6 +96,17 @@ public int getrlimit(int resource, RLimit rlimit) { } } + @Override + public int setrlimit(int resource, RLimit rlimit) { + assert rlimit instanceof JdkRLimit; + var jdkRlimit = (JdkRLimit) rlimit; + try { + return (int) setrlimit$mh.invokeExact(errnoState, resource, jdkRlimit.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + @Override public int mlockall(int flags) { try { @@ -100,6 +116,22 @@ public int mlockall(int flags) { } } + @Override + public FStore newFStore() { + return new JdkFStore(); + } + + @Override + public int fcntl(int fd, int cmd, FStore fst) { + assert fst instanceof JdkFStore; + var jdkFst = (JdkFStore) fst; + try { + return (int) fcntl$mh.invokeExact(errnoState, fd, cmd, jdkFst.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + static class JdkRLimit implements RLimit { private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_LONG, JAVA_LONG); private static final VarHandle rlim_cur$vh = varHandleWithoutOffset(layout, groupElement(0)); @@ -122,9 +154,60 @@ public long rlim_max() { return (long) rlim_max$vh.get(segment); } + @Override + public void rlim_cur(long v) { + rlim_cur$vh.set(segment, v); + } + + @Override + public void rlim_max(long v) { + rlim_max$vh.set(segment, v); + } + @Override public String toString() { return "JdkRLimit[rlim_cur=" + rlim_cur() + ", rlim_max=" + rlim_max(); } } + + private static class JdkFStore implements FStore { + private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_INT, JAVA_INT, JAVA_LONG, JAVA_LONG, JAVA_LONG); + private static final VarHandle st_flags$vh = layout.varHandle(groupElement(0)); + private static final VarHandle st_posmode$vh = layout.varHandle(groupElement(1)); + private static final VarHandle st_offset$vh = layout.varHandle(groupElement(2)); + private static final VarHandle st_length$vh = layout.varHandle(groupElement(3)); + private static final VarHandle st_bytesalloc$vh = layout.varHandle(groupElement(4)); + + private final MemorySegment segment; + + JdkFStore() { + var arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + } + + @Override + public void set_flags(int flags) { + st_flags$vh.set(segment, flags); + } + + @Override + public void set_posmode(int posmode) { + st_posmode$vh.set(segment, posmode); + } + + @Override + public void set_offset(long offset) { + st_offset$vh.get(segment, offset); + } + + @Override + public void set_length(long length) { + st_length$vh.set(segment, length); + } + + @Override + public long bytesalloc() { + return (long) st_bytesalloc$vh.get(segment); + } + } } diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java similarity index 84% rename from qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java rename to libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java index c62522880869b..d4bac13990898 100644 --- a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java @@ -6,12 +6,16 @@ * Side Public License, v 1. */ -package org.elasticsearch.bootstrap; +package org.elasticsearch.nativeaccess; import org.apache.lucene.util.Constants; import org.elasticsearch.test.ESTestCase; +import static org.apache.lucene.tests.util.LuceneTestCase.assumeTrue; +import static org.junit.Assert.fail; + /** Simple tests system call filter is working. */ +@ESTestCase.WithoutSecurityManager public class SystemCallFilterTests extends ESTestCase { /** command to try to run in tests */ @@ -20,15 +24,18 @@ public class SystemCallFilterTests extends ESTestCase { @Override public void setUp() throws Exception { super.setUp(); - assumeTrue("requires system call filter installation", Natives.isSystemCallFilterInstalled()); + assumeTrue( + "requires system call filter installation", + NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE + ); // otherwise security manager will block the execution, no fun assumeTrue("cannot test with security manager enabled", System.getSecurityManager() == null); // otherwise, since we don't have TSYNC support, rules are not applied to the test thread // (randomizedrunner class initialization happens in its own thread, after the test thread is created) // instead we just forcefully run it for the test thread here. - if (JNANatives.LOCAL_SYSTEM_CALL_FILTER_ALL == false) { + if (NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.ALL_THREADS) { try { - SystemCallFilter.init(createTempDir()); + NativeAccess.instance().tryInstallExecSandbox(); } catch (Exception e) { throw new RuntimeException("unable to forcefully apply system call filter to test thread", e); } diff --git a/qa/ccs-rolling-upgrade-remote-cluster/build.gradle b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle index c48674831c422..b63522daa4b4c 100644 --- a/qa/ccs-rolling-upgrade-remote-cluster/build.gradle +++ b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle @@ -58,7 +58,11 @@ BuildParams.bwcVersions.withWireCompatible { bwcVersion, baseName -> dependsOn "processTestResources" mustRunAfter("precommit") doFirst { - localCluster.get().nextNodeToNextVersion() + def cluster = localCluster.get() + cluster.nodes.forEach { node -> + node.getAllTransportPortURI() + } + cluster.nextNodeToNextVersion() } } diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java index a60262ff4a097..84811362c08e6 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java @@ -584,7 +584,7 @@ public BootstrapCheckResult check(BootstrapContext context) { // visible for testing boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); + return NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE; } @Override @@ -608,7 +608,7 @@ public BootstrapCheckResult check(BootstrapContext context) { // visible for testing boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); + return NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE; } // visible for testing diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java index f8ad9dd59650c..005375bf38540 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java @@ -27,16 +27,6 @@ public final class BootstrapInfo { /** no instantiation */ private BootstrapInfo() {} - /** - * Returns true if we successfully loaded native libraries. - *

- * If this returns false, then native operations such as locking - * memory did not work. - */ - public static boolean isNativesAvailable() { - return Natives.JNA_AVAILABLE; - } - /** * Returns true if we were able to lock the process's address space. */ @@ -44,13 +34,6 @@ public static boolean isMemoryLocked() { return NativeAccess.instance().isMemoryLocked(); } - /** - * Returns true if system call filter is installed (supported systems only) - */ - public static boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); - } - /** * Returns information about the console (tty) attached to the server process, or {@code null} * if no console is attached. diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index 082e1dd9257e0..3fc659cb8065d 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -293,7 +293,7 @@ static void initializeNatives(final Path tmpFile, final boolean mlockAll, final * * TODO: should we fail hard here if system call filters fail to install, or remain lenient in non-production environments? */ - Natives.tryInstallSystemCallFilter(tmpFile); + nativeAccess.tryInstallExecSandbox(); } // mlockall if requested @@ -316,13 +316,6 @@ static void initializeNatives(final Path tmpFile, final boolean mlockAll, final } } - // force remainder of JNA to be loaded (if available). - try { - JNAKernel32Library.getInstance(); - } catch (Exception ignored) { - // we've already logged this. - } - // init lucene random seed. it will use /dev/urandom where available: StringHelper.randomId(); diff --git a/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java b/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java deleted file mode 100644 index 01d9a122138f1..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import com.sun.jna.IntegerType; -import com.sun.jna.Native; -import com.sun.jna.NativeLong; -import com.sun.jna.Pointer; -import com.sun.jna.Structure; -import com.sun.jna.WString; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.Constants; - -import java.util.Arrays; -import java.util.List; - -/** - * Library for Windows/Kernel32 - */ -final class JNAKernel32Library { - - private static final Logger logger = LogManager.getLogger(JNAKernel32Library.class); - - // Native library instance must be kept around for the same reason. - private static final class Holder { - private static final JNAKernel32Library instance = new JNAKernel32Library(); - } - - private JNAKernel32Library() { - if (Constants.WINDOWS) { - try { - Native.register("kernel32"); - logger.debug("windows/Kernel32 library loaded"); - } catch (NoClassDefFoundError e) { - logger.warn("JNA not found. native methods and handlers will be disabled."); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link Windows/Kernel32 library. native methods and handlers will be disabled."); - } - } - } - - static JNAKernel32Library getInstance() { - return Holder.instance; - } - - /** - * Memory protection constraints - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366786%28v=vs.85%29.aspx - */ - public static final int PAGE_NOACCESS = 0x0001; - public static final int PAGE_GUARD = 0x0100; - public static final int MEM_COMMIT = 0x1000; - - /** - * Contains information about a range of pages in the virtual address space of a process. - * The VirtualQuery and VirtualQueryEx functions use this structure. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366775%28v=vs.85%29.aspx - */ - public static class MemoryBasicInformation extends Structure { - public Pointer BaseAddress; - public Pointer AllocationBase; - public NativeLong AllocationProtect; - public SizeT RegionSize; - public NativeLong State; - public NativeLong Protect; - public NativeLong Type; - - @Override - protected List getFieldOrder() { - return Arrays.asList("BaseAddress", "AllocationBase", "AllocationProtect", "RegionSize", "State", "Protect", "Type"); - } - } - - public static class SizeT extends IntegerType { - - // JNA requires this no-arg constructor to be public, - // otherwise it fails to register kernel32 library - public SizeT() { - this(0); - } - - SizeT(long value) { - super(Native.SIZE_T_SIZE, value); - } - - } - - /** - * Locks the specified region of the process's virtual address space into physical - * memory, ensuring that subsequent access to the region will not incur a page fault. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366895%28v=vs.85%29.aspx - * - * @param address A pointer to the base address of the region of pages to be locked. - * @param size The size of the region to be locked, in bytes. - * @return true if the function succeeds - */ - native boolean VirtualLock(Pointer address, SizeT size); - - /** - * Retrieves information about a range of pages within the virtual address space of a specified process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366907%28v=vs.85%29.aspx - * - * @param handle A handle to the process whose memory information is queried. - * @param address A pointer to the base address of the region of pages to be queried. - * @param memoryInfo A pointer to a structure in which information about the specified page range is returned. - * @param length The size of the buffer pointed to by the memoryInfo parameter, in bytes. - * @return the actual number of bytes returned in the information buffer. - */ - native int VirtualQueryEx(Pointer handle, Pointer address, MemoryBasicInformation memoryInfo, int length); - - /** - * Sets the minimum and maximum working set sizes for the specified process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686234%28v=vs.85%29.aspx - * - * @param handle A handle to the process whose working set sizes is to be set. - * @param minSize The minimum working set size for the process, in bytes. - * @param maxSize The maximum working set size for the process, in bytes. - * @return true if the function succeeds. - */ - native boolean SetProcessWorkingSetSize(Pointer handle, SizeT minSize, SizeT maxSize); - - /** - * Retrieves a pseudo handle for the current process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms683179%28v=vs.85%29.aspx - * - * @return a pseudo handle to the current process. - */ - native Pointer GetCurrentProcess(); - - /** - * Closes an open object handle. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms724211%28v=vs.85%29.aspx - * - * @param handle A valid handle to an open object. - * @return true if the function succeeds. - */ - native boolean CloseHandle(Pointer handle); - - /** - * Retrieves the short path form of the specified path. See - * {@code GetShortPathName}. - * - * @param lpszLongPath the path string - * @param lpszShortPath a buffer to receive the short name - * @param cchBuffer the size of the buffer - * @return the length of the string copied into {@code lpszShortPath}, otherwise zero for failure - */ - native int GetShortPathNameW(WString lpszLongPath, char[] lpszShortPath, int cchBuffer); - - /** - * Creates or opens a new job object - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms682409%28v=vs.85%29.aspx - * - * @param jobAttributes security attributes - * @param name job name - * @return job handle if the function succeeds - */ - native Pointer CreateJobObjectW(Pointer jobAttributes, String name); - - /** - * Associates a process with an existing job - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms681949%28v=vs.85%29.aspx - * - * @param job job handle - * @param process process handle - * @return true if the function succeeds - */ - native boolean AssignProcessToJobObject(Pointer job, Pointer process); - - /** - * Basic limit information for a job object - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx - */ - public static class JOBOBJECT_BASIC_LIMIT_INFORMATION extends Structure implements Structure.ByReference { - public long PerProcessUserTimeLimit; - public long PerJobUserTimeLimit; - public int LimitFlags; - public SizeT MinimumWorkingSetSize; - public SizeT MaximumWorkingSetSize; - public int ActiveProcessLimit; - public Pointer Affinity; - public int PriorityClass; - public int SchedulingClass; - - @Override - protected List getFieldOrder() { - return Arrays.asList( - "PerProcessUserTimeLimit", - "PerJobUserTimeLimit", - "LimitFlags", - "MinimumWorkingSetSize", - "MaximumWorkingSetSize", - "ActiveProcessLimit", - "Affinity", - "PriorityClass", - "SchedulingClass" - ); - } - } - - /** - * Constant for JOBOBJECT_BASIC_LIMIT_INFORMATION in Query/Set InformationJobObject - */ - static final int JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS = 2; - - /** - * Constant for LimitFlags, indicating a process limit has been set - */ - static final int JOB_OBJECT_LIMIT_ACTIVE_PROCESS = 8; - - /** - * Get job limit and state information - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684925%28v=vs.85%29.aspx - * - * @param job job handle - * @param infoClass information class constant - * @param info pointer to information structure - * @param infoLength size of information structure - * @param returnLength length of data written back to structure (or null if not wanted) - * @return true if the function succeeds - */ - native boolean QueryInformationJobObject(Pointer job, int infoClass, Pointer info, int infoLength, Pointer returnLength); - - /** - * Set job limit and state information - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686216%28v=vs.85%29.aspx - * - * @param job job handle - * @param infoClass information class constant - * @param info pointer to information structure - * @param infoLength size of information structure - * @return true if the function succeeds - */ - native boolean SetInformationJobObject(Pointer job, int infoClass, Pointer info, int infoLength); -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java b/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java deleted file mode 100644 index ba4e90ee2c6c1..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import java.nio.file.Path; - -/** - * This class performs the actual work with JNA and library bindings to call native methods. It should only be used after - * we are sure that the JNA classes are available to the JVM - */ -class JNANatives { - - /** no instantiation */ - private JNANatives() {} - - private static final Logger logger = LogManager.getLogger(JNANatives.class); - - // Set to true, in case native system call filter install was successful - static boolean LOCAL_SYSTEM_CALL_FILTER = false; - // Set to true, in case policy can be applied to all threads of the process (even existing ones) - // otherwise they are only inherited for new threads (ES app threads) - static boolean LOCAL_SYSTEM_CALL_FILTER_ALL = false; - - static void tryInstallSystemCallFilter(Path tmpFile) { - try { - int ret = SystemCallFilter.init(tmpFile); - LOCAL_SYSTEM_CALL_FILTER = true; - if (ret == 1) { - LOCAL_SYSTEM_CALL_FILTER_ALL = true; - } - } catch (Exception e) { - // this is likely to happen unless the kernel is newish, its a best effort at the moment - // so we log stacktrace at debug for now... - if (logger.isDebugEnabled()) { - logger.debug("unable to install syscall filter", e); - } - logger.warn("unable to install syscall filter: ", e); - } - } - -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Natives.java b/server/src/main/java/org/elasticsearch/bootstrap/Natives.java deleted file mode 100644 index c792d1e0bfad0..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/Natives.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.ReferenceDocs; - -import java.lang.invoke.MethodHandles; -import java.nio.file.Path; -import java.util.Locale; - -/** - * The Natives class is a wrapper class that checks if the classes necessary for calling native methods are available on - * startup. If they are not available, this class will avoid calling code that loads these classes. - */ -final class Natives { - /** no instantiation */ - private Natives() {} - - private static final Logger logger = LogManager.getLogger(Natives.class); - - // marker to determine if the JNA class files are available to the JVM - static final boolean JNA_AVAILABLE; - - static { - boolean v = false; - try { - // load one of the main JNA classes to see if the classes are available. this does not ensure that all native - // libraries are available, only the ones necessary by JNA to function - MethodHandles.publicLookup().ensureInitialized(com.sun.jna.Native.class); - v = true; - } catch (IllegalAccessException e) { - throw new AssertionError(e); - } catch (UnsatisfiedLinkError e) { - logger.warn( - String.format( - Locale.ROOT, - "unable to load JNA native support library, native methods will be disabled. See %s", - ReferenceDocs.EXECUTABLE_JNA_TMPDIR - ), - e - ); - } - JNA_AVAILABLE = v; - } - - static void tryInstallSystemCallFilter(Path tmpFile) { - if (JNA_AVAILABLE == false) { - logger.warn("cannot install system call filter because JNA is not available"); - return; - } - JNANatives.tryInstallSystemCallFilter(tmpFile); - } - - static boolean isSystemCallFilterInstalled() { - if (JNA_AVAILABLE == false) { - return false; - } - return JNANatives.LOCAL_SYSTEM_CALL_FILTER; - } - -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java b/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java deleted file mode 100644 index 0ab855d1d5f3a..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java +++ /dev/null @@ -1,641 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import com.sun.jna.Library; -import com.sun.jna.Memory; -import com.sun.jna.Native; -import com.sun.jna.NativeLong; -import com.sun.jna.Pointer; -import com.sun.jna.Structure; -import com.sun.jna.ptr.PointerByReference; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.Constants; -import org.elasticsearch.core.IOUtils; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * Installs a system call filter to block process execution. - *

- * This is supported on Linux, Solaris, FreeBSD, OpenBSD, Mac OS X, and Windows. - *

- * On Linux it currently supports amd64 and i386 architectures, requires Linux kernel 3.5 or above, and requires - * {@code CONFIG_SECCOMP} and {@code CONFIG_SECCOMP_FILTER} compiled into the kernel. - *

- * On Linux BPF Filters are installed using either {@code seccomp(2)} (3.17+) or {@code prctl(2)} (3.5+). {@code seccomp(2)} - * is preferred, as it allows filters to be applied to any existing threads in the process, and one motivation - * here is to protect against bugs in the JVM. Otherwise, code will fall back to the {@code prctl(2)} method - * which will at least protect elasticsearch application threads. - *

- * Linux BPF filters will return {@code EACCES} (Access Denied) for the following system calls: - *

    - *
  • {@code execve}
  • - *
  • {@code fork}
  • - *
  • {@code vfork}
  • - *
  • {@code execveat}
  • - *
- *

- * On Solaris 10 or higher, the following privileges are dropped with {@code priv_set(3C)}: - *

    - *
  • {@code PRIV_PROC_FORK}
  • - *
  • {@code PRIV_PROC_EXEC}
  • - *
- *

- * On BSD systems, process creation is restricted with {@code setrlimit(RLIMIT_NPROC)}. - *

- * On Mac OS X Leopard or above, a custom {@code sandbox(7)} ("Seatbelt") profile is installed that - * denies the following rules: - *

    - *
  • {@code process-fork}
  • - *
  • {@code process-exec}
  • - *
- *

- * On Windows, process creation is restricted with {@code SetInformationJobObject/ActiveProcessLimit}. - *

- * This is not intended as a sandbox. It is another level of security, mostly intended to annoy - * security researchers and make their lives more difficult in achieving "remote execution" exploits. - * @see - * http://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt - * @see - * https://reverse.put.as/wp-content/uploads/2011/06/The-Apple-Sandbox-BHDC2011-Paper.pdf - * @see - * https://docs.oracle.com/cd/E23824_01/html/821-1456/prbac-2.html - */ -// not an example of how to write code!!! -final class SystemCallFilter { - private static final Logger logger = LogManager.getLogger(SystemCallFilter.class); - - // Linux implementation, based on seccomp(2) or prctl(2) with bpf filtering - - /** Access to non-standard Linux libc methods */ - interface LinuxLibrary extends Library { - /** - * maps to prctl(2) - */ - int prctl(int option, NativeLong arg2, NativeLong arg3, NativeLong arg4, NativeLong arg5); - - /** - * used to call seccomp(2), its too new... - * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing - */ - NativeLong syscall(NativeLong number, Object... args); - } - - // null if unavailable or something goes wrong. - private static final LinuxLibrary linux_libc; - - static { - LinuxLibrary lib = null; - if (Constants.LINUX) { - try { - lib = Native.loadLibrary("c", LinuxLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (seccomp) will be disabled.", e); - } - } - linux_libc = lib; - } - - /** the preferred method is seccomp(2), since we can apply to all threads of the process */ - static final int SECCOMP_SET_MODE_FILTER = 1; // since Linux 3.17 - static final int SECCOMP_FILTER_FLAG_TSYNC = 1; // since Linux 3.17 - - /** otherwise, we can use prctl(2), which will at least protect ES application threads */ - static final int PR_GET_NO_NEW_PRIVS = 39; // since Linux 3.5 - static final int PR_SET_NO_NEW_PRIVS = 38; // since Linux 3.5 - static final int PR_GET_SECCOMP = 21; // since Linux 2.6.23 - static final int PR_SET_SECCOMP = 22; // since Linux 2.6.23 - static final long SECCOMP_MODE_FILTER = 2; // since Linux Linux 3.5 - - /** corresponds to struct sock_filter */ - static final class SockFilter { - short code; // insn - byte jt; // number of insn to jump (skip) if true - byte jf; // number of insn to jump (skip) if false - int k; // additional data - - SockFilter(short code, byte jt, byte jf, int k) { - this.code = code; - this.jt = jt; - this.jf = jf; - this.k = k; - } - } - - /** corresponds to struct sock_fprog */ - public static final class SockFProg extends Structure implements Structure.ByReference { - public short len; // number of filters - public Pointer filter; // filters - - SockFProg(SockFilter filters[]) { - len = (short) filters.length; - // serialize struct sock_filter * explicitly, its less confusing than the JNA magic we would need - Memory filter = new Memory(len * 8); - ByteBuffer bbuf = filter.getByteBuffer(0, len * 8); - bbuf.order(ByteOrder.nativeOrder()); // little endian - for (SockFilter f : filters) { - bbuf.putShort(f.code); - bbuf.put(f.jt); - bbuf.put(f.jf); - bbuf.putInt(f.k); - } - this.filter = filter; - } - - @Override - protected List getFieldOrder() { - return Arrays.asList("len", "filter"); - } - } - - // BPF "macros" and constants - static final int BPF_LD = 0x00; - static final int BPF_W = 0x00; - static final int BPF_ABS = 0x20; - static final int BPF_JMP = 0x05; - static final int BPF_JEQ = 0x10; - static final int BPF_JGE = 0x30; - static final int BPF_JGT = 0x20; - static final int BPF_RET = 0x06; - static final int BPF_K = 0x00; - - static SockFilter BPF_STMT(int code, int k) { - return new SockFilter((short) code, (byte) 0, (byte) 0, k); - } - - static SockFilter BPF_JUMP(int code, int k, int jt, int jf) { - return new SockFilter((short) code, (byte) jt, (byte) jf, k); - } - - static final int SECCOMP_RET_ERRNO = 0x00050000; - static final int SECCOMP_RET_DATA = 0x0000FFFF; - static final int SECCOMP_RET_ALLOW = 0x7FFF0000; - - // some errno constants for error checking/handling - static final int EACCES = 0x0D; - static final int EFAULT = 0x0E; - static final int EINVAL = 0x16; - static final int ENOSYS = 0x26; - - // offsets that our BPF checks - // check with offsetof() when adding a new arch, move to Arch if different. - static final int SECCOMP_DATA_NR_OFFSET = 0x00; - static final int SECCOMP_DATA_ARCH_OFFSET = 0x04; - - record Arch( - int audit, // AUDIT_ARCH_XXX constant from linux/audit.h - int limit, // syscall limit (necessary for blacklisting on amd64, to ban 32-bit syscalls) - int fork, // __NR_fork - int vfork, // __NR_vfork - int execve, // __NR_execve - int execveat, // __NR_execveat - int seccomp // __NR_seccomp - ) {} - - /** supported architectures map keyed by os.arch */ - private static final Map ARCHITECTURES; - static { - ARCHITECTURES = Map.of( - "amd64", - new Arch(0xC000003E, 0x3FFFFFFF, 57, 58, 59, 322, 317), - "aarch64", - new Arch(0xC00000B7, 0xFFFFFFFF, 1079, 1071, 221, 281, 277) - ); - } - - /** invokes prctl() from linux libc library */ - private static int linux_prctl(int option, long arg2, long arg3, long arg4, long arg5) { - return linux_libc.prctl(option, new NativeLong(arg2), new NativeLong(arg3), new NativeLong(arg4), new NativeLong(arg5)); - } - - /** invokes syscall() from linux libc library */ - private static long linux_syscall(long number, Object... args) { - return linux_libc.syscall(new NativeLong(number), args).longValue(); - } - - /** try to install our BPF filters via seccomp() or prctl() to block execution */ - private static int linuxImpl() { - // first be defensive: we can give nice errors this way, at the very least. - // also, some of these security features get backported to old versions, checking kernel version here is a big no-no! - final Arch arch = ARCHITECTURES.get(Constants.OS_ARCH); - boolean supported = Constants.LINUX && arch != null; - if (supported == false) { - throw new UnsupportedOperationException("seccomp unavailable: '" + Constants.OS_ARCH + "' architecture unsupported"); - } - - // we couldn't link methods, could be some really ancient kernel (e.g. < 2.1.57) or some bug - if (linux_libc == null) { - throw new UnsupportedOperationException( - "seccomp unavailable: could not link methods. requires kernel 3.5+ " - + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" - ); - } - - // try to check system calls really are who they claim - // you never know (e.g. https://chromium.googlesource.com/chromium/src.git/+/master/sandbox/linux/seccomp-bpf/sandbox_bpf.cc#57) - final int bogusArg = 0xf7a46a5c; - - // test seccomp(BOGUS) - long ret = linux_syscall(arch.seccomp, bogusArg); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: seccomp(BOGUS_OPERATION) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("seccomp(BOGUS_OPERATION): " + JNACLibrary.strerror(errno)); - } - } - - // test seccomp(VALID, BOGUS) - ret = linux_syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, bogusArg); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG): " + JNACLibrary.strerror(errno)); - } - } - - // test prctl(BOGUS) - ret = linux_prctl(bogusArg, 0, 0, 0, 0); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: prctl(BOGUS_OPTION) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("prctl(BOGUS_OPTION): " + JNACLibrary.strerror(errno)); - } - } - - // now just normal defensive checks - - // check for GET_NO_NEW_PRIVS - switch (linux_prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)) { - case 0: - break; // not yet set - case 1: - break; // already set by caller - default: - int errno = Native.getLastError(); - if (errno == EINVAL) { - // friendly error, this will be the typical case for an old kernel - throw new UnsupportedOperationException( - "seccomp unavailable: requires kernel 3.5+ with" + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" - ); - } else { - throw new UnsupportedOperationException("prctl(PR_GET_NO_NEW_PRIVS): " + JNACLibrary.strerror(errno)); - } - } - // check for SECCOMP - switch (linux_prctl(PR_GET_SECCOMP, 0, 0, 0, 0)) { - case 0: - break; // not yet set - case 2: - break; // already in filter mode by caller - default: - int errno = Native.getLastError(); - if (errno == EINVAL) { - throw new UnsupportedOperationException( - "seccomp unavailable: CONFIG_SECCOMP not compiled into kernel," - + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" - ); - } else { - throw new UnsupportedOperationException("prctl(PR_GET_SECCOMP): " + JNACLibrary.strerror(errno)); - } - } - // check for SECCOMP_MODE_FILTER - if (linux_prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, 0, 0, 0) != 0) { - int errno = Native.getLastError(); - switch (errno) { - case EFAULT: - break; // available - case EINVAL: - throw new UnsupportedOperationException( - "seccomp unavailable: CONFIG_SECCOMP_FILTER not" - + " compiled into kernel, CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" - ); - default: - throw new UnsupportedOperationException("prctl(PR_SET_SECCOMP): " + JNACLibrary.strerror(errno)); - } - } - - // ok, now set PR_SET_NO_NEW_PRIVS, needed to be able to set a seccomp filter as ordinary user - if (linux_prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) != 0) { - throw new UnsupportedOperationException("prctl(PR_SET_NO_NEW_PRIVS): " + JNACLibrary.strerror(Native.getLastError())); - } - - // check it worked - if (linux_prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0) != 1) { - throw new UnsupportedOperationException( - "seccomp filter did not really succeed: prctl(PR_GET_NO_NEW_PRIVS): " + JNACLibrary.strerror(Native.getLastError()) - ); - } - - // BPF installed to check arch, limit, then syscall. - // See https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt for details. - SockFilter insns[] = { - /* 1 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_ARCH_OFFSET), // - /* 2 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.audit, 0, 7), // if (arch != audit) goto fail; - /* 3 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_NR_OFFSET), // - /* 4 */ BPF_JUMP(BPF_JMP + BPF_JGT + BPF_K, arch.limit, 5, 0), // if (syscall > LIMIT) goto fail; - /* 5 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.fork, 4, 0), // if (syscall == FORK) goto fail; - /* 6 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.vfork, 3, 0), // if (syscall == VFORK) goto fail; - /* 7 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execve, 2, 0), // if (syscall == EXECVE) goto fail; - /* 8 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execveat, 1, 0), // if (syscall == EXECVEAT) goto fail; - /* 9 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), // pass: return OK; - /* 10 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ERRNO | (EACCES & SECCOMP_RET_DATA)), // fail: return EACCES; - }; - // seccomp takes a long, so we pass it one explicitly to keep the JNA simple - SockFProg prog = new SockFProg(insns); - prog.write(); - long pointer = Pointer.nativeValue(prog.getPointer()); - - int method = 1; - // install filter, if this works, after this there is no going back! - // first try it with seccomp(SECCOMP_SET_MODE_FILTER), falling back to prctl() - if (linux_syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_TSYNC, new NativeLong(pointer)) != 0) { - method = 0; - int errno1 = Native.getLastError(); - if (logger.isDebugEnabled()) { - logger.debug( - "seccomp(SECCOMP_SET_MODE_FILTER): {}, falling back to prctl(PR_SET_SECCOMP)...", - JNACLibrary.strerror(errno1) - ); - } - if (linux_prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, pointer, 0, 0) != 0) { - int errno2 = Native.getLastError(); - throw new UnsupportedOperationException( - "seccomp(SECCOMP_SET_MODE_FILTER): " - + JNACLibrary.strerror(errno1) - + ", prctl(PR_SET_SECCOMP): " - + JNACLibrary.strerror(errno2) - ); - } - } - - // now check that the filter was really installed, we should be in filter mode. - if (linux_prctl(PR_GET_SECCOMP, 0, 0, 0, 0) != 2) { - throw new UnsupportedOperationException( - "seccomp filter installation did not really succeed. seccomp(PR_GET_SECCOMP): " - + JNACLibrary.strerror(Native.getLastError()) - ); - } - - logger.debug("Linux seccomp filter installation successful, threads: [{}]", method == 1 ? "all" : "app"); - return method; - } - - // OS X implementation via sandbox(7) - - /** Access to non-standard OS X libc methods */ - interface MacLibrary extends Library { - /** - * maps to sandbox_init(3), since Leopard - */ - int sandbox_init(String profile, long flags, PointerByReference errorbuf); - - /** - * releases memory when an error occurs during initialization (e.g. syntax bug) - */ - void sandbox_free_error(Pointer errorbuf); - } - - // null if unavailable, or something goes wrong. - private static final MacLibrary libc_mac; - - static { - MacLibrary lib = null; - if (Constants.MAC_OS_X) { - try { - lib = Native.loadLibrary("c", MacLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (seatbelt) will be disabled.", e); - } - } - libc_mac = lib; - } - - /** The only supported flag... */ - static final int SANDBOX_NAMED = 1; - /** Allow everything except process fork and execution */ - static final String SANDBOX_RULES = "(version 1) (allow default) (deny process-fork) (deny process-exec)"; - - /** try to install our custom rule profile into sandbox_init() to block execution */ - private static void macImpl(Path tmpFile) throws IOException { - // first be defensive: we can give nice errors this way, at the very least. - boolean supported = Constants.MAC_OS_X; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize seatbelt for an unsupported OS"); - } - - // we couldn't link methods, could be some really ancient OS X (< Leopard) or some bug - if (libc_mac == null) { - throw new UnsupportedOperationException("seatbelt unavailable: could not link methods. requires Leopard or above."); - } - - // write rules to a temporary file, which will be passed to sandbox_init() - Path rules = Files.createTempFile(tmpFile, "es", "sb"); - Files.write(rules, Collections.singleton(SANDBOX_RULES)); - - boolean success = false; - try { - PointerByReference errorRef = new PointerByReference(); - int ret = libc_mac.sandbox_init(rules.toAbsolutePath().toString(), SANDBOX_NAMED, errorRef); - // if sandbox_init() fails, add the message from the OS (e.g. syntax error) and free the buffer - if (ret != 0) { - Pointer errorBuf = errorRef.getValue(); - RuntimeException e = new UnsupportedOperationException("sandbox_init(): " + errorBuf.getString(0)); - libc_mac.sandbox_free_error(errorBuf); - throw e; - } - logger.debug("OS X seatbelt initialization successful"); - success = true; - } finally { - if (success) { - Files.delete(rules); - } else { - IOUtils.deleteFilesIgnoringExceptions(rules); - } - } - } - - // Solaris implementation via priv_set(3C) - - /** Access to non-standard Solaris libc methods */ - interface SolarisLibrary extends Library { - /** - * see priv_set(3C), a convenience method for setppriv(2). - */ - int priv_set(int op, String which, String... privs); - } - - // null if unavailable, or something goes wrong. - private static final SolarisLibrary libc_solaris; - - static { - SolarisLibrary lib = null; - if (Constants.SUN_OS) { - try { - lib = Native.loadLibrary("c", SolarisLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (priv_set) will be disabled.", e); - } - } - libc_solaris = lib; - } - - // constants for priv_set(2) - static final int PRIV_OFF = 1; - static final String PRIV_ALLSETS = null; - // see privileges(5) for complete list of these - static final String PRIV_PROC_FORK = "proc_fork"; - static final String PRIV_PROC_EXEC = "proc_exec"; - - static void solarisImpl() { - // first be defensive: we can give nice errors this way, at the very least. - boolean supported = Constants.SUN_OS; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize priv_set for an unsupported OS"); - } - - // we couldn't link methods, could be some really ancient Solaris or some bug - if (libc_solaris == null) { - throw new UnsupportedOperationException("priv_set unavailable: could not link methods. requires Solaris 10+"); - } - - // drop a null-terminated list of privileges - if (libc_solaris.priv_set(PRIV_OFF, PRIV_ALLSETS, PRIV_PROC_FORK, PRIV_PROC_EXEC, null) != 0) { - throw new UnsupportedOperationException("priv_set unavailable: priv_set(): " + JNACLibrary.strerror(Native.getLastError())); - } - - logger.debug("Solaris priv_set initialization successful"); - } - - // BSD implementation via setrlimit(2) - - // TODO: add OpenBSD to Lucene Constants - // TODO: JNA doesn't have netbsd support, but this mechanism should work there too. - static final boolean OPENBSD = Constants.OS_NAME.startsWith("OpenBSD"); - - // not a standard limit, means something different on linux, etc! - static final int RLIMIT_NPROC = 7; - - static void bsdImpl() { - boolean supported = Constants.FREE_BSD || OPENBSD || Constants.MAC_OS_X; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize RLIMIT_NPROC for an unsupported OS"); - } - - JNACLibrary.Rlimit limit = new JNACLibrary.Rlimit(); - limit.rlim_cur.setValue(0); - limit.rlim_max.setValue(0); - if (JNACLibrary.setrlimit(RLIMIT_NPROC, limit) != 0) { - throw new UnsupportedOperationException("RLIMIT_NPROC unavailable: " + JNACLibrary.strerror(Native.getLastError())); - } - - logger.debug("BSD RLIMIT_NPROC initialization successful"); - } - - // windows impl via job ActiveProcessLimit - - static void windowsImpl() { - if (Constants.WINDOWS == false) { - throw new IllegalStateException("bug: should not be trying to initialize ActiveProcessLimit for an unsupported OS"); - } - - JNAKernel32Library lib = JNAKernel32Library.getInstance(); - - // create a new Job - Pointer job = lib.CreateJobObjectW(null, null); - if (job == null) { - throw new UnsupportedOperationException("CreateJobObject: " + Native.getLastError()); - } - - try { - // retrieve the current basic limits of the job - int clazz = JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS; - JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION limits = new JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION(); - limits.write(); - if (lib.QueryInformationJobObject(job, clazz, limits.getPointer(), limits.size(), null) == false) { - throw new UnsupportedOperationException("QueryInformationJobObject: " + Native.getLastError()); - } - limits.read(); - // modify the number of active processes to be 1 (exactly the one process we will add to the job). - limits.ActiveProcessLimit = 1; - limits.LimitFlags = JNAKernel32Library.JOB_OBJECT_LIMIT_ACTIVE_PROCESS; - limits.write(); - if (lib.SetInformationJobObject(job, clazz, limits.getPointer(), limits.size()) == false) { - throw new UnsupportedOperationException("SetInformationJobObject: " + Native.getLastError()); - } - // assign ourselves to the job - if (lib.AssignProcessToJobObject(job, lib.GetCurrentProcess()) == false) { - throw new UnsupportedOperationException("AssignProcessToJobObject: " + Native.getLastError()); - } - } finally { - lib.CloseHandle(job); - } - - logger.debug("Windows ActiveProcessLimit initialization successful"); - } - - /** - * Attempt to drop the capability to execute for the process. - *

- * This is best effort and OS and architecture dependent. It may throw any Throwable. - * @return 0 if we can do this for application threads, 1 for the entire process - */ - static int init(Path tmpFile) throws Exception { - if (Constants.LINUX) { - return linuxImpl(); - } else if (Constants.MAC_OS_X) { - // try to enable both mechanisms if possible - bsdImpl(); - macImpl(tmpFile); - return 1; - } else if (Constants.SUN_OS) { - solarisImpl(); - return 1; - } else if (Constants.FREE_BSD || OPENBSD) { - bsdImpl(); - return 1; - } else if (Constants.WINDOWS) { - windowsImpl(); - return 1; - } else { - throw new UnsupportedOperationException("syscall filtering not supported for OS: '" + Constants.OS_NAME + "'"); - } - } -} From 5d6e6c12cacc02abebba95b69236784ced2d3b03 Mon Sep 17 00:00:00 2001 From: Mark Vieira Date: Tue, 9 Jul 2024 13:33:18 -0700 Subject: [PATCH 54/64] Remove Windows BWC pull request pipeline (#110664) We've already removed Windows-specific BWC jobs in our periodic pipelines. They shouldn't behave differently and are very prone to timeouts so let's just remove them from pull requests when the `test-windows` label is added. --- .../pull-request/bwc-snapshots-windows.yml | 20 ------------------- 1 file changed, 20 deletions(-) delete mode 100644 .buildkite/pipelines/pull-request/bwc-snapshots-windows.yml diff --git a/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml b/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml deleted file mode 100644 index d37bdf380f926..0000000000000 --- a/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml +++ /dev/null @@ -1,20 +0,0 @@ -config: - allow-labels: test-windows -steps: - - group: bwc-snapshots-windows - steps: - - label: "{{matrix.BWC_VERSION}} / bwc-snapshots-windows" - key: "bwc-snapshots-windows" - command: .\.buildkite\scripts\run-script.ps1 bash .buildkite/scripts/windows-run-gradle.sh - env: - GRADLE_TASK: "v{{matrix.BWC_VERSION}}#bwcTest" - timeout_in_minutes: 300 - matrix: - setup: - BWC_VERSION: $SNAPSHOT_BWC_VERSIONS - agents: - provider: gcp - image: family/elasticsearch-windows-2022 - machineType: custom-32-98304 - diskType: pd-ssd - diskSizeGb: 350 From c37fa227bd2ea2f16e013daa11d85bdf3c538d4e Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Tue, 9 Jul 2024 17:52:53 -0500 Subject: [PATCH 55/64] Removing the use of Stream::peek from GeoIpDownloader::cleanDatabases (#110666) --- docs/changelog/110666.yaml | 5 ++++ .../ingest/geoip/GeoIpDownloader.java | 23 ++++++++----------- .../ingest/geoip/GeoIpDownloaderTests.java | 3 +-- 3 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 docs/changelog/110666.yaml diff --git a/docs/changelog/110666.yaml b/docs/changelog/110666.yaml new file mode 100644 index 0000000000000..d96f8e2024c81 --- /dev/null +++ b/docs/changelog/110666.yaml @@ -0,0 +1,5 @@ +pr: 110666 +summary: Removing the use of Stream::peek from `GeoIpDownloader::cleanDatabases` +area: Ingest Node +type: bug +issues: [] diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java index 5239e96856b7f..13394a2a0c7cc 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java @@ -318,22 +318,19 @@ public void requestReschedule() { } private void cleanDatabases() { - long expiredDatabases = state.getDatabases() + List> expiredDatabases = state.getDatabases() .entrySet() .stream() .filter(e -> e.getValue().isValid(clusterService.state().metadata().settings()) == false) - .peek(e -> { - String name = e.getKey(); - Metadata meta = e.getValue(); - deleteOldChunks(name, meta.lastChunk() + 1); - state = state.put( - name, - new Metadata(meta.lastUpdate(), meta.firstChunk(), meta.lastChunk(), meta.md5(), meta.lastCheck() - 1) - ); - updateTaskState(); - }) - .count(); - stats = stats.expiredDatabases((int) expiredDatabases); + .toList(); + expiredDatabases.forEach(e -> { + String name = e.getKey(); + Metadata meta = e.getValue(); + deleteOldChunks(name, meta.lastChunk() + 1); + state = state.put(name, new Metadata(meta.lastUpdate(), meta.firstChunk(), meta.lastChunk(), meta.md5(), meta.lastCheck() - 1)); + updateTaskState(); + }); + stats = stats.expiredDatabases(expiredDatabases.size()); } @Override diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java index 4d5070d96683e..6a83fe69473f7 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java @@ -580,7 +580,6 @@ public void testThatRunDownloaderDeletesExpiredDatabases() { client.addHandler( UpdatePersistentTaskStatusAction.INSTANCE, (UpdatePersistentTaskStatusAction.Request request, ActionListener taskResponseListener) -> { - PersistentTasksCustomMetadata.Assignment assignment = mock(PersistentTasksCustomMetadata.Assignment.class); PersistentTasksCustomMetadata.PersistentTask persistentTask = new PersistentTasksCustomMetadata.PersistentTask<>( GeoIpDownloader.GEOIP_DOWNLOADER, @@ -589,8 +588,8 @@ public void testThatRunDownloaderDeletesExpiredDatabases() { request.getAllocationId(), assignment ); - taskResponseListener.onResponse(new PersistentTaskResponse(new PersistentTask<>(persistentTask, request.getState()))); updatePersistentTaskStateCount.incrementAndGet(); + taskResponseListener.onResponse(new PersistentTaskResponse(new PersistentTask<>(persistentTask, request.getState()))); } ); client.addHandler( From 6d85b38745e3582dcdb364df1a4dfaad6b62d0d3 Mon Sep 17 00:00:00 2001 From: Felix Barnsteiner Date: Wed, 10 Jul 2024 09:17:13 +0200 Subject: [PATCH 56/64] Remove `default_field: message` from metrics index templates (#110651) This is a follow-up from https://github.com/elastic/elasticsearch/pull/102456 --- docs/changelog/110651.yaml | 5 +++++ .../src/main/resources/metrics@settings.json | 3 --- .../src/main/resources/metrics@tsdb-settings.json | 3 --- .../org/elasticsearch/xpack/stack/StackTemplateRegistry.java | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 docs/changelog/110651.yaml diff --git a/docs/changelog/110651.yaml b/docs/changelog/110651.yaml new file mode 100644 index 0000000000000..c25c63ee0284a --- /dev/null +++ b/docs/changelog/110651.yaml @@ -0,0 +1,5 @@ +pr: 110651 +summary: "Remove `default_field: message` from metrics index templates" +area: Data streams +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json b/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json index 4f3fac1aed5ae..9960bd2e7fdac 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json @@ -10,9 +10,6 @@ "total_fields": { "ignore_dynamic_beyond_limit": true } - }, - "query": { - "default_field": ["message"] } } } diff --git a/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json b/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json index b0db168e8189d..cb0e2cbffb50b 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json @@ -9,9 +9,6 @@ "total_fields": { "ignore_dynamic_beyond_limit": true } - }, - "query": { - "default_field": ["message"] } } } diff --git a/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java b/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java index aa1e8858163a5..648146ccdcc61 100644 --- a/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java +++ b/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java @@ -47,7 +47,7 @@ public class StackTemplateRegistry extends IndexTemplateRegistry { // The stack template registry version. This number must be incremented when we make changes // to built-in templates. - public static final int REGISTRY_VERSION = 11; + public static final int REGISTRY_VERSION = 12; public static final String TEMPLATE_VERSION_VARIABLE = "xpack.stack.template.version"; public static final Setting STACK_TEMPLATES_ENABLED = Setting.boolSetting( From a4b3e6ffb5be343fc93e04a9590d863772035e2d Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Wed, 10 Jul 2024 09:30:45 +0200 Subject: [PATCH 57/64] Use valid documentation url for capabilities in rest specs (#110657) --- .../src/main/resources/rest-api-spec/api/capabilities.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json b/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json index 28c341d9983cc..a96be0d63834e 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json @@ -1,7 +1,7 @@ { "capabilities": { "documentation": { - "url": "https://www.elastic.co/guide/en/elasticsearch/reference/master/capabilities.html", + "url": "https://github.com/elastic/elasticsearch/blob/main/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/README.asciidoc#require-or-skip-api-capabilities", "description": "Checks if the specified combination of method, API, parameters, and arbitrary capabilities are supported" }, "stability": "experimental", From 356af60b0a69e67a186972ceb2f43f7e61ebe666 Mon Sep 17 00:00:00 2001 From: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:42:09 +0200 Subject: [PATCH 58/64] Feature flag for adaptive allocations (#110639) * Feature flag for adaptive allocations * Update docs/changelog/110639.yaml * Delete docs/changelog/110639.yaml --- .../test/cluster/FeatureFlag.java | 7 +++++- .../StartTrainedModelDeploymentAction.java | 15 +++++++----- .../UpdateTrainedModelDeploymentAction.java | 15 +++++++----- .../AdaptiveAllocationsFeatureFlag.java | 24 +++++++++++++++++++ .../inference/services/ServiceUtils.java | 4 ++++ .../xpack/ml/MlInitializationService.java | 9 +++++-- 6 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 49fb38b518dce..a8a33da27aebe 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -16,7 +16,12 @@ */ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), - FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null); + FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null), + INFERENCE_ADAPTIVE_ALLOCATIONS_ENABLED( + "es.inference_adaptive_allocations_feature_flag_enabled=true", + Version.fromString("8.16.0"), + null + ); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index e635851a4c5e8..59eaf4affa9a8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; @@ -119,12 +120,14 @@ public static class Request extends MasterNodeRequest implements ToXCon ObjectParser.ValueType.VALUE ); PARSER.declareString(Request::setPriority, PRIORITY); - PARSER.declareObjectOrNull( - Request::setAdaptiveAllocationsSettings, - (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), - null, - ADAPTIVE_ALLOCATIONS - ); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + null, + ADAPTIVE_ALLOCATIONS + ); + } } public static Request parseRequest(String modelId, String deploymentId, XContentParser parser) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java index c69a88600f915..28152bc0d5556 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -49,12 +50,14 @@ public static class Request extends AcknowledgedRequest implements ToXC static { PARSER.declareString(Request::setDeploymentId, MODEL_ID); PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS); - PARSER.declareObjectOrNull( - Request::setAdaptiveAllocationsSettings, - (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), - AdaptiveAllocationsSettings.RESET_PLACEHOLDER, - ADAPTIVE_ALLOCATIONS - ); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + AdaptiveAllocationsSettings.RESET_PLACEHOLDER, + ADAPTIVE_ALLOCATIONS + ); + } PARSER.declareString((r, val) -> r.ackTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java new file mode 100644 index 0000000000000..a3b508c0534f9 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * semantic_text feature flag. When the feature is complete, this flag will be removed. + */ +public class AdaptiveAllocationsFeatureFlag { + + private AdaptiveAllocationsFeatureFlag() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_adaptive_allocations"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 99779ac378d89..9f810b829bea9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbedding; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; @@ -131,6 +132,9 @@ public static Object removeAsOneOfTypes( } public static AdaptiveAllocationsSettings removeAsAdaptiveAllocationsSettings(Map sourceMap, String key) { + if (AdaptiveAllocationsFeatureFlag.isEnabled() == false) { + return null; + } Map settingsMap = ServiceUtils.removeFromMap(sourceMap, key); return settingsMap == null ? null diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java index 346b67a169912..a1664b7023fc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java @@ -32,6 +32,7 @@ import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.annotations.AnnotationIndex; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -123,13 +124,17 @@ public void beforeStop() { public void onMaster() { mlDailyMaintenanceService.start(); - adaptiveAllocationsScalerService.start(); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + adaptiveAllocationsScalerService.start(); + } threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::makeMlInternalIndicesHidden); } public void offMaster() { mlDailyMaintenanceService.stop(); - adaptiveAllocationsScalerService.stop(); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + adaptiveAllocationsScalerService.stop(); + } } @Override From bdedced58664153576e25fd72ac5aecd92471466 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 10 Jul 2024 09:01:32 +0100 Subject: [PATCH 59/64] Add test for upgrading to ES with file settings (#110229) --- .../upgrades/FileSettingsUpgradeIT.java | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java new file mode 100644 index 0000000000000..c80911fe5fbcf --- /dev/null +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.upgrades; + +import com.carrotsearch.randomizedtesting.annotations.Name; + +import org.elasticsearch.client.Request; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; +import org.elasticsearch.test.cluster.local.DefaultLocalClusterSpecBuilder; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.cluster.util.Version; +import org.elasticsearch.test.cluster.util.resource.Resource; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.rules.RuleChain; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; + +import java.io.IOException; +import java.util.Map; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; + +public class FileSettingsUpgradeIT extends ParameterizedRollingUpgradeTestCase { + + @BeforeClass + public static void checkVersion() { + assumeTrue("Only valid when upgrading from pre-file settings", getOldClusterTestVersion().before(new Version(8, 4, 0))); + } + + private static final String settingsJSON = """ + { + "metadata": { + "version": "1", + "compatibility": "8.4.0" + }, + "state": { + "cluster_settings": { + "indices.recovery.max_bytes_per_sec": "50mb" + } + } + }"""; + + private static final TemporaryFolder repoDirectory = new TemporaryFolder(); + + private static final ElasticsearchCluster cluster = new DefaultLocalClusterSpecBuilder().distribution(DistributionType.DEFAULT) + .version(getOldClusterTestVersion()) + .nodes(NODE_NUM) + .setting("path.repo", new Supplier<>() { + @Override + @SuppressForbidden(reason = "TemporaryFolder only has io.File methods, not nio.File") + public String get() { + return repoDirectory.getRoot().getPath(); + } + }) + .setting("xpack.security.enabled", "false") + .feature(FeatureFlag.TIME_SERIES_MODE) + .configFile("operator/settings.json", Resource.fromString(settingsJSON)) + .build(); + + @ClassRule + public static TestRule ruleChain = RuleChain.outerRule(repoDirectory).around(cluster); + + public FileSettingsUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { + super(upgradedNodes); + } + + @Override + protected ElasticsearchCluster getUpgradeCluster() { + return cluster; + } + + public void testFileSettingsApplied() throws IOException { + if (isUpgradedCluster()) { + // the nodes have all been upgraded. Check they read the file settings ok + Map response = responseAsMap(adminClient().performRequest(new Request("GET", "/_cluster/settings"))); + assertThat(XContentMapValues.extractValue(response, "persistent", "indices", "recovery", "max_bytes_per_sec"), equalTo("50mb")); + } + } +} From 4ddca132cf7ecdadc6be230919986a0fbdbeb133 Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Wed, 10 Jul 2024 10:35:54 +0200 Subject: [PATCH 60/64] Fix union-types when aggregating on inline conversion function (#110652) A query like: ``` FROM sample_data, sample_data_str | STATS count=count(*) BY client_ip = TO_IP(client_ip) | SORT count DESC, client_ip ASC | KEEP count, client_ip ``` Failed due to unresolved aggregates from the union-type in the grouping key --- .../src/main/resources/union_types.csv-spec | 15 +++++++++++++++ .../xpack/esql/action/EsqlCapabilities.java | 7 ++++++- .../xpack/esql/analysis/Analyzer.java | 13 ++++++++----- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec index 5783489195458..349f968666132 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec @@ -514,6 +514,21 @@ mc:l | count:l 7 | 2 ; +multiIndexTsLongStatsStats +required_capability: union_types +required_capability: union_types_agg_cast + +FROM sample_data, sample_data_ts_long +| EVAL ts = TO_STRING(@timestamp) +| STATS count = COUNT(*) BY ts +| STATS mc = COUNT(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +14 | 1 +; + multiIndexTsLongRenameStats required_capability: union_types diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index fa822b50ffcf5..5002a4a584954 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -117,7 +117,12 @@ public enum Cap { * Fix to GROK validation in case of multiple fields with same name and different types * https://github.com/elastic/elasticsearch/issues/110533 */ - GROK_VALIDATION; + GROK_VALIDATION, + + /** + * Fix for union-types when aggregating over an inline conversion with conversion function. Done in #110652. + */ + UNION_TYPES_INLINE_FIX; private final boolean snapshotOnly; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index fbc98e093c0fb..add1f74cc3f04 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -1088,7 +1088,11 @@ protected LogicalPlan doRule(LogicalPlan plan) { // In ResolveRefs the aggregates are resolved from the groupings, which might have an unresolved MultiTypeEsField. // Now that we have resolved those, we need to re-resolve the aggregates. - if (plan instanceof EsqlAggregate agg && agg.expressionsResolved() == false) { + if (plan instanceof EsqlAggregate agg) { + // If the union-types resolution occurred in a child of the aggregate, we need to check the groupings + plan = agg.transformExpressionsOnly(FieldAttribute.class, UnresolveUnionTypes::checkUnresolved); + + // Aggregates where the grouping key comes from a union-type field need to be resolved against the grouping key Map resolved = new HashMap<>(); for (Expression e : agg.groupings()) { Attribute attr = Expressions.attribute(e); @@ -1096,7 +1100,7 @@ protected LogicalPlan doRule(LogicalPlan plan) { resolved.put(attr, e); } } - plan = agg.transformExpressionsOnly(UnresolvedAttribute.class, ua -> resolveAttribute(ua, resolved)); + plan = plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> resolveAttribute(ua, resolved)); } // Otherwise drop the converted attributes after the alias function, as they are only needed for this function, and @@ -1222,9 +1226,8 @@ protected LogicalPlan rule(LogicalPlan plan) { return plan.transformExpressionsOnly(FieldAttribute.class, UnresolveUnionTypes::checkUnresolved); } - private static Attribute checkUnresolved(FieldAttribute fa) { - var field = fa.field(); - if (field instanceof InvalidMappedField imf) { + static Attribute checkUnresolved(FieldAttribute fa) { + if (fa.field() instanceof InvalidMappedField imf) { String unresolvedMessage = "Cannot use field [" + fa.name() + "] due to ambiguities being " + imf.errorMessage(); return new UnresolvedAttribute(fa.source(), fa.name(), fa.qualifier(), fa.id(), unresolvedMessage, null); } From 8e39f01723c82f00efd0a7f06f6967057ccfa1fd Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Wed, 10 Jul 2024 10:38:54 +0200 Subject: [PATCH 61/64] Remove version barrier for synthetic version based features in tests (#110656) --- .../test/rest/ESRestTestFeatureService.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java index 78a4126ec09db..92d72afbf9d52 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java @@ -86,19 +86,6 @@ public boolean clusterHasFeature(String featureId) { Matcher matcher = VERSION_FEATURE_PATTERN.matcher(featureId); if (matcher.matches()) { Version extractedVersion = Version.fromString(matcher.group(1)); - if (Version.V_8_15_0.before(extractedVersion)) { - // As of version 8.14.0 REST tests have been migrated to use features only. - // For migration purposes we provide a synthetic version feature gte_vX.Y.Z for any version at or before 8.15.0 - // allowing for some transition period. - throw new IllegalArgumentException( - Strings.format( - "Synthetic version features are only available before [%s] for migration purposes! " - + "Please add a cluster feature to an appropriate FeatureSpecification; test-only historical-features " - + "can be supplied via ESRestTestCase#additionalTestOnlyHistoricalFeatures()", - Version.V_8_15_0 - ) - ); - } return version.onOrAfter(extractedVersion); } From 6a50c45bfd9ff46d48f19023657acc793082589a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 10 Jul 2024 10:43:58 +0200 Subject: [PATCH 62/64] ESQL: Fix TOP agg tests expecting lists for single elements (#110658) --- .../xpack/esql/expression/function/aggregate/TopTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index b7b7e7ce84756..c0c23ce29301e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -241,7 +241,7 @@ private static TestCaseSupplier makeSupplier( ), "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", fieldSupplier.type(), - equalTo(expected) + equalTo(expected.size() == 1 ? expected.get(0) : expected) ); }); } From 8d3b0ade5cb5ef6d3c02449414f05162b44aae19 Mon Sep 17 00:00:00 2001 From: David Turner Date: Wed, 10 Jul 2024 09:48:10 +0100 Subject: [PATCH 63/64] Docs links from repo analysis failures (#110681) We still get too many cases about snapshot repositories which claim to be S3-compatible but then fail repository analysis because of genuine incompatibilities or implementation bugs. The info is all there in the manual but it's not very easy to find. This commit adds more detail to the response message, including docs links, to help direct users to the information they need. --- .../repositories/s3/S3Repository.java | 16 ++++++++ .../repositories/s3/S3RepositoryTests.java | 22 ++++++++++ .../elasticsearch/common/ReferenceDocs.java | 2 + .../blobstore/BlobStoreRepository.java | 12 ++++++ .../common/reference-docs-links.json | 4 +- .../rest-api-spec/test/10_analyze.yml | 2 +- .../testkit/RepositoryAnalysisFailureIT.java | 24 +++++++++-- .../testkit/RepositoryAnalyzeAction.java | 40 ++++++++++++------- 8 files changed, 103 insertions(+), 19 deletions(-) diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java index d53c379a37644..72b48c5903629 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; @@ -443,4 +444,19 @@ protected void doClose() { } super.doClose(); } + + @Override + public String getAnalysisFailureExtraDetail() { + return Strings.format( + """ + Elasticsearch observed the storage system underneath this repository behaved incorrectly which indicates it is not \ + suitable for use with Elasticsearch snapshots. Typically this happens when using storage other than AWS S3 which \ + incorrectly claims to be S3-compatible. If so, please report this incompatibility to your storage supplier. Do not report \ + Elasticsearch issues involving storage systems which claim to be S3-compatible unless you can demonstrate that the same \ + issue exists when using a genuine AWS S3 repository. See [%s] for further information about repository analysis, and [%s] \ + for further information about support for S3-compatible repository implementations.""", + ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS, + ReferenceDocs.S3_COMPATIBLE_REPOSITORIES + ); + } } diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java index fcb0e82505dac..4bbc791e5fe21 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java @@ -11,6 +11,7 @@ import com.amazonaws.services.s3.AbstractAmazonS3; import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeUnit; @@ -28,6 +29,7 @@ import java.util.Map; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -152,4 +154,24 @@ private S3Repository createS3Repo(RepositoryMetadata metadata) { ); } + public void testAnalysisFailureDetail() { + try ( + S3Repository s3repo = createS3Repo( + new RepositoryMetadata("dummy-repo", "mock", Settings.builder().put(S3Repository.BUCKET_SETTING.getKey(), "bucket").build()) + ) + ) { + assertThat( + s3repo.getAnalysisFailureExtraDetail(), + allOf( + containsString("storage system underneath this repository behaved incorrectly"), + containsString("incorrectly claims to be S3-compatible"), + containsString("report this incompatibility to your storage supplier"), + containsString("unless you can demonstrate that the same issue exists when using a genuine AWS S3 repository"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()), + containsString(ReferenceDocs.S3_COMPATIBLE_REPOSITORIES.toString()) + ) + ); + } + } + } diff --git a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java index 1953c1680040a..770ed4d213c55 100644 --- a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java +++ b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java @@ -75,6 +75,8 @@ public enum ReferenceDocs { NETWORK_THREADING_MODEL, ALLOCATION_EXPLAIN_API, NETWORK_BINDING_AND_PUBLISHING, + SNAPSHOT_REPOSITORY_ANALYSIS, + S3_COMPATIBLE_REPOSITORIES, // this comment keeps the ';' on the next line so every entry above has a trailing ',' which makes the diff for adding new links cleaner ; diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 8f55bf16c1674..5b7a11969973d 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -3946,4 +3946,16 @@ public boolean hasAtomicOverwrites() { public int getReadBufferSizeInBytes() { return bufferSize; } + + /** + * @return extra information to be included in the exception message emitted on failure of a repository analysis. + */ + public String getAnalysisFailureExtraDetail() { + return Strings.format( + """ + Elasticsearch observed the storage system underneath this repository behaved incorrectly which indicates it is not \ + suitable for use with Elasticsearch snapshots. See [%s] for further information.""", + ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS + ); + } } diff --git a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json index 303ae22f16269..febcaec1ba057 100644 --- a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json +++ b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json @@ -35,5 +35,7 @@ "EXECUTABLE_JNA_TMPDIR": "executable-jna-tmpdir.html", "NETWORK_THREADING_MODEL": "modules-network.html#modules-network-threading-model", "ALLOCATION_EXPLAIN_API": "cluster-allocation-explain.html", - "NETWORK_BINDING_AND_PUBLISHING": "modules-network.html#modules-network-binding-publishing" + "NETWORK_BINDING_AND_PUBLISHING": "modules-network.html#modules-network-binding-publishing", + "SNAPSHOT_REPOSITORY_ANALYSIS": "repo-analysis-api.html", + "S3_COMPATIBLE_REPOSITORIES": "repository-s3.html#repository-s3-compatible-services" } diff --git a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml index e5babad76eb05..bcee1691e033c 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml +++ b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml @@ -175,6 +175,6 @@ setup: - match: { status: 500 } - match: { error.type: repository_verification_exception } - - match: { error.reason: "/.*test_repo_slow..analysis.failed.*/" } + - match: { error.reason: "/.*test_repo_slow..Repository.analysis.timed.out.*/" } - match: { error.root_cause.0.type: repository_verification_exception } - match: { error.root_cause.0.reason: "/.*test_repo_slow..analysis.timed.out.after..1s.*/" } diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java index 7715b9e8d42b8..2ca5685c83db3 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; @@ -363,6 +364,17 @@ public BytesReference onContendedCompareAndExchange(BytesRegister register, Byte } } + private static void assertAnalysisFailureMessage(String message) { + assertThat( + message, + allOf( + containsString("Elasticsearch observed the storage system underneath this repository behaved incorrectly"), + containsString("not suitable for use with Elasticsearch snapshots"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()) + ) + ); + } + public void testTimesOutSpinningRegisterAnalysis() { final RepositoryAnalyzeAction.Request request = new RepositoryAnalyzeAction.Request("test-repo"); request.timeout(TimeValue.timeValueMillis(between(1, 1000))); @@ -375,7 +387,13 @@ public boolean compareAndExchangeReturnsWitness(String key) { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertThat( + exception.getMessage(), + allOf( + containsString("Repository analysis timed out. Consider specifying a longer timeout"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()) + ) + ); assertThat( asInstanceOf(RepositoryVerificationException.class, exception.getCause()).getMessage(), containsString("analysis timed out") @@ -391,7 +409,7 @@ public boolean compareAndExchangeReturnsWitness(String key) { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertAnalysisFailureMessage(exception.getMessage()); assertThat( asInstanceOf(RepositoryVerificationException.class, ExceptionsHelper.unwrapCause(exception.getCause())).getMessage(), allOf(containsString("uncontended register operation failed"), containsString("did not observe any value")) @@ -407,7 +425,7 @@ public boolean acceptsEmptyRegister() { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertAnalysisFailureMessage(exception.getMessage()); final var cause = ExceptionsHelper.unwrapCause(exception.getCause()); if (cause instanceof IOException ioException) { assertThat(ioException.getMessage(), containsString("empty register update rejected")); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java index 7b82b69a682fa..494d1d3fedcd9 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.blobstore.BlobContainer; @@ -387,6 +388,9 @@ public static class AsyncAction { private final List responses; private final RepositoryPerformanceSummary.Builder summary = new RepositoryPerformanceSummary.Builder(); + private final RepositoryVerificationException analysisCancelledException; + private final RepositoryVerificationException analysisTimedOutException; + public AsyncAction( TransportService transportService, BlobStoreRepository repository, @@ -410,6 +414,12 @@ public AsyncAction( this.listener = ActionListener.runBefore(listener, () -> cancellationListener.onResponse(null)); responses = new ArrayList<>(request.blobCount); + + this.analysisCancelledException = new RepositoryVerificationException(request.repositoryName, "analysis cancelled"); + this.analysisTimedOutException = new RepositoryVerificationException( + request.repositoryName, + "analysis timed out after [" + request.getTimeout() + "]" + ); } private boolean setFirstFailure(Exception e) { @@ -453,12 +463,7 @@ public void onFailure(Exception e) { assert e instanceof ElasticsearchTimeoutException : e; if (isRunning()) { // if this CAS fails then we're already failing for some other reason, nbd - setFirstFailure( - new RepositoryVerificationException( - request.repositoryName, - "analysis timed out after [" + request.getTimeout() + "]" - ) - ); + setFirstFailure(analysisTimedOutException); } } } @@ -472,7 +477,7 @@ public void run() { cancellationListener.addTimeout(request.getTimeout(), repository.threadPool(), EsExecutors.DIRECT_EXECUTOR_SERVICE); cancellationListener.addListener(new CheckForCancelListener()); - task.addListener(() -> setFirstFailure(new RepositoryVerificationException(request.repositoryName, "analysis cancelled"))); + task.addListener(() -> setFirstFailure(analysisCancelledException)); final Random random = new Random(request.getSeed()); final List nodes = getSnapshotNodes(discoveryNodes); @@ -873,13 +878,20 @@ private void sendResponse(final long listingStartTimeNanos, final long deleteSta ); } else { logger.debug(() -> "analysis of repository [" + request.repositoryName + "] failed", exception); - listener.onFailure( - new RepositoryVerificationException( - request.getRepositoryName(), - "analysis failed, you may need to manually remove [" + blobPath + "]", - exception - ) - ); + + final String failureDetail; + if (exception == analysisCancelledException) { + failureDetail = "Repository analysis was cancelled."; + } else if (exception == analysisTimedOutException) { + failureDetail = Strings.format(""" + Repository analysis timed out. Consider specifying a longer timeout using the [?timeout] request parameter. See \ + [%s] for more information.""", ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS); + } else { + failureDetail = repository.getAnalysisFailureExtraDetail(); + } + listener.onFailure(new RepositoryVerificationException(request.getRepositoryName(), Strings.format(""" + %s Elasticsearch attempted to remove the data it wrote at [%s] but may have left some behind. If so, \ + please now remove this data manually.""", failureDetail, blobPath), exception)); } } } From f794966671b5a885533420d4907b059de7361787 Mon Sep 17 00:00:00 2001 From: Jedr Blaszyk Date: Wed, 10 Jul 2024 10:52:06 +0200 Subject: [PATCH 64/64] [Connector API] Don't index literal nulls to connector doc (#110543) --- .../entsearch/connector/10_connector_put.yml | 36 ++++++++++ .../application/connector/Connector.java | 72 ++++++++++++++----- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml index 5cfb016e1b6df..b0f850d09f76d 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml @@ -76,6 +76,42 @@ setup: - match: { custom_scheduling: {} } - match: { filtering.0.domain: DEFAULT } + +--- +'Create Connector - Check for missing keys': + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-test + name: my-connector + language: pl + is_native: false + service_type: super-connector + + - match: { result: 'created' } + + - do: + connector.get: + connector_id: test-connector + + - match: { id: test-connector } + - match: { index_name: search-test } + - match: { name: my-connector } + - match: { language: pl } + - match: { is_native: false } + - match: { service_type: super-connector } + + # check keys that are not populated upon connector creation + - is_false: api_key_id + - is_false: api_key_secret_id + - is_false: description + - is_false: error + - is_false: features + - is_false: last_seen + - is_false: sync_cursor + + --- 'Create Connector - Resource already exists': - do: diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java index a9c488b024d49..46275bb623b7a 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java @@ -377,25 +377,61 @@ public void toInnerXContent(XContentBuilder builder, Params params) throws IOExc if (connectorId != null) { builder.field(ID_FIELD.getPreferredName(), connectorId); } - builder.field(API_KEY_ID_FIELD.getPreferredName(), apiKeyId); - builder.field(API_KEY_SECRET_ID_FIELD.getPreferredName(), apiKeySecretId); - builder.xContentValuesMap(CONFIGURATION_FIELD.getPreferredName(), configuration); - builder.xContentValuesMap(CUSTOM_SCHEDULING_FIELD.getPreferredName(), customScheduling); - builder.field(DESCRIPTION_FIELD.getPreferredName(), description); - builder.field(ERROR_FIELD.getPreferredName(), error); - builder.field(FEATURES_FIELD.getPreferredName(), features); - builder.xContentList(FILTERING_FIELD.getPreferredName(), filtering); - builder.field(INDEX_NAME_FIELD.getPreferredName(), indexName); + if (apiKeyId != null) { + builder.field(API_KEY_ID_FIELD.getPreferredName(), apiKeyId); + } + if (apiKeySecretId != null) { + builder.field(API_KEY_SECRET_ID_FIELD.getPreferredName(), apiKeySecretId); + } + if (configuration != null) { + builder.xContentValuesMap(CONFIGURATION_FIELD.getPreferredName(), configuration); + } + if (customScheduling != null) { + builder.xContentValuesMap(CUSTOM_SCHEDULING_FIELD.getPreferredName(), customScheduling); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD.getPreferredName(), description); + } + if (error != null) { + builder.field(ERROR_FIELD.getPreferredName(), error); + } + if (features != null) { + builder.field(FEATURES_FIELD.getPreferredName(), features); + } + if (filtering != null) { + builder.xContentList(FILTERING_FIELD.getPreferredName(), filtering); + } + if (indexName != null) { + builder.field(INDEX_NAME_FIELD.getPreferredName(), indexName); + } builder.field(IS_NATIVE_FIELD.getPreferredName(), isNative); - builder.field(LANGUAGE_FIELD.getPreferredName(), language); - builder.field(LAST_SEEN_FIELD.getPreferredName(), lastSeen); - syncInfo.toXContent(builder, params); - builder.field(NAME_FIELD.getPreferredName(), name); - builder.field(PIPELINE_FIELD.getPreferredName(), pipeline); - builder.field(SCHEDULING_FIELD.getPreferredName(), scheduling); - builder.field(SERVICE_TYPE_FIELD.getPreferredName(), serviceType); - builder.field(SYNC_CURSOR_FIELD.getPreferredName(), syncCursor); - builder.field(STATUS_FIELD.getPreferredName(), status.toString()); + if (language != null) { + builder.field(LANGUAGE_FIELD.getPreferredName(), language); + } + if (lastSeen != null) { + builder.field(LAST_SEEN_FIELD.getPreferredName(), lastSeen); + } + if (syncInfo != null) { + syncInfo.toXContent(builder, params); + } + if (name != null) { + builder.field(NAME_FIELD.getPreferredName(), name); + } + if (pipeline != null) { + builder.field(PIPELINE_FIELD.getPreferredName(), pipeline); + } + if (scheduling != null) { + builder.field(SCHEDULING_FIELD.getPreferredName(), scheduling); + } + if (serviceType != null) { + builder.field(SERVICE_TYPE_FIELD.getPreferredName(), serviceType); + } + if (syncCursor != null) { + builder.field(SYNC_CURSOR_FIELD.getPreferredName(), syncCursor); + } + if (status != null) { + builder.field(STATUS_FIELD.getPreferredName(), status.toString()); + } builder.field(SYNC_NOW_FIELD.getPreferredName(), syncNow); }