Skip to content

Commit

Permalink
Addressed changes requested by Martin
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 28, 2023
1 parent 4be198c commit e9db552
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import java.util.Map;

import lombok.Getter;
import lombok.Setter;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.Nullable;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
Expand All @@ -16,18 +20,21 @@
import org.opensearch.search.pipeline.SearchRequestProcessor;

/**
* Neural Search Query Request Processor
* Neural Search Query Request Processor, It modifies the search request with neural query clause
* and adds model Id if not present in the search query.
*/
@Setter
@Getter
public class NeuralQueryProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "neural_query";
public static final String TYPE = "enriching_query_defaults";

final String modelId;
private final String modelId;

final Map<String, Object> neuralFieldDefaultIdMap;
private final Map<String, Object> neuralFieldDefaultIdMap;

/**
* Returns the type of the processor.
Expand All @@ -39,12 +46,12 @@ public String getType() {
return TYPE;
}

protected NeuralQueryProcessor(
private NeuralQueryProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
Map<String, Object> neuralFieldDefaultIdMap
@Nullable String modelId,
@Nullable Map<String, Object> neuralFieldDefaultIdMap
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
Expand Down Expand Up @@ -81,7 +88,12 @@ public NeuralQueryProcessor create(
Map<String, Object> config,
PipelineContext pipelineContext
) throws IllegalArgumentException {
String modelId = (String) config.remove(DEFAULT_MODEL_ID);
String modelId;
try {
modelId = (String) config.remove(DEFAULT_MODEL_ID);
} catch (ClassCastException e) {
throw new IllegalArgumentException("model Id must of String type");
}
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
this.modelId = in.readOptionalString();
} else {
Expand All @@ -109,6 +110,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
@AllArgsConstructor
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {

private String modelId;
private Map<String, Object> neuralFieldMap;
private final String modelId;
private final Map<String, Object> neuralFieldMap;

/**
* Accept method accepts every query builder from the search request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.neuralsearch.util;

import java.util.Locale;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -48,19 +46,7 @@ public void initialize(final ClusterService clusterService) {
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
try {
return this.clusterService.state().getNodes().getMinNodeVersion();
} catch (Exception exception) {
log.error(
String.format(
Locale.ROOT,
"Failed to get cluster minimum node version, returning current node version %s instead.",
Version.CURRENT
),
exception
);
return Version.CURRENT;
}
return this.clusterService.state().getNodes().getMinNodeVersion();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public class NeuralQueryProcessorTests extends OpenSearchTestCase {
public void testFactory() throws Exception {
NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory();
NeuralQueryProcessor processor = createTestProcessor(factory);
assertEquals("vasdcvkcjkbldbjkd", processor.modelId);
assertEquals("bahbkcdkacb", processor.neuralFieldDefaultIdMap.get("fieldName").toString());
assertEquals("vasdcvkcjkbldbjkd", processor.getModelId());
assertEquals("bahbkcdkacb", processor.getNeuralFieldDefaultIdMap().get("fieldName").toString());

// Missing "query" parameter:
expectThrows(
Expand All @@ -39,7 +39,7 @@ public void testProcessRequest() throws Exception {
assertEquals(processSearchRequest, searchRequest);
}

public NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception {
private NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception {
Map<String, Object> configMap = new HashMap<>();
configMap.put("default_model_id", "vasdcvkcjkbldbjkd");
configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.neuralsearch.util;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils.mockClusterService;

import org.opensearch.Version;
Expand Down Expand Up @@ -36,16 +34,4 @@ public void testMinNodeVersion_whenMultipleNodesCluster_thenSuccess() {

assertTrue(Version.V_2_3_0.equals(minVersion));
}

public void testMinNodeVersion_WhenErrorOnClusterState_thenMatchCurrentVersion() {
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready"));

final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance();
neuralSearchClusterUtil.initialize(clusterService);

final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion();

assertTrue(Version.CURRENT.equals(minVersion));
}
}

0 comments on commit e9db552

Please sign in to comment.