-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
semantic_text - Field inference #103697
semantic_text - Field inference #103697
Conversation
…om:Mikep86/elasticsearch into store-semantic_text-model-info-in-mappings
…inference # Conflicts: # server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added infra for doing YAML tests on ML plugin
@@ -42,6 +42,9 @@ | |||
import java.util.Set; | |||
import java.util.stream.Collectors; | |||
|
|||
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the dependencies so the field mappers depend on the constants defined in server
code. That makes sense as server
code is the one generating the embeddings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you are doing here. OK, I can understand that. Basically, the thing that is satisfying the interface gets access to this param.
This is better than it was. It does seem backwards, the plugin should know the mapper & how it extracts things. But, this is better before :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it would be more contained if everything was in the plugin. That will mean to move the inference generation class to the plugin, and inject it into the TransportBulkAction
and BulkOperation
classes.
Let me check how that would look like in a separate branch.
@@ -256,32 +252,8 @@ public void testMissingSubfields() throws IOException { | |||
); | |||
assertThat( | |||
ex.getMessage(), | |||
containsString( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test isn't needed as there won't be an additional level of nesting for the results
public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); | ||
|
||
private static final Map<List<String>, Set<String>> REQUIRED_SUBFIELDS_MAP = Map.of( | ||
List.of(), | ||
Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The structure for the embeddings changes a bit. The field mapper was prepared to have an additional nesting level, but that is not required as the asMap()
method from the results does not return the information on that format
import java.util.function.Consumer; | ||
import java.util.stream.Collectors; | ||
|
||
public class BulkShardRequestInferenceProvider { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the dependency, so this class defines the constants and they are used from the field mappers. LMK if this addresses your concerns.
public static ElasticsearchCluster cluster = ElasticsearchCluster.local() | ||
.setting("xpack.security.enabled", "false") | ||
.setting("xpack.security.http.ssl.enabled", "false") | ||
.plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uses the TestInferenceServicePlugin
, which defines a mock inference service to be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jimczi LMKWYT of these integration tests. Do you think it would be valuable to use the _bulk
API in a separate test suite to test it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial review. Looking good!
k -> new HashMap<String, Object>() | ||
); | ||
|
||
List<String> inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could simplify the method signature to getFieldNamesForInference(Set<String> inferenceFields, Map<String, Object> docMap)
by passing fieldModelsEntrySet.getValue()
. This would also make it clearer that this helper method doesn't use the model ID.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! 👍
} | ||
|
||
private static List<String> getFieldNamesForInference(Map.Entry<String, Set<String>> fieldModelsEntrySet, Map<String, Object> docMap) { | ||
List<String> inferenceFieldNames = new ArrayList<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor optimization: We could pre-allocate an ArrayList
of the maximum required size by using the inference field set size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend not to do that for small lists - it will be expanded to 10 elements when the first element is added, so probably just removes one list expansion.
// Perform inference on string, non-null values | ||
if (fieldValue instanceof String) { | ||
inferenceFieldNames.add(inferenceField); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to handle when the field value is a non-null & non-String value (i.e. when the user has provided a value with an invalid data type)? Or will that be handled somewhere downstream/upstream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, we don't as of now.
There are some cases to consider, I'll work on them
- Array values: We could treat these as chunking, and perform inference on every array value
- Non-string values: Convert them to strings before doing inference. Don't error out when we have a non-string value, similar to how text works.
My only concern on converting non-strings is that we're doing inference on fields where inference makes no sense - and potentially incurring in costs - instead of warning the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't even thinking of multi-valued text fields (i.e. array of strings), but that's a case we need to handle here as well.
I was thinking of handling obvious error cases, such as when the value is a Map
and can't be converted to a string in a sensible way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked how the text
field handles this:
Primitive value (i.e. string, bool, number): Coerce to string
Array of primitive values: Index each value separately, coerce each value to string
Object (i.e. Map
): Throw error
Array containing an object: Throw error
List<Map<String, Object>> inferenceFieldResultList = (List<Map<String, Object>>) rootInferenceFieldMap | ||
.computeIfAbsent(fieldName, k -> new ArrayList<>()); | ||
// Remove previous inference results if any | ||
inferenceFieldResultList.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we always remove previous inference results, doesn't that mean we will re-run inference for every semantic_text
field on a reindex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're totally correct. We need additional logic to handle that.
This code is correct as when we receive back inference results, we want to remove the previous inference.
But, we should avoid to recalculate on reindex:
- We should always avoid calculating inference for an index action if there are already inference results.
- We should always recalculate inference for an update action for the included inference fields in the request
I'll work on that and add some more tests. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also break this out into a follow-up task if it makes the scope of this PR too big. It's already pretty chonky 😵💫
String modelId = fieldModelsEntrySet.getKey(); | ||
|
||
@SuppressWarnings("unchecked") | ||
Map<String, Object> rootInferenceFieldMap = (Map<String, Object>) docMap.computeIfAbsent( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we plan on handling cast errors here? It's theoretically possible for the user to provide a value for the _semantic_text_inference
field with an invalid data type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say we're safe as we're casting to Map<String, Object> - and the document source will be parsed already at this point. I don't think that this can fail if it's valid JSON (as should be at that stage) 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the document source will be parsed already at this point.
Parsed by whom?
There are three scenarios here:
- User doing a "put/post" with the field already defined (obviously, hasn't been indexed already)
- A Reindex occurring, obviously, this one is OK as it was indexed previously somewhere (hopefully by us :/)
- A new document where this field exists or doesn't based on previously inferenced values (more than one field, meaning it doesn't exist for the first inference result but does for next).
This is the tricky part of having things that are generally mapper validates further up in the ingest pipeline.
So, we need to validate things are what we expect. We don't need to do a full parse of the internals (the mapper does this), but I am not sure blindly casting is wise.
Do we have tests covering the above scenarios I laid out yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the variable docMap
is retrieved using sourceAsMap()
from the index or update request, which I believe invokes the XContentParser. So, it should be parsed at this point, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue is that _source
can be valid JSON without meeting the cast type expectations. For example:
{
"_semantic_text_inference": "foo"
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, for some reason I was not seeing that. Thanks for catching this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the index or update request, which I believe invokes the XContentParser. So, it should be parsed at this point, right?
It can be valid JSON (or CBOR, or SMILE), but it can 100% be invalid for what we care about.
} | ||
|
||
tasks.named('yamlRestTest') { | ||
usesDefaultDistribution() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it strictly required that we use the default distribution here? Can these tests simply explicitly install the plugins/modules necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with this. Could you please point me to some docs explaining the process, or examples that don't use the default distribution to check? 🙏
@@ -0,0 +1,15 @@ | |||
apply plugin: 'elasticsearch.internal-yaml-rest-test' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to create another QA project here. We can just apply this plugin to the the :x-pack:plugin:inference
project and run these tests there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that was a common pattern - moved directly under the plugin root dir 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of necessity mostly. The new testing framework make it unnecessary in almost all circumstances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noted, thank you Mark!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these tests. I think the infrastructure is in place and we can iterate from here to add the missing pieces. +1 to merge on the branch so that we can start working on batching as a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like how this is progressing :)
server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice tests!
@@ -1941,13 +1941,16 @@ protected void assertSnapshotOrGenericThread() { | |||
client, | |||
null, | |||
() -> DocumentParsingObserver.EMPTY_INSTANCE | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Odd place for whitespace
- match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" } | ||
- match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" } | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how TestInferenceServicePlugin
is generating embeddings, but is it possible to test that the embeddings have changed here?
|
||
- match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } | ||
- match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we know that (currently, at least) TransportBulkAction
will re-generate embeddings for semantic_text
fields on reindex, IMO this test should indicate as such by failing right now.
Maybe we could have TestInferenceServicePlugin
generate random embeddings, regardless of input text, so we can determine when the inference service has been called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to iterate on that idea - I have already tried that, but the problem is check that something doesn't match - AFAIK there's not a not_match
construct for YAML tests that would provide support for failing the comparison
* "dragon": 0.50991, | ||
* "type": 0.23241979, | ||
* "dr": 1.9312073, | ||
* "##o": 0.2797593 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Indentation is off here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updates to this class are incomplete, there are still references to SparseEmbeddingResults.Embedding.EMBEDDING
& SparseEmbeddingResults.Embedding.IS_TRUNCATED
. The tests still pass because the corresponding references in the tests still exist as well.
I can handle updating SemanticTextInferenceResultFieldMapper
& SemanticTextInferenceResultFieldMapperTests
in a separate PR if you like; it would also keep the scope of this PR more limited.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, good catch - it would help if you can push the changes to this branch, or tackle that as a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll like to handle this in a separate PR if that's OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, thanks Mike!
@benwtrent , @Mikep86, @jimczi : I've been working on adding support for:
I'm making good progress, but it's adding quite a few lines to this PR. I'd like to add that iteratively if that's ok with you. If you're missing something that can be added afterwards, please let me know and we can discuss. Thanks! |
@carlosdelest Agree that we should iterate on this through multiple PRs. This one is already huge! Can we just ensure that we capture all the follow-ups before closing this PR? |
Sounds good to me @carlosdelest do your thing :). As long as we get tests and such. |
My thing seems to be iterating on this PR forever. That's my idea of purgatory as of now. 👿 Thank you Ben!
I've already implemented some YAML tests that handle some of the cases. I'll get back to them when I add support for the missing pieces. |
Thanks everyone for your input and guidance on this PR. I'm merging it on the feature branch. |
ca65a70
into
elastic:feature/semantic-text
Performs inference in
TransportBulkAction
. For every group ofBulkShardRequests
, it performs inference on each individual request.Bulk inference is done at the document and model level. We can extend this in the future for multiple docs.
For now, there is no chunking - but the source format is prepared to deal with it as it stores arrays of embeddings and text pairs that will be used in nested queries for passage retrieval: