-
Notifications
You must be signed in to change notification settings - Fork 25k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML][Inference][HLRC] GET trained models (#49464)
- Loading branch information
Showing
13 changed files
with
676 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
139 changes: 139 additions & 0 deletions
139
...nt/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* Licensed to Elasticsearch under one or more contributor | ||
* license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright | ||
* ownership. Elasticsearch licenses this file to you under | ||
* the Apache License, Version 2.0 (the "License"); you may | ||
* not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.elasticsearch.client.ml; | ||
|
||
import org.elasticsearch.client.Validatable; | ||
import org.elasticsearch.client.ValidationException; | ||
import org.elasticsearch.client.core.PageParams; | ||
import org.elasticsearch.common.Nullable; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
public class GetTrainedModelsRequest implements Validatable { | ||
|
||
public static final String ALLOW_NO_MATCH = "allow_no_match"; | ||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; | ||
public static final String DECOMPRESS_DEFINITION = "decompress_definition"; | ||
|
||
private final List<String> ids; | ||
private Boolean allowNoMatch; | ||
private Boolean includeDefinition; | ||
private Boolean decompressDefinition; | ||
private PageParams pageParams; | ||
|
||
/** | ||
* Helper method to create a request that will get ALL TrainedModelConfigs | ||
* @return new {@link GetTrainedModelsRequest} object for the id "_all" | ||
*/ | ||
public static GetTrainedModelsRequest getAllTrainedModelConfigsRequest() { | ||
return new GetTrainedModelsRequest("_all"); | ||
} | ||
|
||
public GetTrainedModelsRequest(String... ids) { | ||
this.ids = Arrays.asList(ids); | ||
} | ||
|
||
public List<String> getIds() { | ||
return ids; | ||
} | ||
|
||
public Boolean getAllowNoMatch() { | ||
return allowNoMatch; | ||
} | ||
|
||
/** | ||
* Whether to ignore if a wildcard expression matches no trained models. | ||
* | ||
* @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) | ||
* does not match any trained models | ||
*/ | ||
public GetTrainedModelsRequest setAllowNoMatch(boolean allowNoMatch) { | ||
this.allowNoMatch = allowNoMatch; | ||
return this; | ||
} | ||
|
||
public PageParams getPageParams() { | ||
return pageParams; | ||
} | ||
|
||
public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) { | ||
this.pageParams = pageParams; | ||
return this; | ||
} | ||
|
||
public Boolean getIncludeDefinition() { | ||
return includeDefinition; | ||
} | ||
|
||
/** | ||
* Whether to include the full model definition. | ||
* | ||
* The full model definition can be very large. | ||
* | ||
* @param includeDefinition If {@code true}, the definition is included. | ||
*/ | ||
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) { | ||
this.includeDefinition = includeDefinition; | ||
return this; | ||
} | ||
|
||
public Boolean getDecompressDefinition() { | ||
return decompressDefinition; | ||
} | ||
|
||
/** | ||
* Whether or not to decompress the trained model, or keep it in its compressed string form | ||
* | ||
* @param decompressDefinition If {@code true}, the definition is decompressed. | ||
*/ | ||
public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinition) { | ||
this.decompressDefinition = decompressDefinition; | ||
return this; | ||
} | ||
|
||
@Override | ||
public Optional<ValidationException> validate() { | ||
if (ids == null || ids.isEmpty()) { | ||
return Optional.of(ValidationException.withError("trained model id must not be null")); | ||
} | ||
return Optional.empty(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
|
||
GetTrainedModelsRequest other = (GetTrainedModelsRequest) o; | ||
return Objects.equals(ids, other.ids) | ||
&& Objects.equals(allowNoMatch, other.allowNoMatch) | ||
&& Objects.equals(decompressDefinition, other.decompressDefinition) | ||
&& Objects.equals(includeDefinition, other.includeDefinition) | ||
&& Objects.equals(pageParams, other.pageParams); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition); | ||
} | ||
} |
86 changes: 86 additions & 0 deletions
86
...t/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Licensed to Elasticsearch under one or more contributor | ||
* license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright | ||
* ownership. Elasticsearch licenses this file to you under | ||
* the Apache License, Version 2.0 (the "License"); you may | ||
* not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.elasticsearch.client.ml; | ||
|
||
import org.elasticsearch.client.ml.inference.TrainedModelConfig; | ||
import org.elasticsearch.common.ParseField; | ||
import org.elasticsearch.common.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.common.xcontent.XContentParser; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; | ||
|
||
public class GetTrainedModelsResponse { | ||
|
||
public static final ParseField TRAINED_MODEL_CONFIGS = new ParseField("trained_model_configs"); | ||
public static final ParseField COUNT = new ParseField("count"); | ||
|
||
@SuppressWarnings("unchecked") | ||
static final ConstructingObjectParser<GetTrainedModelsResponse, Void> PARSER = | ||
new ConstructingObjectParser<>( | ||
"get_trained_model_configs", | ||
true, | ||
args -> new GetTrainedModelsResponse((List<TrainedModelConfig>) args[0], (Long) args[1])); | ||
|
||
static { | ||
PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelConfig.fromXContent(p), TRAINED_MODEL_CONFIGS); | ||
PARSER.declareLong(constructorArg(), COUNT); | ||
} | ||
|
||
public static GetTrainedModelsResponse fromXContent(final XContentParser parser) { | ||
return PARSER.apply(parser, null); | ||
} | ||
|
||
private final List<TrainedModelConfig> trainedModels; | ||
private final Long count; | ||
|
||
|
||
public GetTrainedModelsResponse(List<TrainedModelConfig> trainedModels, Long count) { | ||
this.trainedModels = trainedModels; | ||
this.count = count; | ||
} | ||
|
||
public List<TrainedModelConfig> getTrainedModels() { | ||
return trainedModels; | ||
} | ||
|
||
/** | ||
* @return The total count of the trained models that matched the ID pattern. | ||
*/ | ||
public Long getCount() { | ||
return count; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
|
||
GetTrainedModelsResponse other = (GetTrainedModelsResponse) o; | ||
return Objects.equals(this.trainedModels, other.trainedModels) && Objects.equals(this.count, other.count); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(trainedModels, count); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.