-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathNeuralQueryBuilder.java
261 lines (233 loc) · 10.7 KB
/
NeuralQueryBuilder.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;
import java.io.IOException;
import java.util.function.Supplier;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import com.google.common.annotations.VisibleForTesting;
/**
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a
* k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as
* the query vector for the k-NN search.
*/
@Log4j2
@Getter
@Setter
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> {
public static final String NAME = "neural";
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
@VisibleForTesting
static final ParseField K_FIELD = new ParseField("k");
private static final int DEFAULT_K = 10;
private static MLCommonsClientAccessor ML_CLIENT;
public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralQueryBuilder.ML_CLIENT = mlClient;
}
private String fieldName;
private String queryText;
private String modelId;
private int k = DEFAULT_K;
@VisibleForTesting
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
/**
* Constructor from stream input
*
* @param in StreamInput to initialize object from
* @throws IOException thrown if unable to read from input stream
*/
public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}
@Override
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
}
/**
* Creates NeuralQueryBuilder from xContent.
*
* The expected parsing form looks like:
* {
* "VECTOR_FIELD": {
* "query_text": "string",
* "model_id": "string",
* "k": int,
* "name": "string", (optional)
* "boost": float (optional),
* "filter": map (optional)
* }
* }
*
* @param parser XContentParser
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
*/
public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT");
}
parser.nextToken();
neuralQueryBuilder.fieldName(parser.currentName());
parser.nextToken();
parseQueryParams(parser, neuralQueryBuilder);
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ neuralQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
);
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
return neuralQueryBuilder;
}
private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder neuralQueryBuilder) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName();
} else if (token.isValue()) {
if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.queryText(parser.text());
} else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.modelId(parser.text());
} else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.k((Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false));
} else if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.queryName(parser.text());
} else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.boost(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
);
}
}
}
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// When re-writing a QueryBuilder, if the QueryBuilder is not changed, doRewrite should return itself
// (see
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/QueryBuilder.java#L90-L98).
// Otherwise, it should return the modified copy (see rewrite logic
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/index/query/Rewriteable.java#L117.
// With the asynchronous call, on first rewrite, we create a new
// vector supplier that will get populated once the asynchronous call finishes and pass this supplier in to
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
if (vectorSupplier() != null) {
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter());
}
SetOnce<float[]> vectorSetOnce = new SetOnce<>();
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter());
}
@Override
protected Query doToQuery(QueryShardContext queryShardContext) {
// All queries should be generated by the k-NN Query Builder
throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
}
@Override
protected boolean doEquals(NeuralQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queryText, obj.queryText);
equalsBuilder.append(modelId, obj.modelId);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(filter, obj.filter);
return equalsBuilder.isEquals();
}
@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(k).toHashCode();
}
@Override
public String getWriteableName() {
return NAME;
}
}