Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

semantic_text - Field inference #103697

Conversation

carlosdelest
Copy link
Member

@carlosdelest carlosdelest commented Dec 22, 2023

Performs inference in TransportBulkAction. For every group of BulkShardRequests, it performs inference on each individual request.

Bulk inference is done at the document and model level. We can extend this in the future for multiple docs.

For now, there is no chunking - but the source format is prepared to deal with it as it stores arrays of embeddings and text pairs that will be used in nested queries for passage retrieval:

{
  "infer_field": "these are not the droids you're looking for. He's free to go around",
  "another_infer_field": "Carry on. Carry on",
  "non_infer_field": "hello",
  "_semantic_text_inference": {
    "infer_field": [
      {
        "sparse_embedding": {
          "play": 0.34588584,
          "legend": 0.005075309,
          "about": 0.13270257,
          "ship": 0.13503131,
          "anime": 0.31627595,
          "walk": 0.30274966
        },
        "text": "these are not the droids you're looking for. He's free to go around"
      }
    ],
    "another_infer_field": [
      {
        "sparse_embedding": {
          "gift": 0.027486322,
          "ryan": 0.67748386,
          "possession": 0.37753758,
          "bring": 0.88360184,
          "pocket": 0.08802759
        },
        "text": "Carry on. Carry on"
      }
    ]
  }
}

Mikep86 and others added 30 commits December 8, 2023 18:10
…om:Mikep86/elasticsearch into store-semantic_text-model-info-in-mappings
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added infra for doing YAML tests on ML plugin

@@ -42,6 +42,9 @@
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the dependencies so the field mappers depend on the constants defined in server code. That makes sense as server code is the one generating the embeddings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are doing here. OK, I can understand that. Basically, the thing that is satisfying the interface gets access to this param.

This is better than it was. It does seem backwards, the plugin should know the mapper & how it extracts things. But, this is better before :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be more contained if everything was in the plugin. That will mean to move the inference generation class to the plugin, and inject it into the TransportBulkAction and BulkOperation classes.

Let me check how that would look like in a separate branch.

@@ -256,32 +252,8 @@ public void testMissingSubfields() throws IOException {
);
assertThat(
ex.getMessage(),
containsString(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test isn't needed as there won't be an additional level of nesting for the results

public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper());

private static final Map<List<String>, Set<String>> REQUIRED_SUBFIELDS_MAP = Map.of(
List.of(),
Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The structure for the embeddings changes a bit. The field mapper was prepared to have an additional nesting level, but that is not required as the asMap() method from the results does not return the information on that format

import java.util.function.Consumer;
import java.util.stream.Collectors;

public class BulkShardRequestInferenceProvider {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the dependency, so this class defines the constants and they are used from the field mappers. LMK if this addresses your concerns.

public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.setting("xpack.security.enabled", "false")
.setting("xpack.security.http.ssl.enabled", "false")
.plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uses the TestInferenceServicePlugin, which defines a mock inference service to be used

@carlosdelest carlosdelest marked this pull request as ready for review February 7, 2024 17:12
@carlosdelest carlosdelest requested a review from a team as a code owner February 7, 2024 17:12
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jimczi LMKWYT of these integration tests. Do you think it would be valuable to use the _bulk API in a separate test suite to test it?

@carlosdelest carlosdelest requested a review from jimczi February 7, 2024 18:06
Copy link
Contributor

@Mikep86 Mikep86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial review. Looking good!

k -> new HashMap<String, Object>()
);

List<String> inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could simplify the method signature to getFieldNamesForInference(Set<String> inferenceFields, Map<String, Object> docMap) by passing fieldModelsEntrySet.getValue(). This would also make it clearer that this helper method doesn't use the model ID.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! 👍

}

private static List<String> getFieldNamesForInference(Map.Entry<String, Set<String>> fieldModelsEntrySet, Map<String, Object> docMap) {
List<String> inferenceFieldNames = new ArrayList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor optimization: We could pre-allocate an ArrayList of the maximum required size by using the inference field set size

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend not to do that for small lists - it will be expanded to 10 elements when the first element is added, so probably just removes one list expansion.

// Perform inference on string, non-null values
if (fieldValue instanceof String) {
inferenceFieldNames.add(inferenceField);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle when the field value is a non-null & non-String value (i.e. when the user has provided a value with an invalid data type)? Or will that be handled somewhere downstream/upstream?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, we don't as of now.

There are some cases to consider, I'll work on them

  • Array values: We could treat these as chunking, and perform inference on every array value
  • Non-string values: Convert them to strings before doing inference. Don't error out when we have a non-string value, similar to how text works.

My only concern on converting non-strings is that we're doing inference on fields where inference makes no sense - and potentially incurring in costs - instead of warning the user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't even thinking of multi-valued text fields (i.e. array of strings), but that's a case we need to handle here as well.

I was thinking of handling obvious error cases, such as when the value is a Map and can't be converted to a string in a sensible way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked how the text field handles this:

Primitive value (i.e. string, bool, number): Coerce to string
Array of primitive values: Index each value separately, coerce each value to string
Object (i.e. Map): Throw error
Array containing an object: Throw error

List<Map<String, Object>> inferenceFieldResultList = (List<Map<String, Object>>) rootInferenceFieldMap
.computeIfAbsent(fieldName, k -> new ArrayList<>());
// Remove previous inference results if any
inferenceFieldResultList.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we always remove previous inference results, doesn't that mean we will re-run inference for every semantic_text field on a reindex?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're totally correct. We need additional logic to handle that.

This code is correct as when we receive back inference results, we want to remove the previous inference.

But, we should avoid to recalculate on reindex:

  • We should always avoid calculating inference for an index action if there are already inference results.
  • We should always recalculate inference for an update action for the included inference fields in the request

I'll work on that and add some more tests. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also break this out into a follow-up task if it makes the scope of this PR too big. It's already pretty chonky 😵‍💫

String modelId = fieldModelsEntrySet.getKey();

@SuppressWarnings("unchecked")
Map<String, Object> rootInferenceFieldMap = (Map<String, Object>) docMap.computeIfAbsent(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we plan on handling cast errors here? It's theoretically possible for the user to provide a value for the _semantic_text_inference field with an invalid data type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say we're safe as we're casting to Map<String, Object> - and the document source will be parsed already at this point. I don't think that this can fail if it's valid JSON (as should be at that stage) 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the document source will be parsed already at this point.

Parsed by whom?

There are three scenarios here:

  • User doing a "put/post" with the field already defined (obviously, hasn't been indexed already)
  • A Reindex occurring, obviously, this one is OK as it was indexed previously somewhere (hopefully by us :/)
  • A new document where this field exists or doesn't based on previously inferenced values (more than one field, meaning it doesn't exist for the first inference result but does for next).

This is the tricky part of having things that are generally mapper validates further up in the ingest pipeline.

So, we need to validate things are what we expect. We don't need to do a full parse of the internals (the mapper does this), but I am not sure blindly casting is wise.

Do we have tests covering the above scenarios I laid out yet?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the variable docMap is retrieved using sourceAsMap() from the index or update request, which I believe invokes the XContentParser. So, it should be parsed at this point, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue is that _source can be valid JSON without meeting the cast type expectations. For example:

{
  "_semantic_text_inference": "foo"
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, for some reason I was not seeing that. Thanks for catching this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the index or update request, which I believe invokes the XContentParser. So, it should be parsed at this point, right?

It can be valid JSON (or CBOR, or SMILE), but it can 100% be invalid for what we care about.

}

tasks.named('yamlRestTest') {
usesDefaultDistribution()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it strictly required that we use the default distribution here? Can these tests simply explicitly install the plugins/modules necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with this. Could you please point me to some docs explaining the process, or examples that don't use the default distribution to check? 🙏

@@ -0,0 +1,15 @@
apply plugin: 'elasticsearch.internal-yaml-rest-test'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to create another QA project here. We can just apply this plugin to the the :x-pack:plugin:inference project and run these tests there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that was a common pattern - moved directly under the plugin root dir 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of necessity mostly. The new testing framework make it unnecessary in almost all circumstances.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted, thank you Mark!

Copy link
Contributor

@jimczi jimczi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these tests. I think the infrastructure is in place and we can iterate from here to add the missing pieces. +1 to merge on the branch so that we can start working on batching as a follow up.

Copy link
Contributor

@Mikep86 Mikep86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like how this is progressing :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice tests!

@@ -1941,13 +1941,16 @@ protected void assertSnapshotOrGenericThread() {
client,
null,
() -> DocumentParsingObserver.EMPTY_INSTANCE

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Odd place for whitespace

- match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" }
- match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" }


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how TestInferenceServicePlugin is generating embeddings, but is it possible to test that the embeddings have changed here?


- match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding }
- match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we know that (currently, at least) TransportBulkAction will re-generate embeddings for semantic_text fields on reindex, IMO this test should indicate as such by failing right now.

Maybe we could have TestInferenceServicePlugin generate random embeddings, regardless of input text, so we can determine when the inference service has been called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to iterate on that idea - I have already tried that, but the problem is check that something doesn't match - AFAIK there's not a not_match construct for YAML tests that would provide support for failing the comparison

* "dragon": 0.50991,
* "type": 0.23241979,
* "dr": 1.9312073,
* "##o": 0.2797593
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Indentation is off here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updates to this class are incomplete, there are still references to SparseEmbeddingResults.Embedding.EMBEDDING & SparseEmbeddingResults.Embedding.IS_TRUNCATED. The tests still pass because the corresponding references in the tests still exist as well.

I can handle updating SemanticTextInferenceResultFieldMapper & SemanticTextInferenceResultFieldMapperTests in a separate PR if you like; it would also keep the scope of this PR more limited.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, good catch - it would help if you can push the changes to this branch, or tackle that as a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll like to handle this in a separate PR if that's OK

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks Mike!

@carlosdelest
Copy link
Member Author

@benwtrent , @Mikep86, @jimczi :

I've been working on adding support for:

  • Other field types (boolean, numbers)
  • Arrays
  • _reindex and _update_by_query so inference is not recalculated again

I'm making good progress, but it's adding quite a few lines to this PR. I'd like to add that iteratively if that's ok with you.

If you're missing something that can be added afterwards, please let me know and we can discuss.

Thanks!

@Mikep86
Copy link
Contributor

Mikep86 commented Feb 8, 2024

@carlosdelest Agree that we should iterate on this through multiple PRs. This one is already huge! Can we just ensure that we capture all the follow-ups before closing this PR?

@benwtrent
Copy link
Member

Sounds good to me @carlosdelest do your thing :). As long as we get tests and such.

@carlosdelest
Copy link
Member Author

carlosdelest commented Feb 8, 2024

Sounds good to me @carlosdelest do your thing :).

My thing seems to be iterating on this PR forever. That's my idea of purgatory as of now. 👿 Thank you Ben!

As long as we get tests and such.

I've already implemented some YAML tests that handle some of the cases. I'll get back to them when I add support for the missing pieces.

@carlosdelest
Copy link
Member Author

Thanks everyone for your input and guidance on this PR. I'm merging it on the feature branch.

@carlosdelest carlosdelest merged commit ca65a70 into elastic:feature/semantic-text Feb 9, 2024
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:ml Machine learning Team:ML Meta label for the ML team WIP
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants