-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Passing input type through to cohere request #104781
Changes from 6 commits
ee9477d
e129e4c
4182cd5
e3b0b9b
fed3ebf
9ddc9b5
e081684
6ab8503
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
package org.elasticsearch.xpack.core.inference.action; | ||
|
||
import org.elasticsearch.ElasticsearchStatusException; | ||
import org.elasticsearch.TransportVersion; | ||
import org.elasticsearch.TransportVersions; | ||
import org.elasticsearch.action.ActionRequest; | ||
import org.elasticsearch.action.ActionRequestValidationException; | ||
|
@@ -96,7 +97,7 @@ public Request(StreamInput in) throws IOException { | |
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { | ||
this.inputType = in.readEnum(InputType.class); | ||
} else { | ||
this.inputType = InputType.INGEST; | ||
this.inputType = InputType.UNSPECIFIED; | ||
} | ||
} | ||
|
||
|
@@ -146,11 +147,22 @@ public void writeTo(StreamOutput out) throws IOException { | |
out.writeString(input.get(0)); | ||
} | ||
out.writeGenericMap(taskSettings); | ||
// in version ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED the input type enum was added, so we only want to write the enum if we're | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double check this code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM |
||
// at that version or later | ||
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) { | ||
out.writeEnum(inputType); | ||
out.writeEnum(getInputTypeToWrite(out.getTransportVersion())); | ||
} | ||
} | ||
|
||
private InputType getInputTypeToWrite(TransportVersion version) { | ||
// in version ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED the UNSPECIFIED value was added, so if we're before that | ||
// version other nodes won't know about it, so set it to INGEST instead | ||
if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) && inputType == InputType.UNSPECIFIED) { | ||
return InputType.INGEST; | ||
} | ||
return inputType; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
|
@@ -203,7 +215,7 @@ public Builder setTaskSettings(Map<String, Object> taskSettings) { | |
} | ||
|
||
public Request build() { | ||
return new Request(taskType, inferenceEntityId, input, taskSettings, InputType.INGEST); | ||
return new Request(taskType, inferenceEntityId, input, taskSettings, InputType.UNSPECIFIED); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we directly use UNSPECIFIED in this constructor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. It'd probably be clearer if we add a setter on the builder. Currently the builder is only used in the rest code path. We don't expose the input type parameter to rest requests (only internal requests) so it'll always be unspecified. |
||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,8 @@ | |
import java.util.List; | ||
import java.util.Objects; | ||
|
||
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage; | ||
|
||
public record CohereEmbeddingsRequestEntity( | ||
List<String> input, | ||
CohereEmbeddingsTaskSettings taskSettings, | ||
|
@@ -29,14 +31,6 @@ public record CohereEmbeddingsRequestEntity( | |
|
||
private static final String SEARCH_DOCUMENT = "search_document"; | ||
private static final String SEARCH_QUERY = "search_query"; | ||
/** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm going to go with a switch case here instead because that seems a little cleaner now that we have a value we don't want to allow. |
||
* Maps the {@link InputType} to the expected value for cohere for the input_type field in the request using the enum's ordinal. | ||
* The order of these entries is important and needs to match the order in the enum | ||
*/ | ||
private static final String[] INPUT_TYPE_MAPPING = { SEARCH_DOCUMENT, SEARCH_QUERY }; | ||
static { | ||
assert INPUT_TYPE_MAPPING.length == InputType.values().length : "input type mapping was incorrectly defined"; | ||
} | ||
|
||
private static final String TEXTS_FIELD = "texts"; | ||
|
||
|
@@ -56,23 +50,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |
builder.field(CohereServiceSettings.MODEL, model); | ||
} | ||
|
||
if (taskSettings.inputType() != null) { | ||
builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.inputType())); | ||
if (taskSettings.getInputType() != null) { | ||
builder.field(INPUT_TYPE_FIELD, covertToString(taskSettings.getInputType())); | ||
} | ||
|
||
if (embeddingType != null) { | ||
builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType)); | ||
} | ||
|
||
if (taskSettings.truncation() != null) { | ||
builder.field(CohereServiceFields.TRUNCATE, taskSettings.truncation()); | ||
if (taskSettings.getTruncation() != null) { | ||
builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); | ||
} | ||
|
||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
private static String covertToString(InputType inputType) { | ||
return INPUT_TYPE_MAPPING[inputType.ordinal()]; | ||
// default for testing | ||
static String covertToString(InputType inputType) { | ||
return switch (inputType) { | ||
case INGEST -> SEARCH_DOCUMENT; | ||
case SEARCH -> SEARCH_QUERY; | ||
default -> { | ||
assert false : invalidInputTypeMessage(inputType); | ||
yield null; | ||
} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,10 @@ | |
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.common.ValidationException; | ||
import org.elasticsearch.common.settings.SecureString; | ||
import org.elasticsearch.core.CheckedFunction; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.core.Strings; | ||
import org.elasticsearch.inference.InferenceService; | ||
import org.elasticsearch.inference.InputType; | ||
import org.elasticsearch.inference.Model; | ||
import org.elasticsearch.inference.TaskType; | ||
import org.elasticsearch.rest.RestStatus; | ||
|
@@ -110,7 +110,7 @@ public static String mustBeNonEmptyString(String settingName, String scope) { | |
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); | ||
} | ||
|
||
public static String invalidValue(String settingName, String scope, String invalidType, String... requiredTypes) { | ||
public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredTypes) { | ||
return Strings.format( | ||
"[%s] Invalid value [%s] received. [%s] must be one of [%s]", | ||
scope, | ||
|
@@ -225,25 +225,43 @@ public static <T> T extractOptionalEnum( | |
Map<String, Object> map, | ||
String settingName, | ||
String scope, | ||
CheckedFunction<String, T, IllegalArgumentException> converter, | ||
T[] validTypes, | ||
EnumConstructor<T> constructor, | ||
T[] validValues, | ||
ValidationException validationException | ||
) { | ||
var enumString = extractOptionalString(map, settingName, scope, validationException); | ||
if (enumString == null) { | ||
return null; | ||
} | ||
|
||
var validTypesAsStrings = Arrays.stream(validTypes).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); | ||
var validValuesAsStrings = Arrays.stream(validValues).map(type -> type.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); | ||
try { | ||
return converter.apply(enumString); | ||
var createdEnum = constructor.apply(enumString); | ||
validateEnumValue(createdEnum, validValues); | ||
|
||
return createdEnum; | ||
} catch (IllegalArgumentException e) { | ||
validationException.addValidationError(invalidValue(settingName, scope, enumString, validTypesAsStrings)); | ||
validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings)); | ||
} | ||
|
||
return null; | ||
} | ||
|
||
private static <T> void validateEnumValue(T enumValue, T[] validValues) { | ||
if (Arrays.asList(validValues).contains(enumValue) == false) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you achieve the same by passing an Nit: in the docs a generic enum parameter is referred to as https://docs.oracle.com/javase/8/docs/api/java/util/EnumSet.html |
||
throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString())); | ||
} | ||
} | ||
|
||
/** | ||
* Functional interface for creating an enum from a string. | ||
* @param <T> | ||
*/ | ||
@FunctionalInterface | ||
public interface EnumConstructor<T> { | ||
T apply(String name) throws IllegalArgumentException; | ||
} | ||
|
||
public static String parsePersistedConfigErrorMsg(String inferenceEntityId, String serviceName) { | ||
return format( | ||
"Failed to parse stored model [%s] for [%s] service, please delete and add the service again", | ||
|
@@ -272,7 +290,7 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod | |
public static void getEmbeddingSize(Model model, InferenceService service, ActionListener<Integer> listener) { | ||
assert model.getTaskType() == TaskType.TEXT_EMBEDDING; | ||
|
||
service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), listener.delegateFailureAndWrap((delegate, r) -> { | ||
service.infer(model, List.of(TEST_EMBEDDING_INPUT), Map.of(), InputType.INGEST, listener.delegateFailureAndWrap((delegate, r) -> { | ||
if (r instanceof TextEmbedding embeddingResults) { | ||
try { | ||
delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.inference.InputType; | ||
import org.elasticsearch.inference.Model; | ||
import org.elasticsearch.inference.ModelConfigurations; | ||
import org.elasticsearch.inference.ModelSecrets; | ||
|
@@ -123,6 +124,7 @@ public void doInfer( | |
Model model, | ||
List<String> input, | ||
Map<String, Object> taskSettings, | ||
InputType inputType, | ||
ActionListener<InferenceServiceResults> listener | ||
) { | ||
if (model instanceof CohereModel == false) { | ||
|
@@ -133,7 +135,7 @@ public void doInfer( | |
CohereModel cohereModel = (CohereModel) model; | ||
var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); | ||
|
||
var action = cohereModel.accept(actionCreator, taskSettings); | ||
var action = cohereModel.accept(actionCreator, taskSettings, inputType); | ||
action.execute(input, listener); | ||
} | ||
|
||
|
@@ -174,6 +176,6 @@ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsMo | |
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.ML_INFERENCE_COHERE_EMBEDDINGS_ADDED; | ||
return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bumping the cohere service min version just in case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not quite sure what the implication of increasing the minimal supported version of a class is. Do you know what happens when we change this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it means that if a cluster has nodes with different versions and one of the nodes is on |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to other names for this. Maybe
UNKNOWN
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UNSPECIFIED
is what it is, it's a good name