Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Store compressed model definitions in ByteReferences #71679

Merged
merged 9 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand All @@ -29,8 +28,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
Expand All @@ -47,24 +44,24 @@ public final class InferenceToXContentCompressor {

private InferenceToXContentCompressor() {}

public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
public static <T extends ToXContentObject> BytesReference deflate(T objectToCompress) throws IOException {
BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
return deflate(reference);
}

public static <T> T inflate(String compressedString,
public static <T> T inflate(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
return inflate(compressedBytes, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
}

static <T> T inflate(String compressedString,
static <T> T inflate(BytesReference compressedBytes,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry,
long maxBytes) throws IOException {
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, maxBytes))) {
inflate(compressedBytes, maxBytes))) {
return parserFunction.apply(parser);
} catch (XContentParseException parseException) {
SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause =
Expand All @@ -82,32 +79,31 @@ static <T> T inflate(String compressedString,
}
}

static Map<String, Object> inflateToMap(String compressedString) throws IOException {
static Map<String, Object> inflateToMap(BytesReference compressedBytes) throws IOException {
// Don't need the xcontent registry as we are not deflating named objects.
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES))) {
inflate(compressedBytes, MAX_INFLATED_BYTES))) {
return parser.mapOrdered();
}
}

static InputStream inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
static InputStream inflate(BytesReference compressedBytes, long streamSize) throws IOException {
// If the compressed length is already too large, it make sense that the inflated length would be as well
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
if (compressedBytes.length > Math.max(100L, streamSize)) {
if (compressedBytes.length() > Math.max(100L, streamSize)) {
throw new CircuitBreakingException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]",
CircuitBreaker.Durability.PERMANENT);
}
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream gzipStream = new GZIPInputStream(compressedBytes.streamInput(), BUFFER_SIZE);
return new SimpleBoundedInputStream(gzipStream, streamSize);
}

private static String deflate(BytesReference reference) throws IOException {
private static BytesReference deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput);
}
return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this will write out the raw bytes of the GZIP. Is this what we want or do we want to run the Base64 encoder?

I thought the guarantees around base64 character sizes was one of the reasons we could skip transforming into a string?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary data is stored in lucene base64 encoded,

Ah, so since the mapping is binary we get that for free.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's handled by the Jackson JSON generator which is used by the various XContentBuilder::value(byte[] value) methods to write bytes

return out.bytes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -36,8 +38,11 @@
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -248,15 +253,15 @@ public InferenceConfig getInferenceConfig() {
}

@Nullable
public String getCompressedDefinition() throws IOException {
public BytesReference getCompressedDefinition() throws IOException {
if (definition == null) {
return null;
}
return definition.getCompressedString();
return definition.getCompressedDefinition();
}

public void clearCompressed() {
definition.compressedString = null;
definition.compressedRepresentation = null;
}

public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
Expand Down Expand Up @@ -348,7 +353,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition);
} else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getBase64CompressedDefinition());
}
}
builder.field(TAGS.getPreferredName(), tags);
Expand Down Expand Up @@ -564,11 +569,11 @@ public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
return this;
}

public Builder setDefinitionFromString(String definitionFromString) {
if (definitionFromString == null) {
public Builder setDefinitionFromBytes(BytesReference definition) {
if (definition == null) {
return this;
}
this.definition = LazyModelDefinition.fromCompressedString(definitionFromString);
this.definition = LazyModelDefinition.fromCompressedData(definition);
return this;
}

Expand Down Expand Up @@ -605,7 +610,7 @@ private Builder setLazyDefinition(String compressedString) {
DEFINITION.getPreferredName())
.getFormattedMessage());
}
this.definition = LazyModelDefinition.fromCompressedString(compressedString);
this.definition = LazyModelDefinition.fromBase64String(compressedString);
return this;
}

Expand Down Expand Up @@ -760,64 +765,86 @@ public TrainedModelConfig build() {
}
}

public static class LazyModelDefinition implements ToXContentObject, Writeable {
static class LazyModelDefinition implements ToXContentObject, Writeable {

private String compressedString;
private BytesReference compressedRepresentation;
private TrainedModelDefinition parsedDefinition;

public static LazyModelDefinition fromParsedDefinition(TrainedModelDefinition definition) {
return new LazyModelDefinition(null, definition);
}

public static LazyModelDefinition fromCompressedString(String compressedString) {
return new LazyModelDefinition(compressedString, null);
public static LazyModelDefinition fromCompressedData(BytesReference compressed) {
return new LazyModelDefinition(compressed, null);
}

public static LazyModelDefinition fromBase64String(String base64String) {
byte[] decodedBytes = Base64.getDecoder().decode(base64String);
return new LazyModelDefinition(new BytesArray(decodedBytes), null);
}

public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOException {
return new LazyModelDefinition(input.readString(), null);
if (input.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport
return new LazyModelDefinition(input.readBytesReference(), null);
} else {
return fromBase64String(input.readString());
}
}

private LazyModelDefinition(LazyModelDefinition definition) {
if (definition != null) {
this.compressedString = definition.compressedString;
this.compressedRepresentation = definition.compressedRepresentation;
this.parsedDefinition = definition.parsedDefinition;
}
}

private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
if (compressedString == null && trainedModelDefinition == null) {
private LazyModelDefinition(BytesReference compressedRepresentation, TrainedModelDefinition trainedModelDefinition) {
if (compressedRepresentation == null && trainedModelDefinition == null) {
throw new IllegalArgumentException("unexpected null model definition");
}
this.compressedString = compressedString;
this.compressedRepresentation = compressedRepresentation;
this.parsedDefinition = trainedModelDefinition;
}

public void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflate(compressedString,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry);
private BytesReference getCompressedDefinition() throws IOException {
if (compressedRepresentation == null) {
compressedRepresentation = InferenceToXContentCompressor.deflate(parsedDefinition);
}
return compressedRepresentation;
}

public String getCompressedString() throws IOException {
if (compressedString == null) {
compressedString = InferenceToXContentCompressor.deflate(parsedDefinition);
private String getBase64CompressedDefinition() throws IOException {
BytesReference compressedDef = getCompressedDefinition();

ByteBuffer bb = Base64.getEncoder().encode(
ByteBuffer.wrap(compressedDef.array(), compressedDef.arrayOffset(), compressedDef.length()));

return new String(bb.array(), StandardCharsets.UTF_8);
}

private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflate(compressedRepresentation,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry);
}
return compressedString;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(getCompressedString());
if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport
out.writeBytesReference(getCompressedDefinition());
} else {
out.writeString(getBase64CompressedDefinition());
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (parsedDefinition != null) {
return parsedDefinition.toXContent(builder, params);
}
Map<String, Object> map = InferenceToXContentCompressor.inflateToMap(compressedString);
Map<String, Object> map = InferenceToXContentCompressor.inflateToMap(compressedRepresentation);
return builder.map(map);
}

Expand All @@ -826,15 +853,13 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LazyModelDefinition that = (LazyModelDefinition) o;
return Objects.equals(compressedString, that.compressedString) &&
return Objects.equals(compressedRepresentation, that.compressedRepresentation) &&
Objects.equals(parsedDefinition, that.parsedDefinition);
}

@Override
public int hashCode() {
return Objects.hash(compressedString, parsedDefinition);
return Objects.hash(compressedRepresentation, parsedDefinition);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ public final class InferenceIndexConstants {
*
* version: 7.10.0: 000003
* - adds trained_model_metadata object
*
* version: UNKNOWN_MERGED_ON_FEATURE_BRANCH: 000004 TODO
* - adds binary_definition for TrainedModelDefinitionDoc
*/
public static final String INDEX_VERSION = "000003";
public static final String INDEX_VERSION = "000004";
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"definition": {
"enabled": false
},
"binary_definition": {
"type": "binary"
},
"compression_version": {
"type": "long"
},
Expand Down Expand Up @@ -135,7 +138,7 @@
"supplied": {
"type": "boolean"
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -16,7 +18,6 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -27,7 +28,7 @@ public class InferenceToXContentCompressorTests extends ESTestCase {
public void testInflateAndDeflate() throws IOException {
for(int i = 0; i < 10; i++) {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition);
TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
parser -> TrainedModelDefinition.fromXContent(parser, false).build(),
xContentRegistry());
Expand All @@ -45,16 +46,17 @@ public void testInflateTooLargeStream() throws IOException {
.limit(100)
.collect(Collectors.toList()))
.build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
BytesReference firstDeflate = InferenceToXContentCompressor.deflate(definition);
int max = firstDeflate.length() + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("" +
"input stream exceeded maximum bytes of [" + max + "]"));
}

public void testInflateGarbage() {
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
expectThrows(IOException.class, () -> Streams.readFully(
InferenceToXContentCompressor.inflate(new BytesArray(randomByteArrayOfLength(10)), 100L)));
}

public void testInflateParsingTooLargeStream() throws IOException {
Expand All @@ -65,8 +67,8 @@ public void testInflateParsingTooLargeStream() throws IOException {
.limit(100)
.collect(Collectors.toList()))
.build();
String compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10;
BytesReference compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.length() + 10;

CircuitBreakingException e = expectThrows(CircuitBreakingException.class, ()-> InferenceToXContentCompressor.inflate(
compressedString,
Expand Down
Loading