Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate model inference to build query #20

Merged
merged 4 commits into from
Oct 19, 2022

Conversation

jmazanec15
Copy link
Member

Description

Integrates ml-commons model inference capabilities to transform NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing logic and build.gradle to fix bugs. Minor enhancement to MLCommonsCLientAccessor to add single sentence inference.

Added unit tests to test functionality.

Working on integration tests. Confirmed it works on a local cluster by setting up k-NN index with 1K docs and running the following query:

$ curl -XPOST "localhost:9200/test-index/_search?_source_excludes=cool_field&pretty" -H 'Content-Type: application/json' -d'
{
  "query": {
    "neural": {
      "cool_field": {
        "query_text": "Hello world!",
        "model_id": "rlAg04MB3cG1ZCLOBuDF",
        "k": 1000
      }
    }
  },
  "size": 5
}
'
{
  "took" : 14,
  "timed_out" : false,
  "_shards" : {
    "total" : 3,
    "successful" : 3,
    "skipped" : 0,
    "failed" : 0
  },
  "hits" : {
    "total" : {
      "value" : 1000,
      "relation" : "eq"
    },
    "max_score" : 0.002477972,
    "hits" : [
      {
        "_index" : "test-index",
        "_id" : "613",
        "_score" : 0.002477972,
        "_source" : { }
      },
      {
        "_index" : "test-index",
        "_id" : "864",
        "_score" : 0.0024763448,
        "_source" : { }
      },
      {
        "_index" : "test-index",
        "_id" : "190",
        "_score" : 0.002475292,
        "_source" : { }
      },
      {
        "_index" : "test-index",
        "_id" : "818",
        "_score" : 0.002458808,
        "_source" : { }
      },
      {
        "_index" : "test-index",
        "_id" : "340",
        "_score" : 0.0024474028,
        "_source" : { }
      }
    ]
  }
}

Issues Resolved

#14

Check List

  • New functionality includes testing.
    • All tests pass
  • New functionality has been documented.
    • New functionality has javadoc added
  • Commits are signed as per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

Integrates ml-commons model inference capabilities to transform
NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing
logic and build.gradle to fix bugs. Minor enhancement to
MLCommonsCLientAccessor to add single sentence inference.

Added uTs.

Signed-off-by: John Mazanec <[email protected]>
@jmazanec15 jmazanec15 requested a review from a team October 13, 2022 21:10
* @param inputText {@link List} of {@link String} on which inference needs to happen
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
*/
public void inferenceSentence(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this, when we already have a function which does for a list of sentences, please try to use that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im going to be writing the code to convert functionality to single sentence/single vector. Why not incorporate it here so it can be used by other features of the code as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All you need to do is do a get on the vector list, for that if we are creating a new function that is over abstraction. We already have 2 versions of this function in the class, and add third which is very specific seems to be an overkill for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep this here. I dont see a downside to having it. The name of the method "inferenceSentence" distinguishes it from "inferenceSentences" so there won't be any confusion when to use what. Also, I think single versus collection isnt so specific that it couldn't be useful outside my current use case. Many other Java interfaces/classes have methods for both.

All you need to do is do a get on the vector list,

True, but you also have to ensure only 1 vector is returned - so each client would also have to add this check/error handling. Using this function, we can centralize that check. Additionally, it is clunky to pass in a List<List<Float>> listener when you expect that there will only be 1 List<Float>

Given that the purpose of this class is to build an easy to use abstraction over MLClient, I think it fits to add this method here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add parity method with inferenceSentences where filters are passed in as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we abstract the single sentence function in the queryBuilder class only? I just want to avoid confusion around the usage and 1 more function where we will have use case of not using the TARGET_RESPONSE_FILTERS, for 1 single sentence.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think there is any usage confusion, given the method name and signature are descriptive of the functionality.

In terms of maintainability, I dont think it makes sense for other components that will want to use the inferenceSentence method to depend on the queryBuilder class. I think that the ml package should be responsible for providing easy/intuitive interaction with ml-commons for the components in the plugin and should therefore be able to house this functionality.

That being said, I think it makes sense to either integrate it into this class or build a new class that uses this class to provide higher level abstractions. In a similar case, OpenSearch has a "HighLevel" rest client that is built on top of the lower level RestClient to provide more functionality. We could do this, or we could treat the MLCommonsClientAccessor as the higher level client.

Comment on lines 198 to 200
if (vectorSupplier() != null && vectorSupplier.get() != null) {
return new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these checks, if the query builder is running the vectorSupplier will be null isn't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VectorSupplier is null initially and then gets set during rewrite. The actual vector will not get set until the async call finishes completely.

queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why setting null on the OnReponse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was how it was done for Geo: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/AbstractGeometryQueryBuilder.java#L519.

The rewrite context specifies listener response type as wildcard. That being said, we dont control the listener that is passed in, so all we can do is return null.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason of making onResponse(null) is because we have already set the response to vector supplier.

Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.

Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason of making onResponse(null) is because we have already set the response to vector supplier.

We are not setting the response, but instead setting the value of the supplier.

Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?

Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason of making onResponse(null) is because we have already set the response to vector supplier.

We are not setting the response, but instead setting the value of the supplier.

I mean here the value only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?

Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.

My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).

To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null) condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.

Please confirm this understanding is correct or not.

Copy link
Collaborator

@zane-neo zane-neo Oct 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.

I think when second rewrite happens, the supplier vector shouldn't be null, because, the second call will only happen when the async call complete because the for loop is returned here. Once the recursive call is complete and a KNNQueryBuilder is returned, the passed-in listener will be executed and continue the query flow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).

To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null) condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.

Please confirm this understanding is correct or not.

Yes this is correct.

I think when second rewrite happens, the supplier vector shouldn't be null, because, the second call will only happen when the async call complete because the for loop is returned here. Once the recursive call is complete and a KNNQueryBuilder is returned, the passed-in listener will be executed and continue the query flow.

Right, thats correct, but I return a copy at the end instead of "this" to prevent a case where rewrite stops early.

Signed-off-by: John Mazanec <[email protected]>
@jmazanec15 jmazanec15 requested a review from navneet1v October 13, 2022 22:01
@@ -59,6 +59,7 @@ opensearchplugin {
classname "${projectPath}.${pathToPlugin}.${pluginClassName}"
licenseFile rootProject.file('LICENSE')
noticeFile rootProject.file('NOTICE')
extendedPlugins = ['opensearch-knn']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add MLCommons also over here? or this is just for the plugins who we depend during compile time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ml-commons, we are okay because we just take the client as a dependency.

Comment on lines 213 to 215
// Rewrites will continuously happen until the supplier is set and the vector is generated. Rewrites will stop
// once this object is returned. Hence, if we get here, we need to return a new object
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSupplier());
Copy link
Collaborator

@navneet1v navneet1v Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the case is if we are coming on this line then rewrites will happen continuously. Isn't the documentation little off?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thats my understanding. Let me provide better comment on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@jmazanec15 jmazanec15 Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 jmazanec15 requested a review from navneet1v October 14, 2022 04:18
Copy link
Collaborator

@navneet1v navneet1v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall code looks good to me. Minor comments.

Comment on lines +69 to +74
private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralQueryBuilder.ML_CLIENT = mlClient;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Can we init the MLCommonsClientAccessor via the constructor of NeuralQueryBuilder class? is it not possible to do so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I thought a little bit about this, but I am not sure about how to handle stream constructors. Any ideas on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it like this only then. I am worried on the streamInput constructor, that should not cause a NPE for MLCommonsClientAccessor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will leave as is

queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason of making onResponse(null) is because we have already set the response to vector supplier.

We are not setting the response, but instead setting the value of the supplier.

I mean here the value only.

queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?

Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.

My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).

To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null) condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.

Please confirm this understanding is correct or not.

queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Copy link
Collaborator

@zane-neo zane-neo Oct 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding here: doRewrite method will be invoked by rewriteAndFetch which is a recursive method by wrapping itself invocation in the listener. And once the predict invocation is done, the vectorSupplier will be populated with valid data, and a KNNQueryBuilder is returned, next recursive call will hit the check on vectorSupplier and return KNNQueryBuilder again, then this method stop.

When the action gets invoked at executeAsyncActions, the actionListener.onResponse(null) here means to invoke the wildcardListener's on Response to count down. So this is a mandatory and necessary operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so we need to call listener on null so countdown happens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 please add these code links as documentation on top of this code.

// Rewrites will continuously happen until the supplier is set and the vector is generated. Rewrites will stop
// once this object is returned. Hence, if we get here, we need to return a new object
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSupplier());
return this;
Copy link

@ylwu-amzn ylwu-amzn Oct 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the third commit fixed the rescore exception of too many rewrite rounds. The key part is don't return a new NeuralQueryBuilder. That's why the PoC code worked for rescore before

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the rescore issue. The code LGTM.
Please also test other possible use cases.

Copy link

@ylwu-amzn ylwu-amzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 please add these code links as documentation on top of this code.

actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Query] : is there any reason why we are not returning "this" here and returning a new query object? and also I am not able to understand in which case line number 218 will hit?
vectorSupplier can be null or not null.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not returning this here because we are changing the supplier of the query builder. If we return this, this check https://github.com/opensearch-project/OpenSearch/blob/e44158d4d10d4f8905895ffa50bf9398b8550667/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L109 will see the same reference is indicate another round of rewrites does not need to be performed. So instead, we copy it. This is how Geo does it: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/AbstractGeometryQueryBuilder.java#L509.

vectorSupplier can be null or not null.

Good point. Ill simplify

@jmazanec15 jmazanec15 force-pushed the issue-14 branch 2 times, most recently from 43540a3 to b1d826c Compare October 18, 2022 17:33
Signed-off-by: John Mazanec <[email protected]>
@jmazanec15 jmazanec15 requested a review from navneet1v October 18, 2022 17:45
@jmazanec15 jmazanec15 merged commit 272d803 into opensearch-project:main Oct 19, 2022
@jmazanec15 jmazanec15 added the Features Introduces a new unit of functionality that satisfies a requirement label Nov 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Features Introduces a new unit of functionality that satisfies a requirement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants