From 6c2a8893abe9b966b50eb2ae41557ca36bc84863 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 13 Apr 2021 15:51:33 +0100 Subject: [PATCH 1/9] Store binary data as bytes --- .../InferenceToXContentCompressor.java | 28 ++++---- .../core/ml/inference/TrainedModelConfig.java | 67 +++++++++++-------- .../core/ml/inference_index_mappings.json | 5 +- .../InferenceToXContentCompressorTests.java | 16 +++-- .../ml/inference/TrainedModelConfigTests.java | 13 ++-- .../inference/InferenceDefinitionTests.java | 7 +- .../ChunkedTrainedModelPersisterIT.java | 9 ++- .../integration/TrainedModelProviderIT.java | 12 ++-- .../TrainedModelDefinitionDoc.java | 32 ++++++--- .../persistence/TrainedModelProvider.java | 57 +++++++++------- .../pytorch/process/PyTorchStateStreamer.java | 6 +- .../TrainedModelProviderTests.java | 13 ++++ 12 files changed, 157 insertions(+), 108 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java index 70f4998a7f96c..b7f7532668b77 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -29,8 +28,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Map; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -47,24 +44,24 @@ public final class InferenceToXContentCompressor { private InferenceToXContentCompressor() {} - public static String deflate(T objectToCompress) throws IOException { + public static BytesReference deflate(T objectToCompress) throws IOException { BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false); return deflate(reference); } - public static T inflate(String compressedString, + public static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry) throws IOException { - return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); + return inflate(compressedBytes, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); } - static T inflate(String compressedString, + static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry, long maxBytes) throws IOException { try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, maxBytes))) { + inflate(compressedBytes, maxBytes))) { return parserFunction.apply(parser); } catch (XContentParseException parseException) { SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause = @@ -82,32 +79,31 @@ static T inflate(String compressedString, } } - static Map inflateToMap(String compressedString) throws IOException { + static Map inflateToMap(BytesReference compressedBytes) throws IOException { // Don't need the xcontent registry as we are not deflating named objects. try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, MAX_INFLATED_BYTES))) { + inflate(compressedBytes, MAX_INFLATED_BYTES))) { return parser.mapOrdered(); } } - static InputStream inflate(String compressedString, long streamSize) throws IOException { - byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + static InputStream inflate(BytesReference compressedBytes, long streamSize) throws IOException { // If the compressed length is already too large, it make sense that the inflated length would be as well // In the extremely small string case, the compressed data could actually be longer than the compressed stream - if (compressedBytes.length > Math.max(100L, streamSize)) { + if (compressedBytes.length() > Math.max(100L, streamSize)) { throw new CircuitBreakingException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]", CircuitBreaker.Durability.PERMANENT); } - InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); + InputStream gzipStream = new GZIPInputStream(compressedBytes.streamInput(), BUFFER_SIZE); return new SimpleBoundedInputStream(gzipStream, streamSize); } - private static String deflate(BytesReference reference) throws IOException { + private static BytesReference deflate(BytesReference reference) throws IOException { BytesStreamOutput out = new BytesStreamOutput(); try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { reference.writeTo(compressedOutput); } - return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + return out.bytes(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9781eaad70177..b97e74236460b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -36,8 +38,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -248,15 +252,15 @@ public InferenceConfig getInferenceConfig() { } @Nullable - public String getCompressedDefinition() throws IOException { + public BytesReference getCompressedDefinition() throws IOException { if (definition == null) { return null; } - return definition.getCompressedString(); + return definition.compressedRepresentation; } public void clearCompressed() { - definition.compressedString = null; + definition.compressedRepresentation = null; } public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { @@ -348,7 +352,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { - builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); + builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.compressedRepresentation); } } builder.field(TAGS.getPreferredName(), tags); @@ -564,11 +568,11 @@ public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { return this; } - public Builder setDefinitionFromString(String definitionFromString) { - if (definitionFromString == null) { + public Builder setDefinitionFromBytes(BytesReference definition) { + if (definition == null) { return this; } - this.definition = LazyModelDefinition.fromCompressedString(definitionFromString); + this.definition = LazyModelDefinition.fromCompressedData(definition); return this; } @@ -605,7 +609,7 @@ private Builder setLazyDefinition(String compressedString) { DEFINITION.getPreferredName()) .getFormattedMessage()); } - this.definition = LazyModelDefinition.fromCompressedString(compressedString); + this.definition = LazyModelDefinition.fromBase64String(compressedString); return this; } @@ -762,54 +766,61 @@ public TrainedModelConfig build() { public static class LazyModelDefinition implements ToXContentObject, Writeable { - private String compressedString; + private BytesReference compressedRepresentation; private TrainedModelDefinition parsedDefinition; public static LazyModelDefinition fromParsedDefinition(TrainedModelDefinition definition) { return new LazyModelDefinition(null, definition); } - public static LazyModelDefinition fromCompressedString(String compressedString) { - return new LazyModelDefinition(compressedString, null); + public static LazyModelDefinition fromCompressedData(BytesReference compressed) { + return new LazyModelDefinition(compressed, null); + } + + public static LazyModelDefinition fromBase64String(String base64String) { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + return new LazyModelDefinition(new BytesArray(decodedBytes), null); } public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOException { - return new LazyModelDefinition(input.readString(), null); + if (input.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + return new LazyModelDefinition(input.readBytesReference(), null); + } else { + return fromBase64String(input.readString()); + } } private LazyModelDefinition(LazyModelDefinition definition) { if (definition != null) { - this.compressedString = definition.compressedString; + this.compressedRepresentation = definition.compressedRepresentation; this.parsedDefinition = definition.parsedDefinition; } } - private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) { - if (compressedString == null && trainedModelDefinition == null) { + private LazyModelDefinition(BytesReference compressedRepresentation, TrainedModelDefinition trainedModelDefinition) { + if (compressedRepresentation == null && trainedModelDefinition == null) { throw new IllegalArgumentException("unexpected null model definition"); } - this.compressedString = compressedString; + this.compressedRepresentation = compressedRepresentation; this.parsedDefinition = trainedModelDefinition; } public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { if (parsedDefinition == null) { - parsedDefinition = InferenceToXContentCompressor.inflate(compressedString, + parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation, parser -> TrainedModelDefinition.fromXContent(parser, true).build(), xContentRegistry); } } - public String getCompressedString() throws IOException { - if (compressedString == null) { - compressedString = InferenceToXContentCompressor.deflate(parsedDefinition); - } - return compressedString; - } - @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(getCompressedString()); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + out.writeBytesReference(compressedRepresentation); + } else { + String base64String = new String(Base64.getEncoder().encode(compressedRepresentation.array()), StandardCharsets.UTF_8); + out.writeString(base64String); + } } @Override @@ -817,7 +828,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (parsedDefinition != null) { return parsedDefinition.toXContent(builder, params); } - Map map = InferenceToXContentCompressor.inflateToMap(compressedString); + Map map = InferenceToXContentCompressor.inflateToMap(compressedRepresentation); return builder.map(map); } @@ -826,13 +837,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LazyModelDefinition that = (LazyModelDefinition) o; - return Objects.equals(compressedString, that.compressedString) && + return Objects.equals(compressedRepresentation, that.compressedRepresentation) && Objects.equals(parsedDefinition, that.parsedDefinition); } @Override public int hashCode() { - return Objects.hash(compressedString, parsedDefinition); + return Objects.hash(compressedRepresentation, parsedDefinition); } } diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json index 171cbabc52c30..bf164f049dc43 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json @@ -44,6 +44,9 @@ "definition": { "enabled": false }, + "binary_definition": { + "type": "binary" + }, "compression_version": { "type": "long" }, @@ -135,7 +138,7 @@ "supplied": { "type": "boolean" } - } + } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java index 47a131f5d758c..6a98cdb49550a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; @@ -16,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -27,7 +28,7 @@ public class InferenceToXContentCompressorTests extends ESTestCase { public void testInflateAndDeflate() throws IOException { for(int i = 0; i < 10; i++) { TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate, parser -> TrainedModelDefinition.fromXContent(parser, false).build(), xContentRegistry()); @@ -45,8 +46,8 @@ public void testInflateTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); - int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); + int max = firstDeflate.length() + 10; IOException ex = expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max))); assertThat(ex.getMessage(), equalTo("" + @@ -54,7 +55,8 @@ public void testInflateTooLargeStream() throws IOException { } public void testInflateGarbage() { - expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L))); + expectThrows(IOException.class, () -> Streams.readFully( + InferenceToXContentCompressor.inflate(new BytesArray(randomByteArrayOfLength(10)), 100L))); } public void testInflateParsingTooLargeStream() throws IOException { @@ -65,8 +67,8 @@ public void testInflateParsingTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String compressedString = InferenceToXContentCompressor.deflate(definition); - int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference compressedString = InferenceToXContentCompressor.deflate(definition); + int max = compressedString.length() + 10; CircuitBreakingException e = expectThrows(CircuitBreakingException.class, ()-> InferenceToXContentCompressor.inflate( compressedString, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index c2755b57014f7..12a8146c0fe61 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -115,8 +115,7 @@ protected NamedXContentRegistry xContentRegistry() { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - List entries = new ArrayList<>(); - entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + List entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } @@ -170,7 +169,7 @@ public void testToXContentWithParams() throws IOException { assertThat(reference.utf8ToString(), containsString("\"definition\"")); assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } - + public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { TrainedModelConfig.LazyModelDefinition lazyModelDefinition = TrainedModelConfig.LazyModelDefinition .fromParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder().build()); @@ -263,9 +262,9 @@ public void testSerializationWithLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); @@ -294,10 +293,10 @@ public void testSerializationWithCompressedLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index d1a486dbe0b8a..58cd3eaad2d2b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; +import com.unboundid.util.Base64; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; @@ -24,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import java.io.IOException; +import java.text.ParseException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -59,7 +61,7 @@ public void testTreeSchemaDeserialization() throws IOException { assertThat(definition.getTrainedModel().getClass(), equalTo(TreeInferenceModel.class)); } - public void testMultiClassIrisInference() throws IOException { + public void testMultiClassIrisInference() throws IOException, ParseException { // Fairly simple, random forest classification model built to fit in our format // Trained on the well known Iris dataset String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" + @@ -83,7 +85,8 @@ public void testMultiClassIrisInference() throws IOException { "aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" + "fdpTi9JB0sDp2JR7b309mn5HuPkEAAA=="; - InferenceDefinition definition = InferenceToXContentCompressor.inflate(compressedDef, + byte[] bytes = Base64.decode(compressedDef); + InferenceDefinition definition = InferenceToXContentCompressor.inflate(new BytesArray(bytes), InferenceDefinition::fromXContent, xContentRegistry()); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index c2779ac8bbb74..1209f072923b5 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.license.License; @@ -27,8 +28,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.HyperparametersTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; @@ -46,7 +47,9 @@ import org.junit.Before; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; @@ -78,7 +81,9 @@ public void testStoreModelViaChunkedPersister() throws IOException { .build(); List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); - String compressedDefinition = configBuilder.build().getCompressedDefinition(); + // TODO where does the compressed def come from is parse set?? + BytesReference bytes = configBuilder.build().getCompressedDefinition(); + String compressedDefinition = new String(Base64.getEncoder().encode(bytes.array()), StandardCharsets.UTF_8); int totalSize = compressedDefinition.length(); List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index a25a3d7bc8c0e..11ef3cca78ecd 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -46,7 +47,6 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; -import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; @@ -228,7 +228,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception { TrainedModelDefinitionDoc truncatedDoc = new TrainedModelDefinitionDoc.Builder() .setDocNum(0) - .setCompressedString(config.getCompressedDefinition().substring(0, config.getCompressedDefinition().length() - 10)) + .setBinaryData(config.getCompressedDefinition().slice(0, config.getCompressedDefinition().length() - 10).array()) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(config.getCompressedDefinition().length()) .setTotalDefinitionLength(config.getCompressedDefinition().length()) @@ -351,15 +351,15 @@ public void testGetTrainedModelForInference() throws InterruptedException, IOExc assertThat(definitionHolder.get(), is(not(nullValue()))); } - private List createModelDefinitionDocs(String compressedDefinition, String modelId) { - List chunks = chunkStringWithSize(compressedDefinition, compressedDefinition.length()/3); + private List createModelDefinitionDocs(BytesReference compressedDefinition, String modelId) { + List chunks = TrainedModelProvider.chunkDefinitionWithSize(compressedDefinition, compressedDefinition.length()/3); return IntStream.range(0, chunks.size()) .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() .setDocNum(i) - .setCompressedString(chunks.get(i)) + .setBinaryData(chunks.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunks.get(i).length()) + .setDefinitionLength(chunks.get(i).length) .setEos(i == chunks.size() - 1) .setModelId(modelId)) .collect(Collectors.toList()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index efc3cc9eee865..c41ac67be5a4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -17,6 +17,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Objects; /** @@ -31,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { public static final ParseField DOC_NUM = new ParseField("doc_num"); public static final ParseField DEFINITION = new ParseField("definition"); + public static final ParseField BINARY_DEFINITION = new ParseField("binary_definition"); public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version"); public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length"); @@ -47,6 +50,8 @@ private static ObjectParser createParse parser.declareString((a, b) -> {}, InferenceIndexConstants.DOC_TYPE); // type is hard coded but must be parsed parser.declareString(TrainedModelDefinitionDoc.Builder::setModelId, TrainedModelConfig.MODEL_ID); parser.declareString(TrainedModelDefinitionDoc.Builder::setCompressedString, DEFINITION); + parser.declareField(TrainedModelDefinitionDoc.Builder::setBinaryData, (p, c) -> p.binaryValue(), + BINARY_DEFINITION, ObjectParser.ValueType.VALUE_ARRAY); parser.declareInt(TrainedModelDefinitionDoc.Builder::setDocNum, DOC_NUM); parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); @@ -64,7 +69,7 @@ public static String docId(String modelId, int docNum) { return NAME + "-" + modelId + "-" + docNum; } - private final String compressedString; + private final byte[] binaryData; private final String modelId; private final int docNum; // for bwc @@ -73,14 +78,14 @@ public static String docId(String modelId, int docNum) { private final int compressionVersion; private final boolean eos; - private TrainedModelDefinitionDoc(String compressedString, + private TrainedModelDefinitionDoc(byte[] binaryData, String modelId, int docNum, Long totalDefinitionLength, long definitionLength, int compressionVersion, boolean eos) { - this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION); + this.binaryData = ExceptionsHelper.requireNonNull(binaryData, BINARY_DEFINITION); this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); if (docNum < 0) { throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0"); @@ -98,8 +103,8 @@ private TrainedModelDefinitionDoc(String compressedString, this.eos = eos; } - public String getCompressedString() { - return compressedString; + public byte[] getBinaryData() { + return binaryData; } public String getModelId() { @@ -141,7 +146,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); } builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion); - builder.field(DEFINITION.getPreferredName(), compressedString); + builder.field(BINARY_DEFINITION.getPreferredName(), binaryData); builder.field(EOS.getPreferredName(), eos); builder.endObject(); return builder; @@ -163,18 +168,18 @@ public boolean equals(Object o) { Objects.equals(totalDefinitionLength, that.totalDefinitionLength) && Objects.equals(compressionVersion, that.compressionVersion) && Objects.equals(eos, that.eos) && - Objects.equals(compressedString, that.compressedString); + Objects.equals(binaryData, that.binaryData); } @Override public int hashCode() { - return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos); + return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, binaryData, eos); } public static class Builder { private String modelId; - private String compressedString; + private byte[] binaryData; private int docNum; private Long totalDefinitionLength; private long definitionLength; @@ -187,7 +192,12 @@ public Builder setModelId(String modelId) { } public Builder setCompressedString(String compressedString) { - this.compressedString = compressedString; + this.binaryData = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + return this; + } + + public Builder setBinaryData(byte [] binaryData) { + this.binaryData = binaryData; return this; } @@ -218,7 +228,7 @@ public Builder setEos(boolean eos) { public TrainedModelDefinitionDoc build() { return new TrainedModelDefinitionDoc( - this.compressedString, + this.binaryData, this.modelId, this.docNum, this.totalDefinitionLength, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 18473f1380610..a7ee020215c50 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -40,6 +40,8 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ByteArray; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -291,25 +293,25 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi List trainedModelDefinitionDocs = new ArrayList<>(); try { - String compressedString = trainedModelConfig.getCompressedDefinition(); - if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) { + BytesReference compressedDefinition = trainedModelConfig.getCompressedDefinition(); + if (compressedDefinition.length() > MAX_COMPRESSED_STRING_SIZE) { listener.onFailure( ExceptionsHelper.badRequestException( "Unable to store model as compressed definition has length [{}] the limit is [{}]", - compressedString.length(), + compressedDefinition.length(), MAX_COMPRESSED_STRING_SIZE)); return; } - List chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE); - for(int i = 0; i < chunkedStrings.size(); ++i) { + List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_STRING_CHUNK_SIZE); + for(int i = 0; i < chunkedDefinition.size(); ++i) { trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setModelId(trainedModelConfig.getModelId()) - .setCompressedString(chunkedStrings.get(i)) + .setBinaryData(chunkedDefinition.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunkedStrings.get(i).length()) + .setDefinitionLength(chunkedDefinition.get(i).length) // If it is the last doc, it is the EOS - .setEos(i == chunkedStrings.size() - 1) + .setEos(i == chunkedDefinition.size() - 1) .build()); } } catch (IOException ex) { @@ -416,9 +418,9 @@ public void getTrainedModelForInference(final String modelId, final ActionListen modelRestorer.restoreModelDefinition(docs::add, success -> { try { - String compressedString = getDefinitionFromDocs(docs, modelId); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, + compressedData, InferenceDefinition::fromXContent, xContentRegistry); @@ -534,8 +536,8 @@ public void getTrainedModel(final String modelId, ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource( bytes, resourceId, xContentRegistry)); try { - String compressedString = getDefinitionFromDocs(docs, modelId); - builder.setDefinitionFromString(compressedString); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); + builder.setDefinitionFromBytes(compressedData); } catch (ElasticsearchException elasticsearchException) { getTrainedModelListener.onFailure(elasticsearchException); return; @@ -1086,14 +1088,22 @@ private static List handleHits(SearchHit[] hits, return results; } - private static String getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); + private static BytesReference getDefinitionFromDocs(List docs, + String modelId) throws ElasticsearchException { + + int size = docs.stream().map(doc -> doc.getBinaryData().length).reduce(0, Integer::sum); + ByteArray byteArray = BigArrays.NON_RECYCLING_INSTANCE.newByteArray(size); + int offset = 0; + for (TrainedModelDefinitionDoc doc : docs) { + byteArray.set(0L, doc.getBinaryData(), offset, doc.getBinaryData().length); + offset += doc.getBinaryData().length; + } + BytesReference bytes = BytesReference.fromByteArray(byteArray, size); + // BWC for when we tracked the total definition length // TODO: remove in 9 if (docs.get(0).getTotalDefinitionLength() != null) { - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + if (bytes.length() != docs.get(0).getTotalDefinitionLength()) { throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); } } else { @@ -1103,15 +1113,16 @@ private static String getDefinitionFromDocs(List docs throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); } } - return compressedString; + return bytes; } - static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); - for (int i = 0; i < str.length();i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + public static List chunkDefinitionWithSize(BytesReference definition, int chunkSize) { + List chunks = new ArrayList<>((int)Math.ceil(definition.length()/(double)chunkSize)); + for (int i = 0; i < definition.length();i += chunkSize) { + BytesReference chunk = definition.slice(i, Math.min(chunkSize, definition.length() - i)); + chunks.add(chunk.array()); } - return subStrings; + return chunks; } private TrainedModelConfig.Builder parseModelConfigLenientlyFromSource(BytesReference source, String modelId) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java index 1134d767d59a7..3ac6ed80f8033 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java @@ -19,8 +19,6 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Locale; import java.util.Objects; import java.util.concurrent.ExecutorService; @@ -82,9 +80,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr modelSizeWritten = true; } - byte[] rawBytes = Base64.getDecoder().decode(doc.getCompressedString().getBytes(StandardCharsets.UTF_8)); - outputStream.write(rawBytes); - + outputStream.write(doc.getBinaryData()); return true; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index e447dc29de321..10460d2bd0699 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.ConstantScoreQueryBuilder; @@ -25,9 +26,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -146,6 +149,16 @@ public void testGetModelThatExistsAsResourceButIsMissing() { assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model"))); } + public void testChunkDefinitionWithSize() { + byte[] bytes = randomByteArrayOfLength(100); + List chunks = TrainedModelProvider.chunkDefinitionWithSize(new BytesArray(bytes), 30); + assertThat(chunks, hasSize(4)); + assertArrayEquals(Arrays.copyOfRange(bytes, 0, 30), chunks.get(0)); + assertArrayEquals(Arrays.copyOfRange(bytes, 30, 60), chunks.get(1)); + assertArrayEquals(Arrays.copyOfRange(bytes, 60, 90), chunks.get(2)); + assertArrayEquals(Arrays.copyOfRange(bytes, 90, 100), chunks.get(3)); + } + @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); From 4ca1e5c03b56514190f18714c6612447b09c41f8 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 13 Apr 2021 18:29:23 +0100 Subject: [PATCH 2/9] Switch to bytes ref --- .../integration/TrainedModelProviderIT.java | 7 +++--- .../TrainedModelDefinitionDoc.java | 19 ++++++++------- .../persistence/TrainedModelProvider.java | 24 ++++++++----------- .../pytorch/process/PyTorchStateStreamer.java | 4 +++- .../TrainedModelDefinitionDocTests.java | 8 ++++--- .../TrainedModelProviderTests.java | 21 +++++++++++----- 6 files changed, 48 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 11ef3cca78ecd..a6e0912ac7d7e 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -228,7 +228,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception { TrainedModelDefinitionDoc truncatedDoc = new TrainedModelDefinitionDoc.Builder() .setDocNum(0) - .setBinaryData(config.getCompressedDefinition().slice(0, config.getCompressedDefinition().length() - 10).array()) + .setBinaryData(config.getCompressedDefinition().slice(0, config.getCompressedDefinition().length() - 10)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(config.getCompressedDefinition().length()) .setTotalDefinitionLength(config.getCompressedDefinition().length()) @@ -352,14 +352,15 @@ public void testGetTrainedModelForInference() throws InterruptedException, IOExc } private List createModelDefinitionDocs(BytesReference compressedDefinition, String modelId) { - List chunks = TrainedModelProvider.chunkDefinitionWithSize(compressedDefinition, compressedDefinition.length()/3); + List chunks = TrainedModelProvider.chunkDefinitionWithSize(compressedDefinition, compressedDefinition.length()/3); return IntStream.range(0, chunks.size()) .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setBinaryData(chunks.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunks.get(i).length) + .setDefinitionLength(chunks.get(i).length()) + .setTotalDefinitionLength(compressedDefinition.length()) .setEos(i == chunks.size() - 1) .setModelId(modelId)) .collect(Collectors.toList()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index c41ac67be5a4a..ea609d6129e00 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -50,8 +52,8 @@ private static ObjectParser createParse parser.declareString((a, b) -> {}, InferenceIndexConstants.DOC_TYPE); // type is hard coded but must be parsed parser.declareString(TrainedModelDefinitionDoc.Builder::setModelId, TrainedModelConfig.MODEL_ID); parser.declareString(TrainedModelDefinitionDoc.Builder::setCompressedString, DEFINITION); - parser.declareField(TrainedModelDefinitionDoc.Builder::setBinaryData, (p, c) -> p.binaryValue(), - BINARY_DEFINITION, ObjectParser.ValueType.VALUE_ARRAY); + parser.declareField(TrainedModelDefinitionDoc.Builder::setBinaryData, (p, c) -> new BytesArray(p.binaryValue()), + BINARY_DEFINITION, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); parser.declareInt(TrainedModelDefinitionDoc.Builder::setDocNum, DOC_NUM); parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); @@ -69,7 +71,7 @@ public static String docId(String modelId, int docNum) { return NAME + "-" + modelId + "-" + docNum; } - private final byte[] binaryData; + private final BytesReference binaryData; private final String modelId; private final int docNum; // for bwc @@ -78,7 +80,7 @@ public static String docId(String modelId, int docNum) { private final int compressionVersion; private final boolean eos; - private TrainedModelDefinitionDoc(byte[] binaryData, + private TrainedModelDefinitionDoc(BytesReference binaryData, String modelId, int docNum, Long totalDefinitionLength, @@ -103,7 +105,7 @@ private TrainedModelDefinitionDoc(byte[] binaryData, this.eos = eos; } - public byte[] getBinaryData() { + public BytesReference getBinaryData() { return binaryData; } @@ -179,7 +181,7 @@ public int hashCode() { public static class Builder { private String modelId; - private byte[] binaryData; + private BytesReference binaryData; private int docNum; private Long totalDefinitionLength; private long definitionLength; @@ -192,11 +194,12 @@ public Builder setModelId(String modelId) { } public Builder setCompressedString(String compressedString) { - this.binaryData = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + this.binaryData = new BytesArray(Base64.getDecoder() + .decode(compressedString.getBytes(StandardCharsets.UTF_8))); return this; } - public Builder setBinaryData(byte [] binaryData) { + public Builder setBinaryData(BytesReference binaryData) { this.binaryData = binaryData; return this; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index a7ee020215c50..a9c596cb47eee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -38,10 +38,9 @@ import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ByteArray; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -302,14 +301,14 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi MAX_COMPRESSED_STRING_SIZE)); return; } - List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_STRING_CHUNK_SIZE); + List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_STRING_CHUNK_SIZE); for(int i = 0; i < chunkedDefinition.size(); ++i) { trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setModelId(trainedModelConfig.getModelId()) .setBinaryData(chunkedDefinition.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunkedDefinition.get(i).length) + .setDefinitionLength(chunkedDefinition.get(i).length()) // If it is the last doc, it is the EOS .setEos(i == chunkedDefinition.size() - 1) .build()); @@ -1091,14 +1090,11 @@ private static List handleHits(SearchHit[] hits, private static BytesReference getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { - int size = docs.stream().map(doc -> doc.getBinaryData().length).reduce(0, Integer::sum); - ByteArray byteArray = BigArrays.NON_RECYCLING_INSTANCE.newByteArray(size); - int offset = 0; - for (TrainedModelDefinitionDoc doc : docs) { - byteArray.set(0L, doc.getBinaryData(), offset, doc.getBinaryData().length); - offset += doc.getBinaryData().length; + BytesReference[] bb = new BytesReference[docs.size()]; + for (int i=0; i chunkDefinitionWithSize(BytesReference definition, int chunkSize) { - List chunks = new ArrayList<>((int)Math.ceil(definition.length()/(double)chunkSize)); + public static List chunkDefinitionWithSize(BytesReference definition, int chunkSize) { + List chunks = new ArrayList<>((int)Math.ceil(definition.length()/(double)chunkSize)); for (int i = 0; i < definition.length();i += chunkSize) { BytesReference chunk = definition.slice(i, Math.min(chunkSize, definition.length() - i)); - chunks.add(chunk.array()); + chunks.add(chunk); } return chunks; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java index 3ac6ed80f8033..338821aa72ab1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java @@ -80,7 +80,9 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr modelSizeWritten = true; } - outputStream.write(doc.getBinaryData()); + // The array backing the BytesReference may be bigger than what is + // referred to so write only what is after the offset + outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length()); return true; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java index 73cea41c1449a..e79a6ef02eca2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.persistence; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -14,7 +15,7 @@ public class TrainedModelDefinitionDocTests extends AbstractXContentTestCase { - private boolean isLenient = randomBoolean(); + private final boolean isLenient = randomBoolean(); @Override protected TrainedModelDefinitionDoc doParseInstance(XContentParser parser) throws IOException { @@ -28,12 +29,13 @@ protected boolean supportsUnknownFields() { @Override protected TrainedModelDefinitionDoc createTestInstance() { - int length = randomIntBetween(1, 10); + int length = randomIntBetween(4, 10); + return new TrainedModelDefinitionDoc.Builder() .setModelId(randomAlphaOfLength(6)) .setDefinitionLength(length) .setTotalDefinitionLength(randomIntBetween(length, length *2)) - .setCompressedString(randomAlphaOfLength(length)) + .setBinaryData(new BytesArray(randomByteArrayOfLength(length))) .setDocNum(randomIntBetween(0, 10)) .setCompressionVersion(randomIntBetween(1, 5)) .setEos(randomBoolean()) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 10460d2bd0699..f353a8221d459 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.ConstantScoreQueryBuilder; @@ -150,13 +151,21 @@ public void testGetModelThatExistsAsResourceButIsMissing() { } public void testChunkDefinitionWithSize() { - byte[] bytes = randomByteArrayOfLength(100); - List chunks = TrainedModelProvider.chunkDefinitionWithSize(new BytesArray(bytes), 30); + int totalLength = 100; + int size = 30; + + byte[] bytes = randomByteArrayOfLength(totalLength); + List chunks = TrainedModelProvider.chunkDefinitionWithSize(new BytesArray(bytes), size); assertThat(chunks, hasSize(4)); - assertArrayEquals(Arrays.copyOfRange(bytes, 0, 30), chunks.get(0)); - assertArrayEquals(Arrays.copyOfRange(bytes, 30, 60), chunks.get(1)); - assertArrayEquals(Arrays.copyOfRange(bytes, 60, 90), chunks.get(2)); - assertArrayEquals(Arrays.copyOfRange(bytes, 90, 100), chunks.get(3)); + int start = 0; + int end = size; + for (BytesReference chunk : chunks) { + assertArrayEquals(Arrays.copyOfRange(bytes, start, end), + Arrays.copyOfRange(chunk.array(), chunk.arrayOffset(), chunk.arrayOffset() + chunk.length())); + + start += size; + end = Math.min(end + size, totalLength); + } } @Override From 00590f3cf80d3e448e43f8c6f313dc49d7e988c4 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 14 Apr 2021 10:30:23 +0100 Subject: [PATCH 3/9] new index mappings version --- .../core/ml/inference/persistence/InferenceIndexConstants.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 5f93c6d7b7cfc..f8e01c989d4ff 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -26,6 +26,9 @@ public final class InferenceIndexConstants { * * version: 7.10.0: 000003 * - adds trained_model_metadata object + * + * version: UNKNOWN_MERGED_ON_FEATURE_BRANCH: 000004 TODO + * - adds binary_definition for TrainedModelDefinitionDoc */ public static final String INDEX_VERSION = "000003"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; From 79afa7b4865feed5b2fb47a821c740bc0b7cfb86 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 14 Apr 2021 11:49:00 +0100 Subject: [PATCH 4/9] Fixing up tests --- .../core/ml/inference/TrainedModelConfig.java | 11 +++++++-- .../ChunkedTrainedModelPersisterIT.java | 23 +++++++++---------- .../ChunkedTrainedModelRestorerIT.java | 22 ++++++++++-------- .../TrainedModelDefinitionDocTests.java | 23 +++++++++++++++++++ 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index b97e74236460b..1de2c618fc55e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -256,7 +256,7 @@ public BytesReference getCompressedDefinition() throws IOException { if (definition == null) { return null; } - return definition.compressedRepresentation; + return definition.getCompressedDefinition(); } public void clearCompressed() { @@ -352,7 +352,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { - builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.compressedRepresentation); + builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedDefinition()); } } builder.field(TAGS.getPreferredName(), tags); @@ -805,6 +805,13 @@ private LazyModelDefinition(BytesReference compressedRepresentation, TrainedMode this.parsedDefinition = trainedModelDefinition; } + public BytesReference getCompressedDefinition() throws IOException { + if (compressedRepresentation == null) { + compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition); + } + return compressedRepresentation; + } + public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { if (parsedDefinition == null) { parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation, diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 1209f072923b5..e8da96e2da43e 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -49,6 +49,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.List; @@ -81,11 +82,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { .build(); List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); - // TODO where does the compressed def come from is parse set?? - BytesReference bytes = configBuilder.build().getCompressedDefinition(); - String compressedDefinition = new String(Base64.getEncoder().encode(bytes.array()), StandardCharsets.UTF_8); - int totalSize = compressedDefinition.length(); - List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); + BytesReference compressedDefinition = configBuilder.build().getCompressedDefinition(); + List base64Chunks = chunkBinaryDefinition(compressedDefinition, compressedDefinition.length() / 3); ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider, analyticsConfig, @@ -97,8 +95,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { //Accuracy for size is not tested here ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); persister.createAndIndexInferenceModelConfig(modelSizeInfo, configBuilder.getModelType()); - for (int i = 0; i < chunks.size(); i++) { - persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); + for (int i = 0; i < base64Chunks.size(); i++) { + persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(base64Chunks.get(i), i, i == (base64Chunks.size() - 1))); } ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) .limit(randomIntBetween(1, 10)) @@ -158,14 +156,15 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setInput(TrainedModelInputTests.createRandomInput()); } - public static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize); - for (int i = 0; i < str.length(); i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + public static List chunkBinaryDefinition(BytesReference bytes, int chunkSize) { + List subStrings = new ArrayList<>((bytes.length() + chunkSize - 1) / chunkSize); + for (int i = 0; i < bytes.length(); i += chunkSize) { + subStrings.add( + Base64.getEncoder().encodeToString( + Arrays.copyOfRange(bytes.array(), i, Math.min(i + chunkSize, bytes.length())))); } return subStrings; } - @Override public NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java index 3effd05c9ef20..6919812c3000c 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java @@ -12,6 +12,8 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; @@ -23,7 +25,9 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; @@ -36,11 +40,11 @@ public class ChunkedTrainedModelRestorerIT extends MlSingleNodeTestCase { public void testRestoreWithMultipleSearches() throws IOException, InterruptedException { String modelId = "test-multiple-searches"; int numDocs = 22; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -72,11 +76,11 @@ public void testRestoreWithMultipleSearches() throws IOException, InterruptedExc public void testCancel() throws IOException, InterruptedException { String modelId = "test-cancel-search"; int numDocs = 6; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -126,11 +130,11 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte String modelId = "test-multiple-indices"; int numDocs = 24; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -166,14 +170,14 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte assertEquals(actualDocs, reorderedDocs); } - private List createModelDefinitionDocs(List compressedDefinitions, String modelId) { - int totalLength = compressedDefinitions.stream().map(String::length).reduce(0, Integer::sum); + private List createModelDefinitionDocs(List compressedDefinitions, String modelId) { + int totalLength = compressedDefinitions.stream().map(BytesReference::length).reduce(0, Integer::sum); List docs = new ArrayList<>(); for (int i = 0; i < compressedDefinitions.size(); i++) { docs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) - .setCompressedString(compressedDefinitions.get(i)) + .setBinaryData(compressedDefinitions.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setTotalDefinitionLength(totalLength) .setDefinitionLength(compressedDefinitions.get(i).length()) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java index e79a6ef02eca2..5fb179486b5cf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java @@ -9,14 +9,37 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; +import java.util.Base64; public class TrainedModelDefinitionDocTests extends AbstractXContentTestCase { private final boolean isLenient = randomBoolean(); + public void testParsingDocWithCompressedString() throws IOException { + byte[] bytes = randomByteArrayOfLength(50); + String base64 = Base64.getEncoder().encodeToString(bytes); + + // The previous storage format was a base64 encoded string. + // The new format should parse and decode the string storing the raw bytes. + String compressedStringDoc = "{\"doc_type\":\"trained_model_definition_doc\"," + + "\"model_id\":\"bntHUo\"," + + "\"doc_num\":6," + + "\"definition_length\":7," + + "\"total_definition_length\":13," + + "\"compression_version\":3," + + "\"definition\":\"" + base64 + "\"," + + "\"eos\":false}"; + + try (XContentParser parser = createParser(JsonXContent.jsonXContent, compressedStringDoc)) { + TrainedModelDefinitionDoc parsed = doParseInstance(parser); + assertArrayEquals(bytes, parsed.getBinaryData().array()); + } + } + @Override protected TrainedModelDefinitionDoc doParseInstance(XContentParser parser) throws IOException { return TrainedModelDefinitionDoc.fromXContent(parser, isLenient).build(); From 6b8b4f0a74ee18fd98214d64f0859ac9be15c6ec Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 14 Apr 2021 12:00:11 +0100 Subject: [PATCH 5/9] Actually update version instead of just the comment --- .../core/ml/inference/persistence/InferenceIndexConstants.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index f8e01c989d4ff..60ff6b56727e5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -30,7 +30,7 @@ public final class InferenceIndexConstants { * version: UNKNOWN_MERGED_ON_FEATURE_BRANCH: 000004 TODO * - adds binary_definition for TrainedModelDefinitionDoc */ - public static final String INDEX_VERSION = "000003"; + public static final String INDEX_VERSION = "000004"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; From 93fb22619bbeaccbdc7de27beb119cbc7ef00c5c Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 14 Apr 2021 12:44:21 +0100 Subject: [PATCH 6/9] More tests --- .../core/ml/inference/TrainedModelConfig.java | 25 ++++++++++++------- .../ChunkedTrainedModelPersisterIT.java | 4 +-- .../TrainedModelDefinitionDocTests.java | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 1de2c618fc55e..38247509aeeb1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; @@ -352,7 +353,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { - builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedDefinition()); + builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getBase64CompressedDefinition()); } } builder.field(TAGS.getPreferredName(), tags); @@ -764,7 +765,7 @@ public TrainedModelConfig build() { } } - public static class LazyModelDefinition implements ToXContentObject, Writeable { + static class LazyModelDefinition implements ToXContentObject, Writeable { private BytesReference compressedRepresentation; private TrainedModelDefinition parsedDefinition; @@ -805,14 +806,23 @@ private LazyModelDefinition(BytesReference compressedRepresentation, TrainedMode this.parsedDefinition = trainedModelDefinition; } - public BytesReference getCompressedDefinition() throws IOException { + private BytesReference getCompressedDefinition() throws IOException { if (compressedRepresentation == null) { compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition); } return compressedRepresentation; } - public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { + private String getBase64CompressedDefinition() throws IOException { + BytesReference compressedDef = getCompressedDefinition(); + + ByteBuffer bb = Base64.getEncoder().encode( + ByteBuffer.wrap(compressedDef.array(), compressedDef.arrayOffset(), compressedDef.length())); + + return new String(bb.array(), StandardCharsets.UTF_8); + } + + private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { if (parsedDefinition == null) { parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation, parser -> TrainedModelDefinition.fromXContent(parser, true).build(), @@ -823,10 +833,9 @@ public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throw @Override public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport - out.writeBytesReference(compressedRepresentation); + out.writeBytesReference(getCompressedDefinition()); } else { - String base64String = new String(Base64.getEncoder().encode(compressedRepresentation.array()), StandardCharsets.UTF_8); - out.writeString(base64String); + out.writeString(getBase64CompressedDefinition()); } } @@ -852,7 +861,5 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(compressedRepresentation, parsedDefinition); } - } - } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index e8da96e2da43e..74a0cfa9161df 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -47,7 +47,6 @@ import org.junit.Before; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -96,7 +95,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); persister.createAndIndexInferenceModelConfig(modelSizeInfo, configBuilder.getModelType()); for (int i = 0; i < base64Chunks.size(); i++) { - persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(base64Chunks.get(i), i, i == (base64Chunks.size() - 1))); + persister.createAndIndexInferenceModelDoc( + new TrainedModelDefinitionChunk(base64Chunks.get(i), i, i == (base64Chunks.size() - 1))); } ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) .limit(randomIntBetween(1, 10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java index 5fb179486b5cf..c44a988d6a52c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java @@ -19,7 +19,7 @@ public class TrainedModelDefinitionDocTests extends AbstractXContentTestCase Date: Thu, 15 Apr 2021 09:57:21 +0100 Subject: [PATCH 7/9] Always check for EOS --- .../persistence/TrainedModelProvider.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index a9c596cb47eee..497ed89b76c4f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -1096,18 +1096,16 @@ private static BytesReference getDefinitionFromDocs(List Date: Tue, 20 Apr 2021 15:27:23 +0100 Subject: [PATCH 8/9] Say size is in bytes --- .../inference/persistence/TrainedModelProvider.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 497ed89b76c4f..ca0945b5d4330 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -113,9 +113,9 @@ public class TrainedModelProvider { public static final Set MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1"); private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/"; private static final String MODEL_RESOURCE_FILE_EXT = ".json"; - private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024; + private static final int COMPRESSED_MODEL_CHUNK_SIZE = 16 * 1024 * 1024; private static final int MAX_NUM_DEFINITION_DOCS = 100; - private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; + private static final int MAX_COMPRESSED_MODEL_SIZE = COMPRESSED_MODEL_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; @@ -293,15 +293,15 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi List trainedModelDefinitionDocs = new ArrayList<>(); try { BytesReference compressedDefinition = trainedModelConfig.getCompressedDefinition(); - if (compressedDefinition.length() > MAX_COMPRESSED_STRING_SIZE) { + if (compressedDefinition.length() > MAX_COMPRESSED_MODEL_SIZE) { listener.onFailure( ExceptionsHelper.badRequestException( - "Unable to store model as compressed definition has length [{}] the limit is [{}]", + "Unable to store model as compressed definition of size [{}] bytes the limit is [{}] bytes", compressedDefinition.length(), - MAX_COMPRESSED_STRING_SIZE)); + MAX_COMPRESSED_MODEL_SIZE)); return; } - List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_STRING_CHUNK_SIZE); + List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_MODEL_CHUNK_SIZE); for(int i = 0; i < chunkedDefinition.size(); ++i) { trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) From bb2638c63c41328fe807febaea55a30ee8205a17 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 20 Apr 2021 15:30:24 +0100 Subject: [PATCH 9/9] review changes --- .../xpack/ml/inference/persistence/TrainedModelProvider.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index ca0945b5d4330..78ac16d42ba80 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -1091,7 +1091,7 @@ private static BytesReference getDefinitionFromDocs(List