Skip to content

Commit

Permalink
Adding trained model metadata class. (elastic#106988)
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret authored Apr 22, 2024
1 parent 8ed92db commit 80bafa3
Show file tree
Hide file tree
Showing 20 changed files with 948 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ static TransportVersion def(int id) {
public static final TransportVersion WATERMARK_THRESHOLDS_STATS = def(8_637_00_0);
public static final TransportVersion ENRICH_CACHE_ADDITIONAL_STATS = def(8_638_00_0);
public static final TransportVersion ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED = def(8_639_00_0);
public static final TransportVersion ML_TRAINED_MODEL_CACHE_METADATA_ADDED = def(8_640_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.TimeValue;

import java.io.IOException;
import java.util.Objects;

public class FlushTrainedModelCacheAction extends ActionType<AcknowledgedResponse> {

public static final FlushTrainedModelCacheAction INSTANCE = new FlushTrainedModelCacheAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/clear_model_cache";

private FlushTrainedModelCacheAction() {
super(NAME);
}

public static class Request extends AcknowledgedRequest<FlushTrainedModelCacheAction.Request> {
public Request() {
super();
}

Request(TimeValue timeout) {
super(timeout);
}

public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public int hashCode() {
return Objects.hashCode(ackTimeout());
}

@Override
public boolean equals(Object other) {
if (other == this) return true;
if (other == null || getClass() != other.getClass()) return false;
Request that = (Request) other;
return Objects.equals(that.ackTimeout(), ackTimeout());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.AbstractNamedDiffable;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.NamedDiff;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.Objects;

public class TrainedModelCacheMetadata extends AbstractNamedDiffable<Metadata.Custom> implements Metadata.Custom {
public static final String NAME = "trained_model_cache_metadata";
public static final TrainedModelCacheMetadata EMPTY = new TrainedModelCacheMetadata(0L);
private static final ParseField VERSION_FIELD = new ParseField("version");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<TrainedModelCacheMetadata, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
args -> new TrainedModelCacheMetadata((long) args[0])
);

static {
PARSER.declareLong(ConstructingObjectParser.constructorArg(), VERSION_FIELD);
}

public static TrainedModelCacheMetadata fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public static TrainedModelCacheMetadata fromState(ClusterState clusterState) {
TrainedModelCacheMetadata cacheMetadata = clusterState.getMetadata().custom(NAME);
return cacheMetadata == null ? EMPTY : cacheMetadata;
}

public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput streamInput) throws IOException {
return readDiffFrom(Metadata.Custom.class, NAME, streamInput);
}

private final long version;

public TrainedModelCacheMetadata(long version) {
this.version = version;
}

public TrainedModelCacheMetadata(StreamInput in) throws IOException {
this.version = in.readVLong();
}

public long version() {
return version;
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params ignored) {
return Iterators.single(((builder, params) -> { return builder.field(VERSION_FIELD.getPreferredName(), version); }));
}

@Override
public EnumSet<Metadata.XContentContext> context() {
return Metadata.ALL_CONTEXTS;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_TRAINED_MODEL_CACHE_METADATA_ADDED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(version);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelCacheMetadata that = (TrainedModelCacheMetadata) o;
return Objects.equals(version, that.version);
}

@Override
public int hashCode() {
return Objects.hash(version);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.FlushTrainedModelCacheAction.Request;

import java.io.IOException;

public class FlushTrainedModelCacheActionRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
@Override
protected Request createTestInstance() {
return randomBoolean() ? new Request() : new Request(randomTimeout());
}

@Override
protected Request mutateInstance(Request instance) throws IOException {
return new Request(randomValueOtherThan(instance.timeout(), this::randomTimeout));
}

@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}

@Override
protected Request mutateInstanceForVersion(Request instance, TransportVersion version) {
return instance;
}

private TimeValue randomTimeout() {
return TimeValue.parseTimeValue(randomTimeValue(), null, "timeout");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractChunkedSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;

public class TrainedModelCacheMetadataTests extends AbstractChunkedSerializingTestCase<TrainedModelCacheMetadata> {
@Override
protected TrainedModelCacheMetadata doParseInstance(XContentParser parser) throws IOException {
return TrainedModelCacheMetadata.fromXContent(parser);
}

@Override
protected Writeable.Reader<TrainedModelCacheMetadata> instanceReader() {
return TrainedModelCacheMetadata::new;
}

@Override
protected TrainedModelCacheMetadata createTestInstance() {
return new TrainedModelCacheMetadata(randomNonNegativeLong());
}

@Override
protected TrainedModelCacheMetadata mutateInstance(TrainedModelCacheMetadata instance) {
return new TrainedModelCacheMetadata(randomValueOtherThan(instance.version(), () -> randomNonNegativeLong()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase {

@Before
public void setupModelAndData() throws IOException {
putRegressionModel(MODEL_ID, """
putLearningToRankModel(MODEL_ID, """
{
"description": "super complex model for tests",
"input": { "field_names": ["cost", "product"] },
Expand Down Expand Up @@ -328,6 +328,95 @@ public void testLtrCanMatch() throws Exception {
assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
}

@SuppressWarnings("unchecked")
public void testModelCacheIsFlushedOnModelChange() throws IOException {
String searchBody = """
{
"rescore": {
"window_size": 10,
"learning_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""";

Response searchResponse = searchDfs(searchBody);
Map<String, Object> response = responseAsMap(searchResponse);
assertThat(
response.toString(),
(List<Double>) XContentMapValues.extractValue("hits.hits._score", response),
contains(20.0, 20.0, 9.0, 9.0, 6.0)
);

deleteLearningToRankModel(MODEL_ID);
putLearningToRankModel(MODEL_ID, """
{
"input": { "field_names": ["cost"] },
"inference_config": {
"learning_to_rank": {
"feature_extractors": [
{
"query_extractor": {
"feature_name": "cost",
"query": {
"script_score": {
"query": { "match_all": {} },
"script": { "source": "return doc[\\"cost\\"].value" }
}
}
}
}
]
}
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": ["cost"],
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["cost"],
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_gain": 12,
"threshold": 1000,
"decision_type": "lt",
"default_left": true,
"left_child": 1,
"right_child": 2
},
{
"node_index": 1,
"leaf_value": 1.0
},
{
"node_index": 2,
"leaf_value": 10
}
],
"target_type": "regression"
}
}
]
}
}
}
}
""");

searchResponse = searchDfs(searchBody);
response = responseAsMap(searchResponse);
assertThat(
response.toString(),
(List<Double>) XContentMapValues.extractValue("hits.hits._score", response),
contains(10.0, 1.0, 1.0, 1.0, 1.0)
);
}

private void indexData(String data) throws IOException {
Request request = new Request("POST", INDEX_NAME + "/_doc");
request.setJsonEntity(data);
Expand All @@ -354,7 +443,12 @@ private Response searchCanMatch(String searchBody, boolean dfs) throws IOExcepti
return client().performRequest(request);
}

private void putRegressionModel(String modelId, String body) throws IOException {
private void deleteLearningToRankModel(String modelId) throws IOException {
Request model = new Request("DELETE", "_ml/trained_models/" + modelId);
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}

private void putLearningToRankModel(String modelId, String body) throws IOException {
Request model = new Request("PUT", "_ml/trained_models/" + modelId);
model.setJsonEntity(body);
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
Expand Down Expand Up @@ -352,6 +353,12 @@ protected void ensureClusterStateConsistency() throws IOException {
);
entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
entries.add(
new NamedWriteableRegistry.Entry(Metadata.Custom.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::new)
);
entries.add(
new NamedWriteableRegistry.Entry(NamedDiff.class, TrainedModelCacheMetadata.NAME, TrainedModelCacheMetadata::readDiffFrom)
);
entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));
entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, IndexLifecycleMetadata.TYPE, IndexLifecycleMetadata::new));
entries.add(
Expand Down
Loading

0 comments on commit 80bafa3

Please sign in to comment.