diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java index 70f4998a7f96c..b7f7532668b77 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; @@ -29,8 +28,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Map; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; @@ -47,24 +44,24 @@ public final class InferenceToXContentCompressor { private InferenceToXContentCompressor() {} - public static String deflate(T objectToCompress) throws IOException { + public static BytesReference deflate(T objectToCompress) throws IOException { BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false); return deflate(reference); } - public static T inflate(String compressedString, + public static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry) throws IOException { - return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); + return inflate(compressedBytes, parserFunction, xContentRegistry, MAX_INFLATED_BYTES); } - static T inflate(String compressedString, + static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry, long maxBytes) throws IOException { try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, maxBytes))) { + inflate(compressedBytes, maxBytes))) { return parserFunction.apply(parser); } catch (XContentParseException parseException) { SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause = @@ -82,32 +79,31 @@ static T inflate(String compressedString, } } - static Map inflateToMap(String compressedString) throws IOException { + static Map inflateToMap(BytesReference compressedBytes) throws IOException { // Don't need the xcontent registry as we are not deflating named objects. try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, MAX_INFLATED_BYTES))) { + inflate(compressedBytes, MAX_INFLATED_BYTES))) { return parser.mapOrdered(); } } - static InputStream inflate(String compressedString, long streamSize) throws IOException { - byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + static InputStream inflate(BytesReference compressedBytes, long streamSize) throws IOException { // If the compressed length is already too large, it make sense that the inflated length would be as well // In the extremely small string case, the compressed data could actually be longer than the compressed stream - if (compressedBytes.length > Math.max(100L, streamSize)) { + if (compressedBytes.length() > Math.max(100L, streamSize)) { throw new CircuitBreakingException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]", CircuitBreaker.Durability.PERMANENT); } - InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); + InputStream gzipStream = new GZIPInputStream(compressedBytes.streamInput(), BUFFER_SIZE); return new SimpleBoundedInputStream(gzipStream, streamSize); } - private static String deflate(BytesReference reference) throws IOException { + private static BytesReference deflate(BytesReference reference) throws IOException { BytesStreamOutput out = new BytesStreamOutput(); try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { reference.writeTo(compressedOutput); } - return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8); + return out.bytes(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 9781eaad70177..38247509aeeb1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -36,8 +38,11 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -248,15 +253,15 @@ public InferenceConfig getInferenceConfig() { } @Nullable - public String getCompressedDefinition() throws IOException { + public BytesReference getCompressedDefinition() throws IOException { if (definition == null) { return null; } - return definition.getCompressedString(); + return definition.getCompressedDefinition(); } public void clearCompressed() { - definition.compressedString = null; + definition.compressedRepresentation = null; } public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { @@ -348,7 +353,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { builder.field(DEFINITION.getPreferredName(), definition); } else { - builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); + builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getBase64CompressedDefinition()); } } builder.field(TAGS.getPreferredName(), tags); @@ -564,11 +569,11 @@ public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { return this; } - public Builder setDefinitionFromString(String definitionFromString) { - if (definitionFromString == null) { + public Builder setDefinitionFromBytes(BytesReference definition) { + if (definition == null) { return this; } - this.definition = LazyModelDefinition.fromCompressedString(definitionFromString); + this.definition = LazyModelDefinition.fromCompressedData(definition); return this; } @@ -605,7 +610,7 @@ private Builder setLazyDefinition(String compressedString) { DEFINITION.getPreferredName()) .getFormattedMessage()); } - this.definition = LazyModelDefinition.fromCompressedString(compressedString); + this.definition = LazyModelDefinition.fromBase64String(compressedString); return this; } @@ -760,56 +765,78 @@ public TrainedModelConfig build() { } } - public static class LazyModelDefinition implements ToXContentObject, Writeable { + static class LazyModelDefinition implements ToXContentObject, Writeable { - private String compressedString; + private BytesReference compressedRepresentation; private TrainedModelDefinition parsedDefinition; public static LazyModelDefinition fromParsedDefinition(TrainedModelDefinition definition) { return new LazyModelDefinition(null, definition); } - public static LazyModelDefinition fromCompressedString(String compressedString) { - return new LazyModelDefinition(compressedString, null); + public static LazyModelDefinition fromCompressedData(BytesReference compressed) { + return new LazyModelDefinition(compressed, null); + } + + public static LazyModelDefinition fromBase64String(String base64String) { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + return new LazyModelDefinition(new BytesArray(decodedBytes), null); } public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOException { - return new LazyModelDefinition(input.readString(), null); + if (input.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + return new LazyModelDefinition(input.readBytesReference(), null); + } else { + return fromBase64String(input.readString()); + } } private LazyModelDefinition(LazyModelDefinition definition) { if (definition != null) { - this.compressedString = definition.compressedString; + this.compressedRepresentation = definition.compressedRepresentation; this.parsedDefinition = definition.parsedDefinition; } } - private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) { - if (compressedString == null && trainedModelDefinition == null) { + private LazyModelDefinition(BytesReference compressedRepresentation, TrainedModelDefinition trainedModelDefinition) { + if (compressedRepresentation == null && trainedModelDefinition == null) { throw new IllegalArgumentException("unexpected null model definition"); } - this.compressedString = compressedString; + this.compressedRepresentation = compressedRepresentation; this.parsedDefinition = trainedModelDefinition; } - public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { - if (parsedDefinition == null) { - parsedDefinition = InferenceToXContentCompressor.inflate(compressedString, - parser -> TrainedModelDefinition.fromXContent(parser, true).build(), - xContentRegistry); + private BytesReference getCompressedDefinition() throws IOException { + if (compressedRepresentation == null) { + compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition); } + return compressedRepresentation; } - public String getCompressedString() throws IOException { - if (compressedString == null) { - compressedString = InferenceToXContentCompressor.deflate(parsedDefinition); + private String getBase64CompressedDefinition() throws IOException { + BytesReference compressedDef = getCompressedDefinition(); + + ByteBuffer bb = Base64.getEncoder().encode( + ByteBuffer.wrap(compressedDef.array(), compressedDef.arrayOffset(), compressedDef.length())); + + return new String(bb.array(), StandardCharsets.UTF_8); + } + + private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException { + if (parsedDefinition == null) { + parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation, + parser -> TrainedModelDefinition.fromXContent(parser, true).build(), + xContentRegistry); } - return compressedString; } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(getCompressedString()); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport + out.writeBytesReference(getCompressedDefinition()); + } else { + out.writeString(getBase64CompressedDefinition()); + } } @Override @@ -817,7 +844,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (parsedDefinition != null) { return parsedDefinition.toXContent(builder, params); } - Map map = InferenceToXContentCompressor.inflateToMap(compressedString); + Map map = InferenceToXContentCompressor.inflateToMap(compressedRepresentation); return builder.map(map); } @@ -826,15 +853,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LazyModelDefinition that = (LazyModelDefinition) o; - return Objects.equals(compressedString, that.compressedString) && + return Objects.equals(compressedRepresentation, that.compressedRepresentation) && Objects.equals(parsedDefinition, that.parsedDefinition); } @Override public int hashCode() { - return Objects.hash(compressedString, parsedDefinition); + return Objects.hash(compressedRepresentation, parsedDefinition); } - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 5f93c6d7b7cfc..60ff6b56727e5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -26,8 +26,11 @@ public final class InferenceIndexConstants { * * version: 7.10.0: 000003 * - adds trained_model_metadata object + * + * version: UNKNOWN_MERGED_ON_FEATURE_BRANCH: 000004 TODO + * - adds binary_definition for TrainedModelDefinitionDoc */ - public static final String INDEX_VERSION = "000003"; + public static final String INDEX_VERSION = "000004"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json index 171cbabc52c30..bf164f049dc43 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_mappings.json @@ -44,6 +44,9 @@ "definition": { "enabled": false }, + "binary_definition": { + "type": "binary" + }, "compression_version": { "type": "long" }, @@ -135,7 +138,7 @@ "supplied": { "type": "boolean" } - } + } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java index 47a131f5d758c..6a98cdb49550a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; @@ -16,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -27,7 +28,7 @@ public class InferenceToXContentCompressorTests extends ESTestCase { public void testInflateAndDeflate() throws IOException { for(int i = 0; i < 10; i++) { TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate, parser -> TrainedModelDefinition.fromXContent(parser, false).build(), xContentRegistry()); @@ -45,8 +46,8 @@ public void testInflateTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String firstDeflate = InferenceToXContentCompressor.deflate(definition); - int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition); + int max = firstDeflate.length() + 10; IOException ex = expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max))); assertThat(ex.getMessage(), equalTo("" + @@ -54,7 +55,8 @@ public void testInflateTooLargeStream() throws IOException { } public void testInflateGarbage() { - expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L))); + expectThrows(IOException.class, () -> Streams.readFully( + InferenceToXContentCompressor.inflate(new BytesArray(randomByteArrayOfLength(10)), 100L))); } public void testInflateParsingTooLargeStream() throws IOException { @@ -65,8 +67,8 @@ public void testInflateParsingTooLargeStream() throws IOException { .limit(100) .collect(Collectors.toList())) .build(); - String compressedString = InferenceToXContentCompressor.deflate(definition); - int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10; + BytesReference compressedString = InferenceToXContentCompressor.deflate(definition); + int max = compressedString.length() + 10; CircuitBreakingException e = expectThrows(CircuitBreakingException.class, ()-> InferenceToXContentCompressor.inflate( compressedString, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index c2755b57014f7..12a8146c0fe61 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -115,8 +115,7 @@ protected NamedXContentRegistry xContentRegistry() { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - List entries = new ArrayList<>(); - entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + List entries = new ArrayList<>(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } @@ -170,7 +169,7 @@ public void testToXContentWithParams() throws IOException { assertThat(reference.utf8ToString(), containsString("\"definition\"")); assertThat(reference.utf8ToString(), not(containsString("compressed_definition"))); } - + public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { TrainedModelConfig.LazyModelDefinition lazyModelDefinition = TrainedModelConfig.LazyModelDefinition .fromParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder().build()); @@ -263,9 +262,9 @@ public void testSerializationWithLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); @@ -294,10 +293,10 @@ public void testSerializationWithCompressedLazyDefinition() throws IOException { xContentTester(this::createParser, () -> { try { - String compressedString = + BytesReference bytes = InferenceToXContentCompressor.deflate(TrainedModelDefinitionTests.createRandomBuilder().build()); return createTestInstance(randomAlphaOfLength(10)) - .setDefinitionFromString(compressedString) + .setDefinitionFromBytes(bytes) .build(); } catch (IOException ex) { fail(ex.getMessage()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index d1a486dbe0b8a..58cd3eaad2d2b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; +import com.unboundid.util.Base64; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; @@ -24,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import java.io.IOException; +import java.text.ParseException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -59,7 +61,7 @@ public void testTreeSchemaDeserialization() throws IOException { assertThat(definition.getTrainedModel().getClass(), equalTo(TreeInferenceModel.class)); } - public void testMultiClassIrisInference() throws IOException { + public void testMultiClassIrisInference() throws IOException, ParseException { // Fairly simple, random forest classification model built to fit in our format // Trained on the well known Iris dataset String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" + @@ -83,7 +85,8 @@ public void testMultiClassIrisInference() throws IOException { "aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" + "fdpTi9JB0sDp2JR7b309mn5HuPkEAAA=="; - InferenceDefinition definition = InferenceToXContentCompressor.inflate(compressedDef, + byte[] bytes = Base64.decode(compressedDef); + InferenceDefinition definition = InferenceToXContentCompressor.inflate(new BytesArray(bytes), InferenceDefinition::fromXContent, xContentRegistry()); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index c2779ac8bbb74..74a0cfa9161df 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.license.License; @@ -27,8 +28,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaselineTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.HyperparametersTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; @@ -47,6 +48,8 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; @@ -78,9 +81,8 @@ public void testStoreModelViaChunkedPersister() throws IOException { .build(); List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); - String compressedDefinition = configBuilder.build().getCompressedDefinition(); - int totalSize = compressedDefinition.length(); - List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); + BytesReference compressedDefinition = configBuilder.build().getCompressedDefinition(); + List base64Chunks = chunkBinaryDefinition(compressedDefinition, compressedDefinition.length() / 3); ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider, analyticsConfig, @@ -92,8 +94,9 @@ public void testStoreModelViaChunkedPersister() throws IOException { //Accuracy for size is not tested here ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); persister.createAndIndexInferenceModelConfig(modelSizeInfo, configBuilder.getModelType()); - for (int i = 0; i < chunks.size(); i++) { - persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); + for (int i = 0; i < base64Chunks.size(); i++) { + persister.createAndIndexInferenceModelDoc( + new TrainedModelDefinitionChunk(base64Chunks.get(i), i, i == (base64Chunks.size() - 1))); } ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) .limit(randomIntBetween(1, 10)) @@ -153,14 +156,15 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setInput(TrainedModelInputTests.createRandomInput()); } - public static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize); - for (int i = 0; i < str.length(); i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + public static List chunkBinaryDefinition(BytesReference bytes, int chunkSize) { + List subStrings = new ArrayList<>((bytes.length() + chunkSize - 1) / chunkSize); + for (int i = 0; i < bytes.length(); i += chunkSize) { + subStrings.add( + Base64.getEncoder().encodeToString( + Arrays.copyOfRange(bytes.array(), i, Math.min(i + chunkSize, bytes.length())))); } return subStrings; } - @Override public NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java index 3effd05c9ef20..6919812c3000c 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelRestorerIT.java @@ -12,6 +12,8 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; @@ -23,7 +25,9 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; @@ -36,11 +40,11 @@ public class ChunkedTrainedModelRestorerIT extends MlSingleNodeTestCase { public void testRestoreWithMultipleSearches() throws IOException, InterruptedException { String modelId = "test-multiple-searches"; int numDocs = 22; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -72,11 +76,11 @@ public void testRestoreWithMultipleSearches() throws IOException, InterruptedExc public void testCancel() throws IOException, InterruptedException { String modelId = "test-cancel-search"; int numDocs = 6; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -126,11 +130,11 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte String modelId = "test-multiple-indices"; int numDocs = 24; - List modelDefs = new ArrayList<>(numDocs); + List modelDefs = new ArrayList<>(numDocs); for (int i=0; i expectedDocs = createModelDefinitionDocs(modelDefs, modelId); @@ -166,14 +170,14 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte assertEquals(actualDocs, reorderedDocs); } - private List createModelDefinitionDocs(List compressedDefinitions, String modelId) { - int totalLength = compressedDefinitions.stream().map(String::length).reduce(0, Integer::sum); + private List createModelDefinitionDocs(List compressedDefinitions, String modelId) { + int totalLength = compressedDefinitions.stream().map(BytesReference::length).reduce(0, Integer::sum); List docs = new ArrayList<>(); for (int i = 0; i < compressedDefinitions.size(); i++) { docs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) - .setCompressedString(compressedDefinitions.get(i)) + .setBinaryData(compressedDefinitions.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setTotalDefinitionLength(totalLength) .setDefinitionLength(compressedDefinitions.get(i).length()) diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index a25a3d7bc8c0e..a6e0912ac7d7e 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -46,7 +47,6 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; -import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; @@ -228,7 +228,7 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception { TrainedModelDefinitionDoc truncatedDoc = new TrainedModelDefinitionDoc.Builder() .setDocNum(0) - .setCompressedString(config.getCompressedDefinition().substring(0, config.getCompressedDefinition().length() - 10)) + .setBinaryData(config.getCompressedDefinition().slice(0, config.getCompressedDefinition().length() - 10)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(config.getCompressedDefinition().length()) .setTotalDefinitionLength(config.getCompressedDefinition().length()) @@ -351,15 +351,16 @@ public void testGetTrainedModelForInference() throws InterruptedException, IOExc assertThat(definitionHolder.get(), is(not(nullValue()))); } - private List createModelDefinitionDocs(String compressedDefinition, String modelId) { - List chunks = chunkStringWithSize(compressedDefinition, compressedDefinition.length()/3); + private List createModelDefinitionDocs(BytesReference compressedDefinition, String modelId) { + List chunks = TrainedModelProvider.chunkDefinitionWithSize(compressedDefinition, compressedDefinition.length()/3); return IntStream.range(0, chunks.size()) .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() .setDocNum(i) - .setCompressedString(chunks.get(i)) + .setBinaryData(chunks.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(chunks.get(i).length()) + .setTotalDefinitionLength(compressedDefinition.length()) .setEos(i == chunks.size() - 1) .setModelId(modelId)) .collect(Collectors.toList()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index efc3cc9eee865..ea609d6129e00 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -17,6 +19,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Objects; /** @@ -31,6 +35,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { public static final ParseField DOC_NUM = new ParseField("doc_num"); public static final ParseField DEFINITION = new ParseField("definition"); + public static final ParseField BINARY_DEFINITION = new ParseField("binary_definition"); public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version"); public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length"); @@ -47,6 +52,8 @@ private static ObjectParser createParse parser.declareString((a, b) -> {}, InferenceIndexConstants.DOC_TYPE); // type is hard coded but must be parsed parser.declareString(TrainedModelDefinitionDoc.Builder::setModelId, TrainedModelConfig.MODEL_ID); parser.declareString(TrainedModelDefinitionDoc.Builder::setCompressedString, DEFINITION); + parser.declareField(TrainedModelDefinitionDoc.Builder::setBinaryData, (p, c) -> new BytesArray(p.binaryValue()), + BINARY_DEFINITION, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); parser.declareInt(TrainedModelDefinitionDoc.Builder::setDocNum, DOC_NUM); parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); @@ -64,7 +71,7 @@ public static String docId(String modelId, int docNum) { return NAME + "-" + modelId + "-" + docNum; } - private final String compressedString; + private final BytesReference binaryData; private final String modelId; private final int docNum; // for bwc @@ -73,14 +80,14 @@ public static String docId(String modelId, int docNum) { private final int compressionVersion; private final boolean eos; - private TrainedModelDefinitionDoc(String compressedString, + private TrainedModelDefinitionDoc(BytesReference binaryData, String modelId, int docNum, Long totalDefinitionLength, long definitionLength, int compressionVersion, boolean eos) { - this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION); + this.binaryData = ExceptionsHelper.requireNonNull(binaryData, BINARY_DEFINITION); this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); if (docNum < 0) { throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0"); @@ -98,8 +105,8 @@ private TrainedModelDefinitionDoc(String compressedString, this.eos = eos; } - public String getCompressedString() { - return compressedString; + public BytesReference getBinaryData() { + return binaryData; } public String getModelId() { @@ -141,7 +148,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); } builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion); - builder.field(DEFINITION.getPreferredName(), compressedString); + builder.field(BINARY_DEFINITION.getPreferredName(), binaryData); builder.field(EOS.getPreferredName(), eos); builder.endObject(); return builder; @@ -163,18 +170,18 @@ public boolean equals(Object o) { Objects.equals(totalDefinitionLength, that.totalDefinitionLength) && Objects.equals(compressionVersion, that.compressionVersion) && Objects.equals(eos, that.eos) && - Objects.equals(compressedString, that.compressedString); + Objects.equals(binaryData, that.binaryData); } @Override public int hashCode() { - return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos); + return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, binaryData, eos); } public static class Builder { private String modelId; - private String compressedString; + private BytesReference binaryData; private int docNum; private Long totalDefinitionLength; private long definitionLength; @@ -187,7 +194,13 @@ public Builder setModelId(String modelId) { } public Builder setCompressedString(String compressedString) { - this.compressedString = compressedString; + this.binaryData = new BytesArray(Base64.getDecoder() + .decode(compressedString.getBytes(StandardCharsets.UTF_8))); + return this; + } + + public Builder setBinaryData(BytesReference binaryData) { + this.binaryData = binaryData; return this; } @@ -218,7 +231,7 @@ public Builder setEos(boolean eos) { public TrainedModelDefinitionDoc build() { return new TrainedModelDefinitionDoc( - this.compressedString, + this.binaryData, this.modelId, this.docNum, this.totalDefinitionLength, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 18473f1380610..78ac16d42ba80 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -38,6 +38,7 @@ import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.util.set.Sets; @@ -112,9 +113,9 @@ public class TrainedModelProvider { public static final Set MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1"); private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/"; private static final String MODEL_RESOURCE_FILE_EXT = ".json"; - private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024; + private static final int COMPRESSED_MODEL_CHUNK_SIZE = 16 * 1024 * 1024; private static final int MAX_NUM_DEFINITION_DOCS = 100; - private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; + private static final int MAX_COMPRESSED_MODEL_SIZE = COMPRESSED_MODEL_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; @@ -291,25 +292,25 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi List trainedModelDefinitionDocs = new ArrayList<>(); try { - String compressedString = trainedModelConfig.getCompressedDefinition(); - if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) { + BytesReference compressedDefinition = trainedModelConfig.getCompressedDefinition(); + if (compressedDefinition.length() > MAX_COMPRESSED_MODEL_SIZE) { listener.onFailure( ExceptionsHelper.badRequestException( - "Unable to store model as compressed definition has length [{}] the limit is [{}]", - compressedString.length(), - MAX_COMPRESSED_STRING_SIZE)); + "Unable to store model as compressed definition of size [{}] bytes the limit is [{}] bytes", + compressedDefinition.length(), + MAX_COMPRESSED_MODEL_SIZE)); return; } - List chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE); - for(int i = 0; i < chunkedStrings.size(); ++i) { + List chunkedDefinition = chunkDefinitionWithSize(compressedDefinition, COMPRESSED_MODEL_CHUNK_SIZE); + for(int i = 0; i < chunkedDefinition.size(); ++i) { trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() .setDocNum(i) .setModelId(trainedModelConfig.getModelId()) - .setCompressedString(chunkedStrings.get(i)) + .setBinaryData(chunkedDefinition.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(chunkedStrings.get(i).length()) + .setDefinitionLength(chunkedDefinition.get(i).length()) // If it is the last doc, it is the EOS - .setEos(i == chunkedStrings.size() - 1) + .setEos(i == chunkedDefinition.size() - 1) .build()); } } catch (IOException ex) { @@ -416,9 +417,9 @@ public void getTrainedModelForInference(final String modelId, final ActionListen modelRestorer.restoreModelDefinition(docs::add, success -> { try { - String compressedString = getDefinitionFromDocs(docs, modelId); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, + compressedData, InferenceDefinition::fromXContent, xContentRegistry); @@ -534,8 +535,8 @@ public void getTrainedModel(final String modelId, ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource( bytes, resourceId, xContentRegistry)); try { - String compressedString = getDefinitionFromDocs(docs, modelId); - builder.setDefinitionFromString(compressedString); + BytesReference compressedData = getDefinitionFromDocs(docs, modelId); + builder.setDefinitionFromBytes(compressedData); } catch (ElasticsearchException elasticsearchException) { getTrainedModelListener.onFailure(elasticsearchException); return; @@ -1086,32 +1087,36 @@ private static List handleHits(SearchHit[] hits, return results; } - private static String getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - // BWC for when we tracked the total definition length - // TODO: remove in 9 + private static BytesReference getDefinitionFromDocs(List docs, + String modelId) throws ElasticsearchException { + + BytesReference[] bb = new BytesReference[docs.size()]; + for (int i = 0; i < docs.size(); i++) { + bb[i] = docs.get(i).getBinaryData(); + } + BytesReference bytes = CompositeBytesReference.of(bb); + if (docs.get(0).getTotalDefinitionLength() != null) { - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); - } - } else { - TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); - // Either we are missing the last doc, or some previous doc - if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + if (bytes.length() != docs.get(0).getTotalDefinitionLength()) { throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); } } - return compressedString; + + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if (lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); + } + return bytes; } - static List chunkStringWithSize(String str, int chunkSize) { - List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); - for (int i = 0; i < str.length();i += chunkSize) { - subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + public static List chunkDefinitionWithSize(BytesReference definition, int chunkSize) { + List chunks = new ArrayList<>((int)Math.ceil(definition.length()/(double)chunkSize)); + for (int i = 0; i < definition.length();i += chunkSize) { + BytesReference chunk = definition.slice(i, Math.min(chunkSize, definition.length() - i)); + chunks.add(chunk); } - return subStrings; + return chunks; } private TrainedModelConfig.Builder parseModelConfigLenientlyFromSource(BytesReference source, String modelId) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java index 1134d767d59a7..338821aa72ab1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java @@ -19,8 +19,6 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Locale; import java.util.Objects; import java.util.concurrent.ExecutorService; @@ -82,9 +80,9 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr modelSizeWritten = true; } - byte[] rawBytes = Base64.getDecoder().decode(doc.getCompressedString().getBytes(StandardCharsets.UTF_8)); - outputStream.write(rawBytes); - + // The array backing the BytesReference may be bigger than what is + // referred to so write only what is after the offset + outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length()); return true; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java index 73cea41c1449a..c44a988d6a52c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDocTests.java @@ -7,14 +7,38 @@ package org.elasticsearch.xpack.ml.inference.persistence; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; +import java.util.Base64; public class TrainedModelDefinitionDocTests extends AbstractXContentTestCase { - private boolean isLenient = randomBoolean(); + private final boolean isLenient = randomBoolean(); + + public void testParsingDocWithCompressedStringDefinition() throws IOException { + byte[] bytes = randomByteArrayOfLength(50); + String base64 = Base64.getEncoder().encodeToString(bytes); + + // The previous storage format was a base64 encoded string. + // The new format should parse and decode the string storing the raw bytes. + String compressedStringDoc = "{\"doc_type\":\"trained_model_definition_doc\"," + + "\"model_id\":\"bntHUo\"," + + "\"doc_num\":6," + + "\"definition_length\":7," + + "\"total_definition_length\":13," + + "\"compression_version\":3," + + "\"definition\":\"" + base64 + "\"," + + "\"eos\":false}"; + + try (XContentParser parser = createParser(JsonXContent.jsonXContent, compressedStringDoc)) { + TrainedModelDefinitionDoc parsed = doParseInstance(parser); + assertArrayEquals(bytes, parsed.getBinaryData().array()); + } + } @Override protected TrainedModelDefinitionDoc doParseInstance(XContentParser parser) throws IOException { @@ -28,12 +52,13 @@ protected boolean supportsUnknownFields() { @Override protected TrainedModelDefinitionDoc createTestInstance() { - int length = randomIntBetween(1, 10); + int length = randomIntBetween(4, 10); + return new TrainedModelDefinitionDoc.Builder() .setModelId(randomAlphaOfLength(6)) .setDefinitionLength(length) .setTotalDefinitionLength(randomIntBetween(length, length *2)) - .setCompressedString(randomAlphaOfLength(length)) + .setBinaryData(new BytesArray(randomByteArrayOfLength(length))) .setDocNum(randomIntBetween(0, 10)) .setCompressionVersion(randomIntBetween(1, 5)) .setEos(randomBoolean()) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index e447dc29de321..f353a8221d459 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -9,6 +9,8 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.ConstantScoreQueryBuilder; @@ -25,9 +27,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -146,6 +150,24 @@ public void testGetModelThatExistsAsResourceButIsMissing() { assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model"))); } + public void testChunkDefinitionWithSize() { + int totalLength = 100; + int size = 30; + + byte[] bytes = randomByteArrayOfLength(totalLength); + List chunks = TrainedModelProvider.chunkDefinitionWithSize(new BytesArray(bytes), size); + assertThat(chunks, hasSize(4)); + int start = 0; + int end = size; + for (BytesReference chunk : chunks) { + assertArrayEquals(Arrays.copyOfRange(bytes, start, end), + Arrays.copyOfRange(chunk.array(), chunk.arrayOffset(), chunk.arrayOffset() + chunk.length())); + + start += size; + end = Math.min(end + size, totalLength); + } + } + @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());