Skip to content

Commit

Permalink
Fixing test cases
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 28, 2023
1 parent e9efd72 commit ff10aba
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.ParseField;
Expand All @@ -51,6 +53,8 @@
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.carrotsearch.randomizedtesting.RandomizedTest;

Expand Down Expand Up @@ -235,6 +239,7 @@ public void testDoToQuery_whenTooManySubqueries_thenFail() {
*/
@SneakyThrows
public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startArray("queries")
Expand Down Expand Up @@ -412,6 +417,7 @@ public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() {

@SneakyThrows
public void testStreams_whenWrittingToStream_thenSuccessful() {
setUpClusterService();
HybridQueryBuilder original = new HybridQueryBuilder();
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
.queryText(QUERY_TEXT)
Expand Down Expand Up @@ -716,4 +722,9 @@ private Map<String, Object> getInnerMap(Object innerObject, String queryName, St
Map<String, Object> vectorFieldInnerMap = (Map<String, Object>) neuralInnerMap.get(fieldName);
return vectorFieldInnerMap;
}

private void setUpClusterService() {
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(Version.CURRENT);
NeuralSearchClusterUtil.instance().initialize(clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import lombok.SneakyThrows;

import org.opensearch.Version;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.ParseField;
Expand All @@ -50,6 +52,8 @@
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralQueryBuilderTests extends OpenSearchTestCase {
Expand All @@ -75,6 +79,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() {
}
}
*/
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand Down Expand Up @@ -107,6 +112,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
}
}
*/
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand Down Expand Up @@ -146,6 +152,7 @@ public void testFromXContent_whenBuiltWithFilter_thenBuildSuccessfully() {
}
}
*/
setUpClusterService();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
Expand Down Expand Up @@ -334,6 +341,7 @@ public void testToXContent() {

@SneakyThrows
public void testStreams() {
setUpClusterService();
NeuralQueryBuilder original = new NeuralQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
Expand Down Expand Up @@ -572,4 +580,9 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() {
KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder;
assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter());
}

private void setUpClusterService() {
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(Version.CURRENT);
NeuralSearchClusterUtil.instance().initialize(clusterService);
}
}

0 comments on commit ff10aba

Please sign in to comment.