Skip to content

Commit

Permalink
[ML] Store compressed model definitions in ByteReferences (#71679)
Browse files Browse the repository at this point in the history
Binary data is stored in lucene base64 encoded, the same data stored in a 
Java string uses 2 bytes (UTF16) to represent each base64 character 
consuming twice the amount of memory required. The compressed
binary representation of the models can stored in ByteReferences
more efficiently. For BWC a new field mapping binary_definition 
is added .ml-inference-* and the index version incremented.
  • Loading branch information
davidkyle authored Apr 20, 2021
1 parent 91eb2cf commit 64c04e5
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand All @@ -29,8 +28,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
Expand All @@ -47,24 +44,24 @@ public final class InferenceToXContentCompressor {

private InferenceToXContentCompressor() {}

public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
public static <T extends ToXContentObject> BytesReference deflate(T objectToCompress) throws IOException {
BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
return deflate(reference);
}

public static <T> T inflate(String compressedString,
public static <T> T inflate(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
return inflate(compressedBytes, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
}

static <T> T inflate(String compressedString,
static <T> T inflate(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry,
long maxBytes) throws IOException {
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, maxBytes))) {
inflate(compressedBytes, maxBytes))) {
return parserFunction.apply(parser);
} catch (XContentParseException parseException) {
SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause =
Expand All @@ -82,32 +79,31 @@ static <T> T inflate(String compressedString,
}
}

static Map<String, Object> inflateToMap(String compressedString) throws IOException {
static Map<String, Object> inflateToMap(BytesReference compressedBytes) throws IOException {
// Don't need the xcontent registry as we are not deflating named objects.
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES))) {
inflate(compressedBytes, MAX_INFLATED_BYTES))) {
return parser.mapOrdered();
}
}

static InputStream inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
static InputStream inflate(BytesReference compressedBytes, long streamSize) throws IOException {
// If the compressed length is already too large, it make sense that the inflated length would be as well
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
if (compressedBytes.length > Math.max(100L, streamSize)) {
if (compressedBytes.length() > Math.max(100L, streamSize)) {
throw new CircuitBreakingException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]",
CircuitBreaker.Durability.PERMANENT);
}
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream gzipStream = new GZIPInputStream(compressedBytes.streamInput(), BUFFER_SIZE);
return new SimpleBoundedInputStream(gzipStream, streamSize);
}

private static String deflate(BytesReference reference) throws IOException {
private static BytesReference deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput);
}
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
return out.bytes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -36,8 +38,11 @@
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -248,15 +253,15 @@ public InferenceConfig getInferenceConfig() {
}

@Nullable
public String getCompressedDefinition() throws IOException {
public BytesReference getCompressedDefinition() throws IOException {
if (definition == null) {
return null;
}
return definition.getCompressedString();
return definition.getCompressedDefinition();
}

public void clearCompressed() {
definition.compressedString = null;
definition.compressedRepresentation = null;
}

public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
Expand Down Expand Up @@ -348,7 +353,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition);
} else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getBase64CompressedDefinition());
}
}
builder.field(TAGS.getPreferredName(), tags);
Expand Down Expand Up @@ -564,11 +569,11 @@ public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
return this;
}

public Builder setDefinitionFromString(String definitionFromString) {
if (definitionFromString == null) {
public Builder setDefinitionFromBytes(BytesReference definition) {
if (definition == null) {
return this;
}
this.definition = LazyModelDefinition.fromCompressedString(definitionFromString);
this.definition = LazyModelDefinition.fromCompressedData(definition);
return this;
}

Expand Down Expand Up @@ -605,7 +610,7 @@ private Builder setLazyDefinition(String compressedString) {
DEFINITION.getPreferredName())
.getFormattedMessage());
}
this.definition = LazyModelDefinition.fromCompressedString(compressedString);
this.definition = LazyModelDefinition.fromBase64String(compressedString);
return this;
}

Expand Down Expand Up @@ -760,64 +765,86 @@ public TrainedModelConfig build() {
}
}

public static class LazyModelDefinition implements ToXContentObject, Writeable {
static class LazyModelDefinition implements ToXContentObject, Writeable {

private String compressedString;
private BytesReference compressedRepresentation;
private TrainedModelDefinition parsedDefinition;

public static LazyModelDefinition fromParsedDefinition(TrainedModelDefinition definition) {
return new LazyModelDefinition(null, definition);
}

public static LazyModelDefinition fromCompressedString(String compressedString) {
return new LazyModelDefinition(compressedString, null);
public static LazyModelDefinition fromCompressedData(BytesReference compressed) {
return new LazyModelDefinition(compressed, null);
}

public static LazyModelDefinition fromBase64String(String base64String) {
byte[] decodedBytes = Base64.getDecoder().decode(base64String);
return new LazyModelDefinition(new BytesArray(decodedBytes), null);
}

public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOException {
return new LazyModelDefinition(input.readString(), null);
if (input.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport
return new LazyModelDefinition(input.readBytesReference(), null);
} else {
return fromBase64String(input.readString());
}
}

private LazyModelDefinition(LazyModelDefinition definition) {
if (definition != null) {
this.compressedString = definition.compressedString;
this.compressedRepresentation = definition.compressedRepresentation;
this.parsedDefinition = definition.parsedDefinition;
}
}

private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
if (compressedString == null && trainedModelDefinition == null) {
private LazyModelDefinition(BytesReference compressedRepresentation, TrainedModelDefinition trainedModelDefinition) {
if (compressedRepresentation == null && trainedModelDefinition == null) {
throw new IllegalArgumentException("unexpected null model definition");
}
this.compressedString = compressedString;
this.compressedRepresentation = compressedRepresentation;
this.parsedDefinition = trainedModelDefinition;
}

public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflate(compressedString,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry);
private BytesReference getCompressedDefinition() throws IOException {
if (compressedRepresentation == null) {
compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition);
}
return compressedRepresentation;
}

public String getCompressedString() throws IOException {
if (compressedString == null) {
compressedString = InferenceToXContentCompressor.deflate(parsedDefinition);
private String getBase64CompressedDefinition() throws IOException {
BytesReference compressedDef = getCompressedDefinition();

ByteBuffer bb = Base64.getEncoder().encode(
ByteBuffer.wrap(compressedDef.array(), compressedDef.arrayOffset(), compressedDef.length()));

return new String(bb.array(), StandardCharsets.UTF_8);
}

private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry);
}
return compressedString;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(getCompressedString());
if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport
out.writeBytesReference(getCompressedDefinition());
} else {
out.writeString(getBase64CompressedDefinition());
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (parsedDefinition != null) {
return parsedDefinition.toXContent(builder, params);
}
Map<String, Object> map = InferenceToXContentCompressor.inflateToMap(compressedString);
Map<String, Object> map = InferenceToXContentCompressor.inflateToMap(compressedRepresentation);
return builder.map(map);
}

Expand All @@ -826,15 +853,13 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LazyModelDefinition that = (LazyModelDefinition) o;
return Objects.equals(compressedString, that.compressedString) &&
return Objects.equals(compressedRepresentation, that.compressedRepresentation) &&
Objects.equals(parsedDefinition, that.parsedDefinition);
}

@Override
public int hashCode() {
return Objects.hash(compressedString, parsedDefinition);
return Objects.hash(compressedRepresentation, parsedDefinition);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ public final class InferenceIndexConstants {
*
* version: 7.10.0: 000003
* - adds trained_model_metadata object
*
* version: UNKNOWN_MERGED_ON_FEATURE_BRANCH: 000004 TODO
* - adds binary_definition for TrainedModelDefinitionDoc
*/
public static final String INDEX_VERSION = "000003";
public static final String INDEX_VERSION = "000004";
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"definition": {
"enabled": false
},
"binary_definition": {
"type": "binary"
},
"compression_version": {
"type": "long"
},
Expand Down Expand Up @@ -135,7 +138,7 @@
"supplied": {
"type": "boolean"
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -16,7 +18,6 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -27,7 +28,7 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
public void testInflateAndDeflate() throws IOException {
for(int i = 0; i < 10; i++) {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition);
TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
parser -> TrainedModelDefinition.fromXContent(parser, false).build(),
xContentRegistry());
Expand All @@ -45,16 +46,17 @@ public void testInflateTooLargeStream() throws IOException {
.limit(100)
.collect(Collectors.toList()))
.build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition);
int max = firstDeflate.length() + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("" +
"input stream exceeded maximum bytes of [" + max + "]"));
}

public void testInflateGarbage() {
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
expectThrows(IOException.class, () -> Streams.readFully(
InferenceToXContentCompressor.inflate(new BytesArray(randomByteArrayOfLength(10)), 100L)));
}

public void testInflateParsingTooLargeStream() throws IOException {
Expand All @@ -65,8 +67,8 @@ public void testInflateParsingTooLargeStream() throws IOException {
.limit(100)
.collect(Collectors.toList()))
.build();
String compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10;
BytesReference compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.length() + 10;

CircuitBreakingException e = expectThrows(CircuitBreakingException.class, ()-> InferenceToXContentCompressor.inflate(
compressedString,
Expand Down
Loading

0 comments on commit 64c04e5

Please sign in to comment.