diff --git a/elastic-fed/buildSrc/version.properties b/elastic-fed/buildSrc/version.properties index 6dd055a6b..b9d65b633 100644 --- a/elastic-fed/buildSrc/version.properties +++ b/elastic-fed/buildSrc/version.properties @@ -1,6 +1,6 @@ havenask = 1.0.0 lucene = 8.7.0 -runtime_image = 0.2 +runtime_image = 0.3 bundled_jdk_vendor = openjdk bundled_jdk = 11.0.2+9 diff --git a/elastic-fed/client/rest-high-level/src/main/java/org/havenask/client/ha/SqlResponse.java b/elastic-fed/client/rest-high-level/src/main/java/org/havenask/client/ha/SqlResponse.java index 92675d421..5ec77b8dc 100644 --- a/elastic-fed/client/rest-high-level/src/main/java/org/havenask/client/ha/SqlResponse.java +++ b/elastic-fed/client/rest-high-level/src/main/java/org/havenask/client/ha/SqlResponse.java @@ -26,6 +26,7 @@ public class SqlResponse { private final double totalTime; private final boolean hasSoftFailure; + private final double coveredPercent; private final int rowCount; private final SqlResult sqlResult; private final ErrorInfo errorInfo; @@ -70,9 +71,16 @@ public int GetErrorCode() { } } - public SqlResponse(double totalTime, boolean hasSoftFailure, int rowCount, SqlResult sqlResult, ErrorInfo errorInfo) { + public SqlResponse( + double totalTime, + boolean hasSoftFailure, + double coveredPercent, + int rowCount, + SqlResult sqlResult, + ErrorInfo errorInfo) { this.totalTime = totalTime; this.hasSoftFailure = hasSoftFailure; + this.coveredPercent = coveredPercent; this.rowCount = rowCount; this.sqlResult = sqlResult; this.errorInfo = errorInfo; @@ -86,6 +94,10 @@ public boolean isHasSoftFailure() { return hasSoftFailure; } + public double getCoveredPercent() { + return coveredPercent; + } + public int getRowCount() { return rowCount; } @@ -102,6 +114,7 @@ public static SqlResponse fromXContent(XContentParser parser) throws IOException XContentParser.Token token; double totalTime = 0; boolean hasSoftFailure = false; + double coveredPercent = 0; int rowCount = 0; SqlResult sqlResult = null; ErrorInfo errorInfo = null; @@ -116,6 +129,9 @@ public static SqlResponse fromXContent(XContentParser parser) throws IOException case "has_soft_failure": hasSoftFailure = parser.booleanValue(); break; + case "covered_percent": + coveredPercent = parser.doubleValue(); + break; case "row_count": rowCount = parser.intValue(); break; @@ -200,7 +216,7 @@ public static SqlResponse fromXContent(XContentParser parser) throws IOException } } } - return new SqlResponse(totalTime, hasSoftFailure, rowCount, sqlResult, errorInfo); + return new SqlResponse(totalTime, hasSoftFailure, coveredPercent, rowCount, sqlResult, errorInfo); } public static SqlResponse parse(String strResponse) throws IOException { diff --git a/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/AbstractHavenaskRestTestCase.java b/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/AbstractHavenaskRestTestCase.java index 652cce1f5..956aaf69b 100644 --- a/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/AbstractHavenaskRestTestCase.java +++ b/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/AbstractHavenaskRestTestCase.java @@ -44,6 +44,8 @@ import org.junit.Before; public abstract class AbstractHavenaskRestTestCase extends HavenaskRestTestCase { + public static final String NUMBER_OF_SHARDS = "number_of_shards"; + public static final String NUMBER_OF_REPLICAS = "number_of_replicas"; private static RestHighLevelClient restHighLevelClient; @Before diff --git a/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/SearchIT.java b/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/SearchIT.java index a5492ea4c..5a24bba1b 100644 --- a/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/SearchIT.java +++ b/elastic-fed/modules/havenask-engine/src/javaRestTest/java/org/havenask/engine/SearchIT.java @@ -15,6 +15,8 @@ package org.havenask.engine; import java.io.IOException; +import java.util.ArrayList; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; @@ -47,10 +49,11 @@ public class SearchIT extends AbstractHavenaskRestTestCase { // static logger private static final Logger logger = LogManager.getLogger(SearchIT.class); - private static final String[] SearchITIndices = { "single_shard_test", "multi_shard_test", "multi_vector_test" }; - private static final int TEST_SINGLE_SHARD_KNN_INDEX_POS = 0; - private static final int TEST_MULTI_SHARD_KNN_INDEX_POS = 1; - private static final int TEST_MULTI_KNN_QUERY_INDEX_POS = 2; + private static final String[] SearchITIndices = { "search_test", "single_shard_test", "multi_shard_test", "multi_vector_test" }; + private static final int TEST_SEARCH_INDEX_POS = 0; + private static final int TEST_SINGLE_SHARD_KNN_INDEX_POS = 1; + private static final int TEST_MULTI_SHARD_KNN_INDEX_POS = 2; + private static final int TEST_MULTI_KNN_QUERY_INDEX_POS = 3; @AfterClass public static void cleanIndices() { @@ -66,6 +69,60 @@ public static void cleanIndices() { } } + public void testSearch() throws Exception { + String index = SearchITIndices[TEST_SEARCH_INDEX_POS]; + int dataNum = 3; + double delta = 0.00001; + // create index + Settings settings = Settings.builder() + .put(EngineSettings.ENGINE_TYPE_SETTING.getKey(), EngineSettings.ENGINE_HAVENASK) + .put(NUMBER_OF_SHARDS, 2) + .put(NUMBER_OF_REPLICAS, 0) + .build(); + + java.util.Map map = Map.of( + "properties", + Map.of("seq", Map.of("type", "integer"), "content", Map.of("type", "keyword"), "time", Map.of("type", "date")) + ); + assertTrue(createTestIndex(index, settings, map)); + + waitIndexGreen(index); + + // PUT docs + String[] idList = { "1", "2", "3" }; + java.util.List> sourceList = new ArrayList<>(); + sourceList.add(Map.of("seq", 1, "content", "欢迎使用1", "time", "20230718")); + sourceList.add(Map.of("seq", 2, "content", "欢迎使用2", "time", "20230717")); + sourceList.add(Map.of("seq", 3, "content", "欢迎使用3", "time", "20230716")); + for (int i = 0; i < idList.length; i++) { + putDoc(index, idList[i], sourceList.get(i)); + } + + // get data with _search API + SearchRequest searchRequest = new SearchRequest(index); + + assertBusy(() -> { + SearchResponse searchResponse = highLevelClient().search(searchRequest, RequestOptions.DEFAULT); + assertEquals(dataNum, searchResponse.getHits().getTotalHits().value); + }, 10, TimeUnit.SECONDS); + SearchResponse searchResponse = highLevelClient().search(searchRequest, RequestOptions.DEFAULT); + assertEquals(dataNum, searchResponse.getHits().getTotalHits().value); + + Set expectedSeq = Set.of(1, 2, 3); + Set expectedContent = Set.of("欢迎使用1", "欢迎使用2", "欢迎使用3"); + Set expectedTime = Set.of("20230718", "20230717", "20230716"); + for (int i = 0; i < dataNum; i++) { + assertEquals(index, searchResponse.getHits().getHits()[i].getIndex()); + assertEquals(1.0, searchResponse.getHits().getHits()[i].getScore(), delta); + assertTrue(expectedSeq.contains(searchResponse.getHits().getHits()[i].getSourceAsMap().get("seq"))); + assertTrue(expectedContent.contains(searchResponse.getHits().getHits()[i].getSourceAsMap().get("content"))); + assertTrue(expectedTime.contains(searchResponse.getHits().getHits()[i].getSourceAsMap().get("time"))); + } + + // delete index and HEAD index + deleteAndHeadIndex(index); + } + public void testSingleShardKnn() throws Exception { String index = SearchITIndices[TEST_SINGLE_SHARD_KNN_INDEX_POS]; String fieldName = "image"; diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/MetaDataSyncer.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/MetaDataSyncer.java index ff874a8e2..3f12cea75 100644 --- a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/MetaDataSyncer.java +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/MetaDataSyncer.java @@ -24,7 +24,6 @@ import java.nio.file.StandardCopyOption; import java.nio.file.StandardOpenOption; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -40,6 +39,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import com.carrotsearch.hppc.cursors.ObjectCursor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.havenask.cluster.ClusterChangedEvent; @@ -47,6 +47,7 @@ import org.havenask.cluster.ClusterStateApplier; import org.havenask.cluster.metadata.IndexMetadata; import org.havenask.cluster.node.DiscoveryNode; +import org.havenask.cluster.routing.IndexRoutingTable; import org.havenask.cluster.routing.RoutingNode; import org.havenask.cluster.routing.ShardRouting; import org.havenask.cluster.service.ClusterService; @@ -66,6 +67,7 @@ import org.havenask.engine.rpc.UpdateHeartbeatTargetRequest; import org.havenask.engine.util.RangeUtil; import org.havenask.engine.util.Utils; +import org.havenask.index.Index; import org.havenask.threadpool.ThreadPool; public class MetaDataSyncer extends AbstractLifecycleComponent implements ClusterStateApplier { @@ -317,20 +319,24 @@ protected String getThreadPool() { @Override public void applyClusterState(ClusterChangedEvent event) { - if (isIngestNode && shouldUpdateQrs(event.previousState(), event.state())) { - // update qrs target - setQrsPendingSync(); + try { + if (isIngestNode && shouldUpdateQrs(event)) { + // update qrs target + setQrsPendingSync(); + } + } catch (Throwable e) { + LOGGER.error("error when update qrs target: ", e); } } - private boolean shouldUpdateQrs(ClusterState prevClusterState, ClusterState curClusterState) { + private boolean shouldUpdateQrs(ClusterChangedEvent event) { // check 是否有索引级别的增删 - if (isHavenaskIndexChanged(prevClusterState, curClusterState)) { + if (isHavenaskIndexChanged(event)) { return true; } - // TODO: check 分片的搬迁是否要更新qrs - if (isHavenaskShardChanged(prevClusterState, curClusterState)) { + // check shard级别的变更 + if (isHavenaskShardChanged(event.previousState(), event.state())) { return true; } @@ -611,45 +617,38 @@ private synchronized void generateDefaultBizConfig(List indexList) throw ); } - private boolean isHavenaskIndexChanged(ClusterState prevClusterState, ClusterState curClusterState) { - Set prevIndexNamesSet = new HashSet<>(Arrays.asList(prevClusterState.metadata().indices().keys().toArray(String.class))); - Set currentIndexNamesSet = new HashSet<>(Arrays.asList(curClusterState.metadata().indices().keys().toArray(String.class))); - Set prevDiff = new HashSet<>(prevIndexNamesSet); - Set curDiff = new HashSet<>(currentIndexNamesSet); - prevDiff.removeAll(currentIndexNamesSet); - curDiff.removeAll(prevIndexNamesSet); - - for (String indexName : prevDiff) { - IndexMetadata indexMetadata = prevClusterState.metadata().index(indexName); + private boolean isHavenaskIndexChanged(ClusterChangedEvent event) { + List indicesDeleted = event.indicesDeleted(); + List indicesCreated = event.indicesCreated(); + for (Index index : indicesDeleted) { + IndexMetadata indexMetadata = event.previousState().getMetadata().index(index); if (EngineSettings.isHavenaskEngine(indexMetadata.getSettings())) { return true; } } - for (String indexName : curDiff) { - IndexMetadata indexMetadata = curClusterState.metadata().index(indexName); + for (String index : indicesCreated) { + IndexMetadata indexMetadata = event.state().metadata().index(index); if (EngineSettings.isHavenaskEngine(indexMetadata.getSettings())) { return true; } } - return false; } private boolean isHavenaskShardChanged(ClusterState prevClusterState, ClusterState curClusterState) { - // TODO : 识别shard搬迁的case - for (RoutingNode routingNode : prevClusterState.getRoutingNodes()) { - for (ShardRouting shardRouting : routingNode) { - IndexMetadata indexMetadata = prevClusterState.metadata().index(shardRouting.index().getName()); - if (false == EngineSettings.isHavenaskEngine(indexMetadata.getSettings())) { - continue; - } - ShardRouting curShardRouting = curClusterState.getRoutingNodes() - .node(routingNode.nodeId()) - .getByShardId(shardRouting.shardId()); - if (curShardRouting == null || curShardRouting.getTargetRelocatingShard() != null) { - return true; - } + for (ObjectCursor indexNameCursor : prevClusterState.routingTable().indicesRouting().keys()) { + String indexName = indexNameCursor.value; + IndexMetadata indexMetadata = prevClusterState.metadata().index(indexName); + if (false == EngineSettings.isHavenaskEngine(indexMetadata.getSettings())) { + continue; + } + IndexRoutingTable prevIndexRoutingTable = prevClusterState.routingTable().indicesRouting().get(indexName); + IndexRoutingTable curIndexRoutingTable = curClusterState.routingTable().indicesRouting().get(indexName); + + // TODO: shard级别的判断变更逻辑,目前使用IndexRoutingTable的equals方法,比较index以及shards是否相等,考虑后续优化 + if (false == prevIndexRoutingTable.equals(curIndexRoutingTable)) { + return true; } } return false; diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/config/BizConfig.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/config/BizConfig.java index e71b97c53..a313ea6ef 100644 --- a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/config/BizConfig.java +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/config/BizConfig.java @@ -69,6 +69,14 @@ public SortConfig(String sort_field, String sort_pattern) { public static class MergeConfig { public String merge_strategy = "combined"; + public MergeStrategyConfig merge_strategy_params = new MergeStrategyConfig(); + } + + public static class MergeStrategyConfig { + public String input_limits = "max-segment-size=5120"; + public String strategy_conditions = "priority-feature=valid-doc-count#asc;conflict-segment-count=10;conflict-delete-percent=30"; + public String output_limits = "max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64"; } public static class BuildConfig { diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchFetchProcessor.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchFetchProcessor.java new file mode 100644 index 000000000..bd05095b8 --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchFetchProcessor.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.havenask.client.ha.SqlResponse; +import org.havenask.common.bytes.BytesArray; +import org.havenask.common.lucene.search.TopDocsAndMaxScore; +import org.havenask.common.text.Text; +import org.havenask.engine.rpc.QrsClient; +import org.havenask.engine.rpc.QrsSqlRequest; +import org.havenask.engine.rpc.QrsSqlResponse; +import org.havenask.engine.search.fetch.HavenaskFetchSourcePhase; +import org.havenask.engine.search.fetch.HavenaskFetchSubPhase; +import org.havenask.engine.search.fetch.HavenaskFetchSubPhaseProcessor; +import org.havenask.index.mapper.MapperService; +import org.havenask.search.SearchHit; +import org.havenask.search.SearchHits; +import org.havenask.search.aggregations.InternalAggregations; +import org.havenask.search.builder.SearchSourceBuilder; +import org.havenask.search.fetch.subphase.FetchSourceContext; +import org.havenask.search.internal.InternalSearchResponse; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.havenask.engine.search.rest.RestHavenaskSqlAction.SQL_DATABASE; + +public class HavenaskSearchFetchProcessor { + private static final Logger logger = LogManager.getLogger(HavenaskSearchFetchProcessor.class); + private static final int ID_POS = 0; + private static final int SCORE_POS = 1; + private static final int SOURCE_POS = 1; + private static final Object SOURCE_NOT_FOUND = "{\n" + "\"warn\":\"source not found\"\n" + "}"; + QrsClient qrsClient; + private final List havenaskFetchSubPhases; + + public HavenaskSearchFetchProcessor(QrsClient qrsClient) { + this.qrsClient = qrsClient; + // TODO 目前仅支持source过滤,未来增加更多的subPhase并考虑以plugin形式去支持 + this.havenaskFetchSubPhases = new ArrayList<>(); + this.havenaskFetchSubPhases.add(new HavenaskFetchSourcePhase()); + } + + public InternalSearchResponse executeFetch(SqlResponse queryPhaseSqlResponse, String tableName, SearchSourceBuilder searchSourceBuilder) + throws IOException { + if (searchSourceBuilder == null) { + throw new IllegalArgumentException("request source can not be null!"); + } + List idList = new ArrayList<>(queryPhaseSqlResponse.getRowCount()); + TopDocsAndMaxScore topDocsAndMaxScore = buildQuerySearchResult(queryPhaseSqlResponse, idList); + SqlResponse fetchPhaseSqlResponse = searchSourceBuilder.fetchSource() == null + || true == searchSourceBuilder.fetchSource().fetchSource() + ? havenaskFetchWithSql(idList, tableName, searchSourceBuilder.fetchSource(), qrsClient) + : null; + + return transferSqlResponse2FetchResult(tableName, idList, fetchPhaseSqlResponse, topDocsAndMaxScore, searchSourceBuilder); + } + + public TopDocsAndMaxScore buildQuerySearchResult(SqlResponse queryPhaseSqlResponse, List idList) throws IOException { + ScoreDoc[] queryScoreDocs = new ScoreDoc[queryPhaseSqlResponse.getRowCount()]; + float maxScore = 0; + int sqlDataSize = queryPhaseSqlResponse.getRowCount() > 0 ? queryPhaseSqlResponse.getSqlResult().getData()[0].length : 0; + if (sqlDataSize > 2) { + throw new IOException("unknow sqlResponse:" + Arrays.deepToString(queryPhaseSqlResponse.getSqlResult().getData())); + } + for (int i = 0; i < queryPhaseSqlResponse.getRowCount(); i++) { + float defaultScore = 1; + float curScore = sqlDataSize == 1 + ? defaultScore + : ((Double) queryPhaseSqlResponse.getSqlResult().getData()[i][SCORE_POS]).floatValue(); + + queryScoreDocs[i] = new ScoreDoc(i, curScore); + maxScore = Math.max(maxScore, curScore); + idList.add(String.valueOf(queryPhaseSqlResponse.getSqlResult().getData()[i][ID_POS])); + } + TopDocs topDocs = new TopDocs( + new TotalHits(queryPhaseSqlResponse.getRowCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + queryScoreDocs + ); + return new TopDocsAndMaxScore(topDocs, maxScore); + } + + public SqlResponse havenaskFetchWithSql( + List idList, + String tableName, + FetchSourceContext fetchSourceContext, + QrsClient qrsClient + ) throws IOException { + QrsSqlRequest qrsFetchPhaseSqlRequest = getQrsFetchPhaseSqlRequest(idList, tableName, fetchSourceContext); + QrsSqlResponse qrsFetchPhaseSqlResponse = qrsClient.executeSql(qrsFetchPhaseSqlRequest); + SqlResponse fetchPhaseSqlResponse = SqlResponse.parse(qrsFetchPhaseSqlResponse.getResult()); + if (logger.isDebugEnabled()) { + logger.debug("fetch idList length: {}, havenask sqlResponse took: {} ms", idList.size(), fetchPhaseSqlResponse.getTotalTime()); + } + return fetchPhaseSqlResponse; + } + + private static QrsSqlRequest getQrsFetchPhaseSqlRequest(List idList, String tableName, FetchSourceContext fetchSourceContext) { + StringBuilder sqlQuery = new StringBuilder(); + sqlQuery.append("select _id, _source from ").append(tableName).append("_summary_ where _id in("); + for (int i = 0; i < idList.size(); i++) { + sqlQuery.append("'").append(idList.get(i)).append("'"); + if (i < idList.size() - 1) { + sqlQuery.append(","); + } + } + sqlQuery.append(")"); + sqlQuery.append(" limit ").append(idList.size()); + String kvpair = "format:full_json;timeout:10000;databaseName:" + SQL_DATABASE; + return new QrsSqlRequest(sqlQuery.toString(), kvpair); + } + + public InternalSearchResponse transferSqlResponse2FetchResult( + String tableName, + List idList, + SqlResponse fetchPhaseSqlResponse, + TopDocsAndMaxScore topDocsAndMaxScore, + SearchSourceBuilder searchSourceBuilder + ) throws IOException { + int loadSize = idList.size(); + TotalHits totalHits = topDocsAndMaxScore.topDocs.totalHits; + SearchHit[] hits = new SearchHit[loadSize]; + FetchSourceContext fetchSourceContext = searchSourceBuilder.fetchSource(); + List processors = getProcessors(tableName, searchSourceBuilder); + + // 记录fetch结果的_id和index的映射关系, query阶段查到的idList是根据_score值排序好的,但fetch结果非有序 + Map fetchResIdListMap = new HashMap<>(); + if (fetchPhaseSqlResponse != null) { + for (int i = 0; i < fetchPhaseSqlResponse.getRowCount(); i++) { + fetchResIdListMap.put((String) fetchPhaseSqlResponse.getSqlResult().getData()[i][ID_POS], i); + } + } + + for (int i = 0; i < loadSize; i++) { + // TODO: add _routing + SearchHit searchHit = new SearchHit( + i, + idList.get(i), + new Text(MapperService.SINGLE_MAPPING_NAME), + Collections.emptyMap(), + Collections.emptyMap() + ); + searchHit.setIndex(tableName); + + // 根据idList的顺序从fetch结果获取相对应的_source, 如果数据丢失则返回_source not found + Object source = null; + if (fetchSourceContext == null || true == fetchSourceContext.fetchSource()) { + Integer fetchResIndex = fetchResIdListMap.get(idList.get(i)); + source = fetchResIndex != null + ? fetchPhaseSqlResponse.getSqlResult().getData()[fetchResIndex][SOURCE_POS] + : SOURCE_NOT_FOUND; + } + HavenaskFetchSubPhase.HitContent hit = new HavenaskFetchSubPhase.HitContent(searchHit, source); + hit.getHit().score(topDocsAndMaxScore.topDocs.scoreDocs[i].score); + if (fetchSourceContext != null && false == fetchSourceContext.fetchSource()) { + hits[i] = hit.getHit(); + } else if (processors != null && !processors.isEmpty()) { + for (HavenaskFetchSubPhaseProcessor processor : processors) { + processor.process(hit); + } + hits[i] = hit.getHit(); + } else { + hits[i] = hit.getHit(); + hits[i].sourceRef(new BytesArray((String) source)); + } + } + return new InternalSearchResponse( + new SearchHits(hits, totalHits, topDocsAndMaxScore.maxScore), + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + } + + List getProcessors(String tableName, SearchSourceBuilder searchSourceBuilder) throws IOException { + try { + List processors = new ArrayList<>(); + for (HavenaskFetchSubPhase fsp : havenaskFetchSubPhases) { + HavenaskFetchSubPhaseProcessor processor = fsp.getProcessor(tableName, searchSourceBuilder); + if (processor != null) { + processors.add(processor); + } + } + return processors; + } catch (Exception e) { + throw new IOException("Error building fetch sub-phases", e); + } + } +} diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchQueryProcessor.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchQueryProcessor.java new file mode 100644 index 000000000..15ae3ffa9 --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/HavenaskSearchQueryProcessor.java @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.havenask.action.search.SearchRequest; +import org.havenask.client.ha.SqlResponse; +import org.havenask.common.Strings; +import org.havenask.engine.index.mapper.DenseVectorFieldMapper; +import org.havenask.engine.index.query.ProximaQueryBuilder; +import org.havenask.engine.rpc.QrsClient; +import org.havenask.engine.rpc.QrsSqlRequest; +import org.havenask.engine.rpc.QrsSqlResponse; +import org.havenask.index.query.MatchAllQueryBuilder; +import org.havenask.index.query.MatchQueryBuilder; +import org.havenask.index.query.QueryBuilder; +import org.havenask.index.query.TermQueryBuilder; +import org.havenask.search.builder.KnnSearchBuilder; +import org.havenask.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; + +import static org.havenask.engine.search.rest.RestHavenaskSqlAction.SQL_DATABASE; + +public class HavenaskSearchQueryProcessor { + private static final Logger logger = LogManager.getLogger(HavenaskSearchQueryProcessor.class); + private static final String SIMILARITY = "similarity"; + private static final String PROPERTIES_FIELD = "properties"; + private static final String VECTOR_SIMILARITY_TYPE_L2_NORM = "L2_NORM"; + private static final String VECTOR_SIMILARITY_TYPE_DOT_PRODUCT = "DOT_PRODUCT"; + QrsClient qrsClient; + + public HavenaskSearchQueryProcessor(QrsClient qrsClient) { + this.qrsClient = qrsClient; + } + + public SqlResponse executeQuery(SearchRequest request, String tableName, Map indexMapping) throws IOException { + String sql = transferSearchRequest2HavenaskSql(tableName, request.source(), indexMapping); + String kvpair = "format:full_json;timeout:10000;databaseName:" + SQL_DATABASE; + QrsSqlRequest qrsQueryPhaseSqlRequest = new QrsSqlRequest(sql, kvpair); + QrsSqlResponse qrsQueryPhaseSqlResponse = qrsClient.executeSql(qrsQueryPhaseSqlRequest); + if (Strings.isNullOrEmpty(qrsQueryPhaseSqlResponse.getResult())) { + // TODO + } + SqlResponse queryPhaseSqlResponse = SqlResponse.parse(qrsQueryPhaseSqlResponse.getResult()); + if (logger.isDebugEnabled()) { + logger.debug("sql: {}, sqlResponse took: {} ms", sql, queryPhaseSqlResponse.getTotalTime()); + } + return queryPhaseSqlResponse; + } + + public String transferSearchRequest2HavenaskSql(String table, SearchSourceBuilder dsl, Map indexMapping) + throws IOException { + if (dsl == null) { + throw new IllegalArgumentException("request source can not be null!"); + } + StringBuilder sqlQuery = new StringBuilder(); + QueryBuilder queryBuilder = dsl.query(); + StringBuilder where = new StringBuilder(); + StringBuilder selectParams = new StringBuilder(); + StringBuilder orderBy = new StringBuilder(); + + selectParams.append(" _id"); + + if (!dsl.knnSearch().isEmpty()) { + where.append(" where "); + boolean first = true; + for (KnnSearchBuilder knnSearchBuilder : dsl.knnSearch()) { + if (!knnSearchBuilder.getFilterQueries().isEmpty() || knnSearchBuilder.getSimilarity() != null) { + throw new IOException("unsupported knn parameter: " + dsl); + } + + String fieldName = knnSearchBuilder.getField(); + + if (indexMapping == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "index mapping is null, field: %s is not a vector type field", fieldName) + ); + } + String similarity = getSimilarity(fieldName, indexMapping); + if (similarity == null) { + throw new IOException(String.format(Locale.ROOT, "field: %s is not a vector type field", fieldName)); + } + + if (false == first) { + where.append(" or "); + selectParams.append(" + "); + } + + if (first) { + first = false; + selectParams.append(", ("); + } + + checkVectorMagnitude(similarity, knnSearchBuilder.getQueryVector()); + selectParams.append(getScoreComputeStr(fieldName, similarity)); + + where.append("MATCHINDEX('").append(fieldName).append("', '"); + for (int i = 0; i < knnSearchBuilder.getQueryVector().length; i++) { + where.append(knnSearchBuilder.getQueryVector()[i]); + if (i < knnSearchBuilder.getQueryVector().length - 1) { + where.append(","); + } + } + where.append("&n=").append(knnSearchBuilder.k()).append("')"); + } + selectParams.append(") as _score"); + orderBy.append(" order by _score desc"); + } else if (queryBuilder != null) { + if (queryBuilder instanceof MatchAllQueryBuilder) {} else if (queryBuilder instanceof ProximaQueryBuilder) { + ProximaQueryBuilder proximaQueryBuilder = (ProximaQueryBuilder) queryBuilder; + String fieldName = proximaQueryBuilder.getFieldName(); + + if (indexMapping == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "index mapping is null, field: %s is not a vector type field", fieldName) + ); + } + String similarity = getSimilarity(fieldName, indexMapping); + if (similarity == null) { + throw new IOException(String.format(Locale.ROOT, "field: %s is not a vector type field", fieldName)); + } + + checkVectorMagnitude(similarity, proximaQueryBuilder.getVector()); + + selectParams.append(", ").append(getScoreComputeStr(fieldName, similarity)).append(" as _score"); + where.append(" where MATCHINDEX('").append(proximaQueryBuilder.getFieldName()).append("', '"); + for (int i = 0; i < proximaQueryBuilder.getVector().length; i++) { + + where.append(proximaQueryBuilder.getVector()[i]); + if (i < proximaQueryBuilder.getVector().length - 1) { + where.append(","); + } + } + where.append("&n=").append(proximaQueryBuilder.getSize()).append("')"); + orderBy.append(" order by _score desc"); + } else if (queryBuilder instanceof TermQueryBuilder) { + TermQueryBuilder termQueryBuilder = (TermQueryBuilder) queryBuilder; + where.append(" where ").append(termQueryBuilder.fieldName()).append("='").append(termQueryBuilder.value()).append("'"); + } else if (queryBuilder instanceof MatchQueryBuilder) { + MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; + where.append(" where MATCHINDEX('") + .append(matchQueryBuilder.fieldName()) + .append("', '") + .append(matchQueryBuilder.value()) + .append("')"); + } else { + throw new IOException("unsupported DSL: " + dsl); + } + } + sqlQuery.append("select").append(selectParams).append(" from ").append(table); + sqlQuery.append(where).append(orderBy); + int size = 0; + if (dsl.size() >= 0) { + size += dsl.size(); + if (dsl.from() >= 0) { + size += dsl.from(); + } + } + + if (size > 0) { + sqlQuery.append(" limit ").append(size); + } + return sqlQuery.toString(); + } + + @SuppressWarnings("unchecked") + private String getSimilarity(String fieldName, Map indexMapping) { + // TODO: 需要考虑如何优化, + // 1.similarity的获取方式, + // 2.针对嵌套的properties如何查询 + Object propertiesObj = indexMapping.get(PROPERTIES_FIELD); + if (propertiesObj instanceof Map) { + Map propertiesMapping = (Map) propertiesObj; + Object fieldObj = propertiesMapping.get(fieldName); + if (fieldObj instanceof Map) { + Map fieldMapping = (Map) fieldObj; + return (String) fieldMapping.get(SIMILARITY); + } + } + + return null; + } + + private void checkVectorMagnitude(String similarity, float[] queryVector) { + if (similarity.equals(VECTOR_SIMILARITY_TYPE_DOT_PRODUCT) && Math.abs(computeSquaredMagnitude(queryVector) - 1.0f) > 1e-4f) { + throw new IllegalArgumentException( + "The [" + + DenseVectorFieldMapper.Similarity.DOT_PRODUCT.getValue() + + "] " + + "similarity can only be used with unit-length vectors." + ); + } + } + + private float computeSquaredMagnitude(float[] queryVector) { + float squaredMagnitude = 0; + for (float v : queryVector) { + squaredMagnitude += v * v; + } + return squaredMagnitude; + } + + private String getScoreComputeStr(String fieldName, String similarity) throws IOException { + StringBuilder scoreComputeStr = new StringBuilder(); + if (similarity != null && similarity.equals(VECTOR_SIMILARITY_TYPE_L2_NORM)) { + // e.g. "(1/(1+vecscore('fieldName')))" + scoreComputeStr.append("(1/(").append("1+vector_score('").append(fieldName).append("')))"); + } else if (similarity != null && similarity.equals(VECTOR_SIMILARITY_TYPE_DOT_PRODUCT)) { + // e.g. "((1+vecscore('fieldName'))/2)" + scoreComputeStr.append("((1+vector_score('").append(fieldName).append("'))/2)"); + } else { + throw new IOException("unsupported similarity: " + similarity); + } + return scoreComputeStr.toString(); + } +} diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/action/TransportHavenaskSearchAction.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/action/TransportHavenaskSearchAction.java index f5d71b6e1..b52400a84 100644 --- a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/action/TransportHavenaskSearchAction.java +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/action/TransportHavenaskSearchAction.java @@ -17,24 +17,35 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.havenask.action.ActionListener; +import org.havenask.action.ingest.IngestActionForwarder; import org.havenask.action.search.SearchRequest; import org.havenask.action.search.SearchResponse; +import org.havenask.action.search.ShardSearchFailure; import org.havenask.action.support.ActionFilters; import org.havenask.action.support.HandledTransportAction; +import org.havenask.client.ha.SqlResponse; +import org.havenask.cluster.ClusterState; import org.havenask.cluster.service.ClusterService; import org.havenask.common.inject.Inject; import org.havenask.engine.NativeProcessControlService; +import org.havenask.engine.search.HavenaskSearchQueryProcessor; import org.havenask.engine.rpc.QrsClient; import org.havenask.engine.rpc.http.QrsHttpClient; +import org.havenask.engine.search.HavenaskSearchFetchProcessor; +import org.havenask.search.internal.InternalSearchResponse; import org.havenask.tasks.Task; import org.havenask.threadpool.ThreadPool; import org.havenask.transport.TransportService; +import java.util.Map; + public class TransportHavenaskSearchAction extends HandledTransportAction { private static final Logger logger = LogManager.getLogger(TransportHavenaskSearchAction.class); - private ClusterService clusterService; + private final IngestActionForwarder ingestForwarder; private QrsClient qrsClient; + private HavenaskSearchQueryProcessor havenaskSearchQueryProcessor; + private HavenaskSearchFetchProcessor havenaskSearchFetchProcessor; @Inject public TransportHavenaskSearchAction( @@ -45,11 +56,75 @@ public TransportHavenaskSearchAction( ) { super(HavenaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new, ThreadPool.Names.SEARCH); this.clusterService = clusterService; + this.ingestForwarder = new IngestActionForwarder(transportService); this.qrsClient = new QrsHttpClient(nativeProcessControlService.getQrsHttpPort()); + havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + havenaskSearchFetchProcessor = new HavenaskSearchFetchProcessor(qrsClient); } @Override protected void doExecute(Task task, SearchRequest request, ActionListener listener) { - // TODO + if (false == clusterService.localNode().isIngestNode()) { + ingestForwarder.forwardIngestRequest(HavenaskSearchAction.INSTANCE, request, listener); + return; + } + + try { + // TODO: 目前的逻辑只有单havenask索引的查询会走到这里,后续如果有多索引的查询,这里需要做相应的修改 + if (request.indices().length != 1) { + throw new IllegalArgumentException("illegal index count! only support search single havenask index."); + } + String tableName = request.indices()[0]; + + long startTime = System.nanoTime(); + + ClusterState clusterState = clusterService.state(); + + Map indexMapping = clusterState.metadata().index(tableName).mapping() != null + ? clusterState.metadata().index(tableName).mapping().getSourceAsMap() + : null; + SqlResponse havenaskSearchQueryPhaseSqlResponse = havenaskSearchQueryProcessor.executeQuery(request, tableName, indexMapping); + + InternalSearchResponse internalSearchResponse = havenaskSearchFetchProcessor.executeFetch( + havenaskSearchQueryPhaseSqlResponse, + tableName, + request.source() + ); + + SearchResponse searchResponse = buildSearchResponse( + tableName, + internalSearchResponse, + havenaskSearchQueryPhaseSqlResponse, + startTime + ); + listener.onResponse(searchResponse); + } catch (Exception e) { + logger.error("Failed to execute havenask search, ", e); + listener.onFailure(e); + } + } + + private SearchResponse buildSearchResponse( + String indexName, + InternalSearchResponse internalSearchResponse, + SqlResponse havenaskSearchQueryPhaseSqlResponse, + long startTime + ) { + ClusterState clusterState = clusterService.state(); + int totalShards = clusterState.metadata().index(indexName).getNumberOfShards(); + double coveredPercent = havenaskSearchQueryPhaseSqlResponse.getCoveredPercent(); + int successfulShards = (int) Math.round(totalShards * coveredPercent); + + long endTime = System.nanoTime(); + return new SearchResponse( + internalSearchResponse, + null, + totalShards, + successfulShards, + 0, + (endTime - startTime) / 1000000, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); } } diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSourcePhase.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSourcePhase.java new file mode 100644 index 000000000..67bef631b --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSourcePhase.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search.fetch; + +import org.havenask.HavenaskException; +import org.havenask.HavenaskParseException; +import org.havenask.common.bytes.BytesArray; +import org.havenask.common.bytes.BytesReference; +import org.havenask.common.collect.Tuple; +import org.havenask.common.io.stream.BytesStreamOutput; +import org.havenask.common.xcontent.XContentBuilder; +import org.havenask.common.xcontent.XContentHelper; +import org.havenask.common.xcontent.XContentType; +import org.havenask.search.builder.SearchSourceBuilder; +import org.havenask.search.fetch.subphase.FetchSourceContext; + +import java.io.IOException; +import java.util.Map; + +public class HavenaskFetchSourcePhase implements HavenaskFetchSubPhase { + @Override + public HavenaskFetchSubPhaseProcessor getProcessor(String indexName, SearchSourceBuilder searchSourceBuilder) throws IOException { + FetchSourceContext fetchSourceContext = searchSourceBuilder.fetchSource(); + if (fetchSourceContext == null || fetchSourceContext.fetchSource() == false) { + return null; + } + assert fetchSourceContext.fetchSource(); + return new HavenaskFetchSubPhaseProcessor() { + @Override + public void process(HitContent hitContent) { + hitExecute(indexName, fetchSourceContext, hitContent); + } + }; + } + + private void hitExecute(String indexName, FetchSourceContext fetchSourceContext, HitContent hitContent) { + BytesReference sourceAsBytes = new BytesArray((String) hitContent.source); + SourceContent sourceContent = loadSource(sourceAsBytes); + + // If source is disabled in the mapping, then attempt to return early. + if (sourceContent.getSourceAsMap() == null && sourceContent.getSourceAsBytes() == null) { + if (containsFilters(fetchSourceContext)) { + throw new IllegalArgumentException( + "unable to fetch fields from _source field: _source is disabled in the mappings " + "for index [" + indexName + "]" + ); + } + return; + } + + // filter the source and add it to the hit. + Object value = fetchSourceContext.getFilter().apply(sourceContent.getSourceAsMap()); + try { + final int initialCapacity = Math.min(1024, sourceAsBytes.length()); + BytesStreamOutput streamOutput = new BytesStreamOutput(initialCapacity); + XContentBuilder builder = new XContentBuilder(sourceContent.getSourceContentType().xContent(), streamOutput); + if (value != null) { + builder.value(value); + } else { + // This happens if the source filtering could not find the specified in the _source. + // Just doing `builder.value(null)` is valid, but the xcontent validation can't detect what format + // it is. In certain cases, for example response serialization we fail if no xcontent type can't be + // detected. So instead we just return an empty top level object. Also this is in inline with what was + // being return in this situation in 5.x and earlier. + builder.startObject(); + builder.endObject(); + } + hitContent.hit.sourceRef(BytesReference.bytes(builder)); + } catch (IOException e) { + throw new HavenaskException("Error filtering source", e); + } + } + + private SourceContent loadSource(BytesReference sourceAsBytes) { + Tuple> tuple = sourceAsMapAndType(sourceAsBytes); + XContentType sourceContentType = tuple.v1(); + Map source = tuple.v2(); + return new SourceContent(sourceAsBytes, source, sourceContentType); + } + + public class SourceContent { + private BytesReference sourceAsBytes; + private Map sourceAsMap; + private XContentType sourceContentType; + + public SourceContent(BytesReference sourceAsBytes, Map source, XContentType sourceContentType) { + this.sourceAsBytes = sourceAsBytes; + this.sourceAsMap = source; + this.sourceContentType = sourceContentType; + } + + public BytesReference getSourceAsBytes() { + return sourceAsBytes; + } + + public Map getSourceAsMap() { + return sourceAsMap; + } + + public XContentType getSourceContentType() { + return sourceContentType; + } + } + + private static boolean containsFilters(FetchSourceContext context) { + return context.includes().length != 0 || context.excludes().length != 0; + } + + private static Tuple> sourceAsMapAndType(BytesReference source) throws HavenaskParseException { + return XContentHelper.convertToMap(source, false); + } +} diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhase.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhase.java new file mode 100644 index 000000000..b34e1eee3 --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhase.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search.fetch; + +import org.havenask.search.SearchHit; +import org.havenask.search.builder.SearchSourceBuilder; + +import java.io.IOException; + +public interface HavenaskFetchSubPhase { + class HitContent { + SearchHit hit; + Object source; + + public HitContent(SearchHit searchHit, Object source) { + this.hit = searchHit; + this.source = source; + } + + public SearchHit getHit() { + return hit; + } + } + + HavenaskFetchSubPhaseProcessor getProcessor(String indexName, SearchSourceBuilder searchSourceBuilder) throws IOException; +} diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhaseProcessor.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhaseProcessor.java new file mode 100644 index 000000000..aae61af1d --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/search/fetch/HavenaskFetchSubPhaseProcessor.java @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search.fetch; + +import org.havenask.engine.search.fetch.HavenaskFetchSubPhase.HitContent; +import java.io.IOException; + +public interface HavenaskFetchSubPhaseProcessor { + void process(HitContent hitContent) throws IOException; +} diff --git a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/BizConfigGeneratorTests.java b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/BizConfigGeneratorTests.java index 7c3e9a7d0..5e59ac141 100644 --- a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/BizConfigGeneratorTests.java +++ b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/BizConfigGeneratorTests.java @@ -100,7 +100,14 @@ public void testBasic() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" @@ -327,7 +334,14 @@ public void testDupFieldProcessor() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" @@ -609,7 +623,14 @@ public void testMaxDocConfig() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" diff --git a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/TableConfigGeneratorTests.java b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/TableConfigGeneratorTests.java index c5ed95fd6..039ceb276 100644 --- a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/TableConfigGeneratorTests.java +++ b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/config/generator/TableConfigGeneratorTests.java @@ -94,7 +94,14 @@ public void testBasic() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" @@ -318,7 +325,14 @@ public void testDupFieldProcessor() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" @@ -614,7 +628,14 @@ public void testMaxDocConfig() throws IOException { + "\t\t\t\"max_doc_count\":100000\n" + "\t\t},\n" + "\t\t\"merge_config\":{\n" - + "\t\t\t\"merge_strategy\":\"combined\"\n" + + "\t\t\t\"merge_strategy\":\"combined\",\n" + + "\t\t\t\"merge_strategy_params\":{\n" + + "\t\t\t\t\"input_limits\":\"max-segment-size=5120\",\n" + + "\t\t\t\t\"output_limits\":\"max-merged-segment-size=13312;max-total-merged-size=15360;" + + "max-small-segment-count=10;merge-size-upperbound=256;merge-size-lowerbound=64\",\n" + + "\t\t\t\t\"strategy_conditions\":\"priority-feature=valid-doc-count#asc;" + + "conflict-segment-count=10;conflict-delete-percent=30\"\n" + + "\t\t\t}\n" + "\t\t}\n" + "\t},\n" + "\t\"online_index_config\":{\n" diff --git a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchFetchProcessorTests.java b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchFetchProcessorTests.java new file mode 100644 index 000000000..7b791f2f6 --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchFetchProcessorTests.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search.fetch; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.havenask.client.ha.SqlResponse; +import org.havenask.common.lucene.search.TopDocsAndMaxScore; +import org.havenask.engine.rpc.QrsClient; +import org.havenask.engine.search.fetch.HavenaskFetchSubPhase.HitContent; +import org.havenask.engine.search.HavenaskSearchFetchProcessor; +import org.havenask.search.SearchHit; +import org.havenask.search.builder.SearchSourceBuilder; +import org.havenask.search.fetch.subphase.FetchSourceContext; +import org.havenask.search.internal.InternalSearchResponse; +import org.havenask.test.HavenaskTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HavenaskSearchFetchProcessorTests extends HavenaskTestCase { + private QrsClient qrsClient = mock(QrsClient.class); + + public void testBuildQuerySearchResult() throws IOException { + String sqlResponseStr = "{\"total_time\":8.126,\"has_soft_failure\":false,\"covered_percent\":1.0," + + "\"row_count\":4,\"format_type\":\"full_json\",\"search_info\":{},\"rpc_info\":\"\"," + + "\"table_leader_info\":{},\"table_build_watermark\":{},\"sql_query\":" + + "\"query=select _id, vector_score('image') as _score from vector_test_0 where " + + "MATCHINDEX('image', '1.1, 1.1') order by _score desc&&kvpair=format:full_json;databaseName:general" + + "\",\"iquan_plan\":{\"error_code\":0,\"error_message\":\"\",\"result\":{\"rel_plan_version\":\"\"," + + "\"rel_plan\":[],\"exec_params\":{}}},\"navi_graph\":\"\",\"trace\":[],\"sql_result\":{\"data\":[[\"4" + + "\",9.680000305175782],[\"3\",7.260000228881836],[\"2\",4.840000152587891],[\"1\",2.4200000762939455]]," + + "\"column_name\":[\"_id\",\"_score\"],\"column_type\":[\"multi_char\",\"float\"]},\"error_info\":{" + + "\"ErrorCode\":0,\"Error\":\"ERROR_NONE\",\"Message\":\"\"}}"; + SqlResponse queryPhaseSqlResponse = SqlResponse.parse(sqlResponseStr); + String[] resStr = new String[] { "4", "3", "2", "1" }; + float[] resFloat = new float[] { 9.6800F, 7.2600F, 4.8400F, 2.4200F }; + int rowNum = resStr.length; + double delta = 0.0001; + + List idList = new ArrayList<>(queryPhaseSqlResponse.getRowCount()); + HavenaskSearchFetchProcessor havenaskSearchFetchProcessor = new HavenaskSearchFetchProcessor(qrsClient); + TopDocsAndMaxScore topDocsAndMaxScore = havenaskSearchFetchProcessor.buildQuerySearchResult(queryPhaseSqlResponse, idList); + assertEquals(4L, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(9.6800F, topDocsAndMaxScore.maxScore, delta); + for (int i = 0; i < rowNum; i++) { + assertEquals(resFloat[i], topDocsAndMaxScore.topDocs.scoreDocs[i].score, delta); + assertEquals(-1, topDocsAndMaxScore.topDocs.scoreDocs[i].shardIndex); + assertEquals(i, topDocsAndMaxScore.topDocs.scoreDocs[i].doc); + assertEquals(resStr[i], idList.get(i)); + } + } + + public void testHitExecute() throws IOException { + String indexName = "test"; + Boolean[] needFilter = new Boolean[] { true, true, true, true, false, true, true }; + String[][] includes = new String[][] { + { "name" }, + {}, + { "key1", "length" }, + { "name", "length" }, + {}, + { "na*" }, + { "na*", "len*" } }; + String[][] excludes = new String[][] { {}, { "key1" }, {}, { "name" }, {}, {}, { "name" } }; + String[] resSourceStr = new String[] { + "{\"name\":\"alice\"}", + "{\"name\":\"alice\",\"length\":1}", + "{\"key1\":\"doc1\",\"length\":1}", + "{\"length\":1}", + "", + "{\"name\":\"alice\"}", + "{\"length\":1}" }; + + for (int i = 0; i < includes.length; i++) { + FetchSourceContext fetchSourceContext = new FetchSourceContext(needFilter[i], includes[i], excludes[i]); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + + SearchHit searchHit = new SearchHit(-1); + String sourceStr = "{\n" + " \"key1\" :\"doc1\",\n" + " \"name\" :\"alice\",\n" + " \"length\":1\n" + "}\n"; + Object source = sourceStr; + HitContent hit = new HitContent(searchHit, source); + + HavenaskFetchSubPhaseProcessor processor = new HavenaskFetchSourcePhase().getProcessor(indexName, searchSourceBuilder); + if (processor != null) { + processor.process(hit); + } else { + continue; + } + + SearchHit res = hit.getHit(); + assertEquals(resSourceStr[i], res.getSourceAsString()); + } + } + + public void testTransferSqlResponse2FetchResult() throws IOException { + String indexName = "table"; + int docsNum = 4; + int loadSize = 3; + String[] resStr = new String[] { + "{\n" + + " \"_index\" : \"table\",\n" + + " \"_type\" : \"_doc\",\n" + + " \"_id\" : \"4\",\n" + + " \"_score\" : 4.0,\n" + + " \"_source\" : {\n" + + " \"image\" : [\n" + + " 4.1,\n" + + " 4.1\n" + + " ]\n" + + " }\n" + + "}", + "{\n" + + " \"_index\" : \"table\",\n" + + " \"_type\" : \"_doc\",\n" + + " \"_id\" : \"3\",\n" + + " \"_score\" : 3.0,\n" + + " \"_source\" : {\n" + + " \"warn\" : \"source not found\"\n" + + " }\n" + + "}", + "{\n" + + " \"_index\" : \"table\",\n" + + " \"_type\" : \"_doc\",\n" + + " \"_id\" : \"2\",\n" + + " \"_score\" : 2.0,\n" + + " \"_source\" : {\n" + + " \"image\" : [\n" + + " 2.1,\n" + + " 2.1\n" + + " ]\n" + + " }\n" + + "}", + "{\n" + + " \"_index\" : \"table\",\n" + + " \"_type\" : \"_doc\",\n" + + " \"_id\" : \"1\",\n" + + " \"_score\" : 1.0,\n" + + " \"_source\" : {\n" + + " \"image\" : [\n" + + " 1.1,\n" + + " 1.1\n" + + " ]\n" + + " }\n" + + "}" }; + + // mock SqlResponse + Object[][] Data = { + { "1", "{\n" + " \"image\":[1.1, 1.1]\n" + "}\n" }, + { "2", "{\n" + " \"image\":[2.1, 2.1]\n" + "}" }, + { "4", "{\n" + " \"image\":[4.1, 4.1]\n" + "}\n" } }; + SqlResponse sqlResponse = mock(SqlResponse.class); + when(sqlResponse.getRowCount()).thenReturn(Data.length); + SqlResponse.SqlResult sqlResult = mock(SqlResponse.SqlResult.class); + when(sqlResponse.getSqlResult()).thenReturn(sqlResult); + when(sqlResult.getData()).thenReturn(Data); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + + TopDocs topDocs = new TopDocs( + new TotalHits(sqlResponse.getRowCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 4), new ScoreDoc(1, 3), new ScoreDoc(2, 2), new ScoreDoc(3, 1) } + ); + float maxScore = 4; + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore); + + List idList = new ArrayList<>(loadSize); + for (int i = 0; i < docsNum; i++) { + idList.add(String.valueOf(docsNum - i)); + } + + HavenaskSearchFetchProcessor havenaskSearchFetchProcessor = new HavenaskSearchFetchProcessor(qrsClient); + InternalSearchResponse internalSearchResponse = havenaskSearchFetchProcessor.transferSqlResponse2FetchResult( + indexName, + idList, + sqlResponse, + topDocsAndMaxScore, + searchSourceBuilder + ); + for (int i = 0; i < loadSize; i++) { + assertEquals(resStr[i], internalSearchResponse.hits().getHits()[i].toString()); + } + } +} diff --git a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchQueryProcessorTests.java b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchQueryProcessorTests.java new file mode 100644 index 000000000..bc30490a7 --- /dev/null +++ b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/search/fetch/HavenaskSearchQueryProcessorTests.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2021, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.havenask.engine.search.fetch; + +import org.havenask.common.collect.List; +import org.havenask.engine.index.query.KnnQueryBuilder; +import org.havenask.engine.rpc.QrsClient; +import org.havenask.engine.search.HavenaskSearchQueryProcessor; +import org.havenask.index.query.QueryBuilders; +import org.havenask.search.builder.KnnSearchBuilder; +import org.havenask.search.builder.SearchSourceBuilder; +import org.havenask.test.HavenaskTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class HavenaskSearchQueryProcessorTests extends HavenaskTestCase { + private QrsClient qrsClient = mock(QrsClient.class); + private Map indexMapping = new HashMap<>(); + + @Before + public void setup() { + Map propertiesMapping = new HashMap<>(); + Map fieldMapping = new HashMap<>(); + Map field1Mapping = new HashMap<>(); + Map field2Mapping = new HashMap<>(); + fieldMapping.put("similarity", "L2_NORM"); + field1Mapping.put("similarity", "L2_NORM"); + field2Mapping.put("similarity", "DOT_PRODUCT"); + propertiesMapping.put("field", fieldMapping); + propertiesMapping.put("field1", field1Mapping); + propertiesMapping.put("field2", field2Mapping); + indexMapping.put("properties", propertiesMapping); + } + + public void testMatchAllDocsQuery() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + assertEquals(sql, "select _id from table"); + } + + public void testProximaQuery() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(new KnnQueryBuilder("field", new float[] { 1.0f, 2.0f }, 20)); + + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, indexMapping); + assertEquals( + "select _id, (1/(1+vector_score('field'))) as _score from table where MATCHINDEX('field', '1.0,2.0&n=20') order by _score desc", + sql + ); + } + + public void testUnsupportedDSL() { + try { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.existsQuery("field")); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + fail(); + } catch (IOException e) { + assertEquals(e.getMessage(), "unsupported DSL: {\"query\":{\"exists\":{\"field\":\"field\",\"boost\":1.0}}}"); + } + } + + public void testMatchQuery() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchQuery("field", "value")); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + assertEquals(sql, "select _id from table where MATCHINDEX('field', 'value')"); + } + + public void testLimit() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + builder.from(10); + builder.size(10); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + assertEquals("select _id from table limit 20", sql); + } + + public void testNoFrom() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + builder.size(10); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + assertEquals(sql, "select _id from table limit 10"); + } + + public void testNoSize() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + builder.from(10); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, null); + assertEquals(sql, "select _id from table"); + } + + // test knn dsl + public void testKnnDsl() throws IOException { + SearchSourceBuilder l2NormBuilder = new SearchSourceBuilder(); + l2NormBuilder.query(QueryBuilders.matchAllQuery()); + l2NormBuilder.knnSearch(List.of(new KnnSearchBuilder("field1", new float[] { 1.0f, 2.0f }, 20, 20, null))); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String l2NormSql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", l2NormBuilder, indexMapping); + assertEquals( + "select _id, ((1/(1+vector_score('field1')))) as _score from table " + + "where MATCHINDEX('field1', '1.0,2.0&n=20') order by _score desc", + l2NormSql + ); + + SearchSourceBuilder dotProductBuilder = new SearchSourceBuilder(); + dotProductBuilder.query(QueryBuilders.matchAllQuery()); + dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 20, 20, null))); + String dotProductSql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", dotProductBuilder, indexMapping); + assertEquals( + "select _id, (((1+vector_score('field2'))/2)) as _score from table " + + "where MATCHINDEX('field2', '0.6,0.8&n=20') order by _score desc", + dotProductSql + ); + } + + // test multi knn dsl + public void testMultiKnnDsl() throws IOException { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + builder.knnSearch( + List.of( + new KnnSearchBuilder("field1", new float[] { 1.0f, 2.0f }, 20, 20, null), + new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 10, 10, null) + ) + ); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + String sql = havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, indexMapping); + assertEquals( + "select _id, ((1/(1+vector_score('field1'))) + ((1+vector_score('field2'))/2)) as _score from table " + + "where MATCHINDEX('field1', '1.0,2.0&n=20') or MATCHINDEX('field2', '0.6,0.8&n=10') order by _score desc", + sql + ); + } + + public void testIllegalKnnParams() throws IOException { + SearchSourceBuilder dotProductBuilder = new SearchSourceBuilder(); + dotProductBuilder.query(QueryBuilders.matchAllQuery()); + dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 1.0f, 2.0f }, 20, 20, null))); + try { + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", dotProductBuilder, indexMapping); + fail("should throw IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals("The [dot_product] similarity can only be used with unit-length vectors.", e.getMessage()); + } + } + + // test unsupported knn dsl + public void testUnsupportedKnnDsl() { + try { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + builder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, 1.0f))); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, indexMapping); + fail(); + } catch (IOException e) { + assertEquals( + e.getMessage(), + "unsupported knn parameter: {\"query\":{\"match_all\":{\"boost\":1.0}}," + + "\"knn\":[{\"field\":\"field\",\"k\":20,\"num_candidates\":20,\"query_vector\":[1.0,2.0]," + + "\"similarity\":1.0}]}" + ); + } + + // unsupported getFilterQueries + try { + SearchSourceBuilder builder = new SearchSourceBuilder(); + builder.query(QueryBuilders.matchAllQuery()); + KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, null); + knnSearchBuilder.addFilterQuery(QueryBuilders.matchAllQuery()); + builder.knnSearch(List.of(knnSearchBuilder)); + HavenaskSearchQueryProcessor havenaskSearchQueryProcessor = new HavenaskSearchQueryProcessor(qrsClient); + havenaskSearchQueryProcessor.transferSearchRequest2HavenaskSql("table", builder, indexMapping); + fail(); + } catch (IOException e) { + assertEquals( + e.getMessage(), + "unsupported knn parameter: {\"query\":{\"match_all\":{\"boost\":1.0}}," + + "\"knn\":[{\"field\":\"field\",\"k\":20,\"num_candidates\":20,\"query_vector\":[1.0,2.0]," + + "\"filter\":[{\"match_all\":{\"boost\":1.0}}]}]}" + ); + } + } +} diff --git a/elastic-fed/server/src/main/java/org/havenask/search/SearchHit.java b/elastic-fed/server/src/main/java/org/havenask/search/SearchHit.java index e9a8a6ac4..e8e23d945 100644 --- a/elastic-fed/server/src/main/java/org/havenask/search/SearchHit.java +++ b/elastic-fed/server/src/main/java/org/havenask/search/SearchHit.java @@ -364,6 +364,13 @@ public long getPrimaryTerm() { return this.primaryTerm; } + /** + * Set index of the hit + */ + public void setIndex(String index) { + this.index = index; + } + /** * The index of the hit. */