Skip to content

Commit

Permalink
[ML][Inference][HLRC] GET trained models (#49464)
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent authored Nov 22, 2019
1 parent 9360dc9 commit 9006926
Show file tree
Hide file tree
Showing 13 changed files with 676 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.OpenJobRequest;
import org.elasticsearch.client.ml.PostCalendarEventRequest;
Expand Down Expand Up @@ -709,6 +710,38 @@ static Request estimateMemoryUsage(PutDataFrameAnalyticsRequest estimateRequest)
return request;
}

static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest) {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml", "inference")
.addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIds()))
.build();
RequestConverters.Params params = new RequestConverters.Params();
if (getTrainedModelsRequest.getPageParams() != null) {
PageParams pageParams = getTrainedModelsRequest.getPageParams();
if (pageParams.getFrom() != null) {
params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString());
}
if (pageParams.getSize() != null) {
params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString());
}
}
if (getTrainedModelsRequest.getAllowNoMatch() != null) {
params.putParam(GetTrainedModelsRequest.ALLOW_NO_MATCH,
Boolean.toString(getTrainedModelsRequest.getAllowNoMatch()));
}
if (getTrainedModelsRequest.getDecompressDefinition() != null) {
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
}
if (getTrainedModelsRequest.getIncludeDefinition() != null) {
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;
}

static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
import org.elasticsearch.client.ml.GetOverallBucketsResponse;
import org.elasticsearch.client.ml.GetRecordsRequest;
import org.elasticsearch.client.ml.GetRecordsResponse;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
import org.elasticsearch.client.ml.MlInfoRequest;
import org.elasticsearch.client.ml.MlInfoResponse;
import org.elasticsearch.client.ml.OpenJobRequest;
Expand Down Expand Up @@ -2290,4 +2292,48 @@ public Cancellable estimateMemoryUsageAsync(PutDataFrameAnalyticsRequest request
listener,
Collections.emptySet());
}

/**
* Gets trained model configs
* <p>
* For additional info
* see <a href="TODO">
* GET Trained Model Configs documentation</a>
*
* @param request The {@link GetTrainedModelsRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @return {@link GetTrainedModelsResponse} response object
*/
public GetTrainedModelsResponse getTrainedModels(GetTrainedModelsRequest request,
RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::getTrainedModels,
options,
GetTrainedModelsResponse::fromXContent,
Collections.emptySet());
}

/**
* Gets trained model configs asynchronously and notifies listener upon completion
* <p>
* For additional info
* see <a href="TODO">
* GET Trained Model Configs documentation</a>
*
* @param request The {@link GetTrainedModelsRequest}
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
* @return cancellable that may be used to cancel the request
*/
public Cancellable getTrainedModelsAsync(GetTrainedModelsRequest request,
RequestOptions options,
ActionListener<GetTrainedModelsResponse> listener) {
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::getTrainedModels,
options,
GetTrainedModelsResponse::fromXContent,
listener,
Collections.emptySet());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.common.Nullable;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class GetTrainedModelsRequest implements Validatable {

public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";

private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private PageParams pageParams;

/**
* Helper method to create a request that will get ALL TrainedModelConfigs
* @return new {@link GetTrainedModelsRequest} object for the id "_all"
*/
public static GetTrainedModelsRequest getAllTrainedModelConfigsRequest() {
return new GetTrainedModelsRequest("_all");
}

public GetTrainedModelsRequest(String... ids) {
this.ids = Arrays.asList(ids);
}

public List<String> getIds() {
return ids;
}

public Boolean getAllowNoMatch() {
return allowNoMatch;
}

/**
* Whether to ignore if a wildcard expression matches no trained models.
*
* @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all})
* does not match any trained models
*/
public GetTrainedModelsRequest setAllowNoMatch(boolean allowNoMatch) {
this.allowNoMatch = allowNoMatch;
return this;
}

public PageParams getPageParams() {
return pageParams;
}

public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) {
this.pageParams = pageParams;
return this;
}

public Boolean getIncludeDefinition() {
return includeDefinition;
}

/**
* Whether to include the full model definition.
*
* The full model definition can be very large.
*
* @param includeDefinition If {@code true}, the definition is included.
*/
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
this.includeDefinition = includeDefinition;
return this;
}

public Boolean getDecompressDefinition() {
return decompressDefinition;
}

/**
* Whether or not to decompress the trained model, or keep it in its compressed string form
*
* @param decompressDefinition If {@code true}, the definition is decompressed.
*/
public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinition) {
this.decompressDefinition = decompressDefinition;
return this;
}

@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
return Optional.of(ValidationException.withError("trained model id must not be null"));
}
return Optional.empty();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GetTrainedModelsRequest other = (GetTrainedModelsRequest) o;
return Objects.equals(ids, other.ids)
&& Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(decompressDefinition, other.decompressDefinition)
&& Objects.equals(includeDefinition, other.includeDefinition)
&& Objects.equals(pageParams, other.pageParams);
}

@Override
public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.client.ml;

import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;

import java.util.List;
import java.util.Objects;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;

public class GetTrainedModelsResponse {

public static final ParseField TRAINED_MODEL_CONFIGS = new ParseField("trained_model_configs");
public static final ParseField COUNT = new ParseField("count");

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<GetTrainedModelsResponse, Void> PARSER =
new ConstructingObjectParser<>(
"get_trained_model_configs",
true,
args -> new GetTrainedModelsResponse((List<TrainedModelConfig>) args[0], (Long) args[1]));

static {
PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelConfig.fromXContent(p), TRAINED_MODEL_CONFIGS);
PARSER.declareLong(constructorArg(), COUNT);
}

public static GetTrainedModelsResponse fromXContent(final XContentParser parser) {
return PARSER.apply(parser, null);
}

private final List<TrainedModelConfig> trainedModels;
private final Long count;


public GetTrainedModelsResponse(List<TrainedModelConfig> trainedModels, Long count) {
this.trainedModels = trainedModels;
this.count = count;
}

public List<TrainedModelConfig> getTrainedModels() {
return trainedModels;
}

/**
* @return The total count of the trained models that matched the ID pattern.
*/
public Long getCount() {
return count;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

GetTrainedModelsResponse other = (GetTrainedModelsResponse) o;
return Objects.equals(this.trainedModels, other.trainedModels) && Objects.equals(this.count, other.count);
}

@Override
public int hashCode() {
return Objects.hash(trainedModels, count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
}

public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null).build();
}

private final String modelId;
Expand Down Expand Up @@ -293,12 +293,12 @@ public Builder setModelId(String modelId) {
return this;
}

private Builder setCreatedBy(String createdBy) {
public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}

private Builder setVersion(Version version) {
public Builder setVersion(Version version) {
this.version = version;
return this;
}
Expand All @@ -312,7 +312,7 @@ public Builder setDescription(String description) {
return this;
}

private Builder setCreateTime(Instant createTime) {
public Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
Expand Down Expand Up @@ -347,17 +347,17 @@ public Builder setInput(TrainedModelInput input) {
return this;
}

private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}

private Builder setEstimatedOperations(Long estimatedOperations) {
public Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}

private Builder setLicenseLevel(String licenseLevel) {
public Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
Expand Down
Loading

0 comments on commit 9006926

Please sign in to comment.