-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate model inference to build query #20
Conversation
Integrates ml-commons model inference capabilities to transform NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing logic and build.gradle to fix bugs. Minor enhancement to MLCommonsCLientAccessor to add single sentence inference. Added uTs. Signed-off-by: John Mazanec <[email protected]>
* @param inputText {@link List} of {@link String} on which inference needs to happen | ||
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out | ||
*/ | ||
public void inferenceSentence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need this, when we already have a function which does for a list of sentences, please try to use that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im going to be writing the code to convert functionality to single sentence/single vector. Why not incorporate it here so it can be used by other features of the code as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All you need to do is do a get on the vector list, for that if we are creating a new function that is over abstraction. We already have 2 versions of this function in the class, and add third which is very specific seems to be an overkill for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to keep this here. I dont see a downside to having it. The name of the method "inferenceSentence" distinguishes it from "inferenceSentences" so there won't be any confusion when to use what. Also, I think single versus collection isnt so specific that it couldn't be useful outside my current use case. Many other Java interfaces/classes have methods for both.
All you need to do is do a get on the vector list,
True, but you also have to ensure only 1 vector is returned - so each client would also have to add this check/error handling. Using this function, we can centralize that check. Additionally, it is clunky to pass in a List<List<Float>> listener when you expect that there will only be 1 List<Float>
Given that the purpose of this class is to build an easy to use abstraction over MLClient, I think it fits to add this method here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add parity method with inferenceSentences where filters are passed in as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we abstract the single sentence function in the queryBuilder class only? I just want to avoid confusion around the usage and 1 more function where we will have use case of not using the TARGET_RESPONSE_FILTERS, for 1 single sentence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think there is any usage confusion, given the method name and signature are descriptive of the functionality.
In terms of maintainability, I dont think it makes sense for other components that will want to use the inferenceSentence method to depend on the queryBuilder class. I think that the ml package should be responsible for providing easy/intuitive interaction with ml-commons for the components in the plugin and should therefore be able to house this functionality.
That being said, I think it makes sense to either integrate it into this class or build a new class that uses this class to provide higher level abstractions. In a similar case, OpenSearch has a "HighLevel" rest client that is built on top of the lower level RestClient to provide more functionality. We could do this, or we could treat the MLCommonsClientAccessor as the higher level client.
if (vectorSupplier() != null && vectorSupplier.get() != null) { | ||
return new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need these checks, if the query builder is running the vectorSupplier will be null isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VectorSupplier is null initially and then gets set during rewrite. The actual vector will not get set until the async call finishes completely.
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why setting null on the OnReponse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was how it was done for Geo: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/AbstractGeometryQueryBuilder.java#L519.
The rewrite context specifies listener response type as wildcard. That being said, we dont control the listener that is passed in, so all we can do is return null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason of making onResponse(null) is because we have already set the response to vector supplier.
Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason of making onResponse(null) is because we have already set the response to vector supplier.
We are not setting the response, but instead setting the value of the supplier.
Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?
Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason of making onResponse(null) is because we have already set the response to vector supplier.
We are not setting the response, but instead setting the value of the supplier.
I mean here the value only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.
My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).
To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null)
condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.
Please confirm this understanding is correct or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.
I think when second rewrite happens, the supplier vector shouldn't be null, because, the second call will only happen when the async call complete because the for loop is returned here. Once the recursive call is complete and a KNNQueryBuilder is returned, the passed-in listener will be executed and continue the query flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).
To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null) condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.
Please confirm this understanding is correct or not.
Yes this is correct.
I think when second rewrite happens, the supplier vector shouldn't be null, because, the second call will only happen when the async call complete because the for loop is returned here. Once the recursive call is complete and a KNNQueryBuilder is returned, the passed-in listener will be executed and continue the query flow.
Right, thats correct, but I return a copy at the end instead of "this" to prevent a case where rewrite stops early.
Signed-off-by: John Mazanec <[email protected]>
@@ -59,6 +59,7 @@ opensearchplugin { | |||
classname "${projectPath}.${pathToPlugin}.${pluginClassName}" | |||
licenseFile rootProject.file('LICENSE') | |||
noticeFile rootProject.file('NOTICE') | |||
extendedPlugins = ['opensearch-knn'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add MLCommons also over here? or this is just for the plugins who we depend during compile time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ml-commons, we are okay because we just take the client as a dependency.
// Rewrites will continuously happen until the supplier is set and the vector is generated. Rewrites will stop | ||
// once this object is returned. Hence, if we get here, we need to return a new object | ||
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSupplier()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the case is if we are coming on this line then rewrites will happen continuously. Isn't the documentation little off?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thats my understanding. Let me provide better comment on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is what I was basing this off of: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/QueryBuilder.java#L90-L98
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the logic: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L83
Actually this one: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall code looks good to me. Minor comments.
private static MLCommonsClientAccessor ML_CLIENT; | ||
|
||
public static void initialize(MLCommonsClientAccessor mlClient) { | ||
NeuralQueryBuilder.ML_CLIENT = mlClient; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Can we init the MLCommonsClientAccessor via the constructor of NeuralQueryBuilder class? is it not possible to do so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I thought a little bit about this, but I am not sure about how to handle stream constructors. Any ideas on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave it like this only then. I am worried on the streamInput constructor, that should not cause a NPE for MLCommonsClientAccessor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will leave as is
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason of making onResponse(null) is because we have already set the response to vector supplier.
We are not setting the response, but instead setting the value of the supplier.
I mean here the value only.
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some more details around how this rewrite query function will work with first if condition like vectorSupplier != null and all.
Because what I am getting from this whole function is the re-writeQuery function can be called atleast 2 times. where one time vectorSupplier will be null and next time it will not be null. is that understanding correct?Yes, will add additional context in the comment. From my understanding, the query will be rewritten until the rewrite method returns the object itself. So, on first rewrite, the supplier will be null. On second rewrite, the supplier will be set, but the vector may be null, depending if the async call finished in time or not. Let me double check this though and add a comment.
My understanding is if we keep on returning the NeuralQueryBuilder, then rewrite will keep on happening(which we are doing).
To stop query re-write we added vectorSupplier() != null && vectorSupplier.get() != null)
condition so that we can return another QueryBuilder(KNNQueryBuilder) which only happens when we have response from MLCommonsClientAccessor.
Please confirm this understanding is correct or not.
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding here: doRewrite method will be invoked by rewriteAndFetch which is a recursive method by wrapping itself invocation in the listener. And once the predict invocation is done, the vectorSupplier will be populated with valid data, and a KNNQueryBuilder is returned, next recursive call will hit the check on vectorSupplier and return KNNQueryBuilder again, then this method stop.
When the action gets invoked at executeAsyncActions, the actionListener.onResponse(null)
here means to invoke the wildcardListener's on Response to count down. So this is a mandatory and necessary operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so we need to call listener on null so countdown happens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmazanec15 please add these code links as documentation on top of this code.
// Rewrites will continuously happen until the supplier is set and the vector is generated. Rewrites will stop | ||
// once this object is returned. Hence, if we get here, we need to return a new object | ||
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSupplier()); | ||
return this; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the third commit fixed the rescore exception of too many rewrite rounds. The key part is don't return a new NeuralQueryBuilder. That's why the PoC code worked for rescore before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the rescore issue. The code LGTM.
Please also test other possible use cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmazanec15 please add these code links as documentation on top of this code.
actionListener.onResponse(null); | ||
}, actionListener::onFailure))) | ||
); | ||
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Query] : is there any reason why we are not returning "this" here and returning a new query object? and also I am not able to understand in which case line number 218 will hit?
vectorSupplier can be null or not null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not returning this here because we are changing the supplier of the query builder. If we return this, this check https://github.com/opensearch-project/OpenSearch/blob/e44158d4d10d4f8905895ffa50bf9398b8550667/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L109 will see the same reference is indicate another round of rewrites does not need to be performed. So instead, we copy it. This is how Geo does it: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/AbstractGeometryQueryBuilder.java#L509.
vectorSupplier can be null or not null.
Good point. Ill simplify
43540a3
to
b1d826c
Compare
Signed-off-by: John Mazanec <[email protected]>
Signed-off-by: John Mazanec <[email protected]>
Description
Integrates ml-commons model inference capabilities to transform NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing logic and build.gradle to fix bugs. Minor enhancement to MLCommonsCLientAccessor to add single sentence inference.
Added unit tests to test functionality.
Working on integration tests. Confirmed it works on a local cluster by setting up k-NN index with 1K docs and running the following query:
Issues Resolved
#14
Check List
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.