Skip to content

Commit

Permalink
适配之前分布式架构在多节点情况下的search接口 (alibaba#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
Huaixinww authored Jan 5, 2024
1 parent 9fca64a commit ca087ee
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ public HavenaskIndexSearcher(

@Override
public void search(Query query, Collector collector) throws IOException {
String sql = QueryTransformer.toSql(tableName, searchContext.request().source(), searchContext.indexShard().mapperService());
String sql = QueryTransformer.toSql(
tableName,
searchContext.request().source(),
searchContext.indexShard().mapperService(),
searchContext.indexShard().shardId().getId()
);
String kvpair = "format:full_json;timeout:10000;databaseName:" + SQL_DATABASE;
HavenaskSqlResponse response = client.execute(HavenaskSqlAction.INSTANCE, new HavenaskSqlRequest(sql, kvpair)).actionGet();
if (false == Strings.isNullOrEmpty(response.getResult())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
import org.havenask.search.builder.SearchSourceBuilder;

public class QueryTransformer {
public static String toSql(String table, SearchSourceBuilder dsl, MapperService mapperService) throws IOException {
public static String toSql(String table, SearchSourceBuilder dsl, MapperService mapperService, int shardId) throws IOException {
StringBuilder sqlQuery = new StringBuilder();
QueryBuilder queryBuilder = dsl.query();
StringBuilder where = new StringBuilder();
StringBuilder shardParams = new StringBuilder();
StringBuilder selectParams = new StringBuilder();
StringBuilder orderBy = new StringBuilder();

shardParams.append(String.format(Locale.ROOT, " /*+ SCAN_ATTR(partitionIds='%s')*/", shardId));
selectParams.append(" _id");

if (dsl.knnSearch().size() > 0) {
Expand Down Expand Up @@ -115,7 +117,7 @@ public static String toSql(String table, SearchSourceBuilder dsl, MapperService
throw new IOException("unsupported DSL: " + dsl);
}
}
sqlQuery.append("select").append(selectParams).append(" from ").append(table);
sqlQuery.append("select").append(shardParams).append(selectParams).append(" from ").append(table);
sqlQuery.append(where).append(orderBy);
int size = 0;
if (dsl.size() >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ public void setup() throws IOException {

public void testMatchAllDocsQuery() throws IOException {
builder.query(QueryBuilders.matchAllQuery());
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals(sql, "select _id from table");
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table");
}

public void testProximaQuery() throws IOException {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(new KnnQueryBuilder("field", new float[] { 1.0f, 2.0f }, 20));
String sql = QueryTransformer.toSql("table", builder, mapperService);
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(
"select _id, (1/(1+vector_score('field'))) as _score from table where MATCHINDEX('field', '1.0,2.0&n=20') order by _score desc",
"select /*+ SCAN_ATTR(partitionIds='0')*/ _id, (1/(1+vector_score('field'))) "
+ "as _score from table where MATCHINDEX('field', '1.0,2.0&n=20') order by _score desc",
sql
);
}
Expand All @@ -91,7 +92,7 @@ public void testUnsupportedDSL() {
try {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.existsQuery("field"));
QueryTransformer.toSql("table", builder, mapperService);
QueryTransformer.toSql("table", builder, mapperService, 0);
fail();
} catch (IOException e) {
assertEquals(e.getMessage(), "unsupported DSL: {\"query\":{\"exists\":{\"field\":\"field\",\"boost\":1.0}}}");
Expand All @@ -102,16 +103,16 @@ public void testUnsupportedDSL() {
public void testTermQuery() throws IOException {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.termQuery("field", "value"));
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals(sql, "select _id from table where field='value'");
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table where field='value'");
}

// test match query
public void testMatchQuery() throws IOException {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.matchQuery("field", "value"));
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals(sql, "select _id from table where MATCHINDEX('field', 'value')");
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table where MATCHINDEX('field', 'value')");
}

// test limit
Expand All @@ -120,46 +121,46 @@ public void testLimit() throws IOException {
builder.query(QueryBuilders.matchAllQuery());
builder.from(10);
builder.size(10);
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals("select _id from table limit 20", sql);
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals("select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table limit 20", sql);
}

// test no from
public void testNoFrom() throws IOException {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.matchAllQuery());
builder.size(10);
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals(sql, "select _id from table limit 10");
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table limit 10");
}

// test no size
public void testNoSize() throws IOException {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.matchAllQuery());
builder.from(10);
String sql = QueryTransformer.toSql("table", builder, mapperService);
assertEquals(sql, "select _id from table");
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table");
}

// test knn dsl
public void testKnnDsl() throws IOException {
SearchSourceBuilder l2NormBuilder = new SearchSourceBuilder();
l2NormBuilder.query(QueryBuilders.matchAllQuery());
l2NormBuilder.knnSearch(List.of(new KnnSearchBuilder("field1", new float[] { 1.0f, 2.0f }, 20, 20, null)));
String l2NormSql = QueryTransformer.toSql("table", l2NormBuilder, mapperService);
String l2NormSql = QueryTransformer.toSql("table", l2NormBuilder, mapperService, 0);
assertEquals(
"select _id, ((1/(1+vector_score('field1')))) as _score from table "
"select /*+ SCAN_ATTR(partitionIds='0')*/ _id, ((1/(1+vector_score('field1')))) as _score from table "
+ "where MATCHINDEX('field1', '1.0,2.0&n=20') order by _score desc",
l2NormSql
);

SearchSourceBuilder dotProductBuilder = new SearchSourceBuilder();
dotProductBuilder.query(QueryBuilders.matchAllQuery());
dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 20, 20, null)));
String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService);
String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService, 0);
assertEquals(
"select _id, (((1+vector_score('field2'))/2)) as _score from table "
"select /*+ SCAN_ATTR(partitionIds='0')*/ _id, (((1+vector_score('field2'))/2)) as _score from table "
+ "where MATCHINDEX('field2', '0.6,0.8&n=20') order by _score desc",
dotProductSql
);
Expand All @@ -175,9 +176,10 @@ public void testMultiKnnDsl() throws IOException {
new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 10, 10, null)
)
);
String sql = QueryTransformer.toSql("table", builder, mapperService);
String sql = QueryTransformer.toSql("table", builder, mapperService, 0);
assertEquals(
"select _id, ((1/(1+vector_score('field1'))) + ((1+vector_score('field2'))/2)) as _score from table "
"select /*+ SCAN_ATTR(partitionIds='0')*/ _id, ((1/(1+vector_score('field1'))) + "
+ "((1+vector_score('field2'))/2)) as _score from table "
+ "where MATCHINDEX('field1', '1.0,2.0&n=20') or MATCHINDEX('field2', '0.6,0.8&n=10') order by _score desc",
sql
);
Expand All @@ -188,7 +190,7 @@ public void testIllegalKnnParams() throws IOException {
dotProductBuilder.query(QueryBuilders.matchAllQuery());
dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 1.0f, 2.0f }, 20, 20, null)));
try {
String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService);
String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService, 0);
fail("should throw IllegalArgumentException");
} catch (IllegalArgumentException e) {
assertEquals("The [dot_product] similarity can only be used with unit-length vectors.", e.getMessage());
Expand All @@ -201,7 +203,7 @@ public void testUnsupportedKnnDsl() {
SearchSourceBuilder builder = new SearchSourceBuilder();
builder.query(QueryBuilders.matchAllQuery());
builder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, 1.0f)));
QueryTransformer.toSql("table", builder, mapperService);
QueryTransformer.toSql("table", builder, mapperService, 0);
fail();
} catch (IOException e) {
assertEquals(
Expand All @@ -219,7 +221,7 @@ public void testUnsupportedKnnDsl() {
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, null);
knnSearchBuilder.addFilterQuery(QueryBuilders.matchAllQuery());
builder.knnSearch(List.of(knnSearchBuilder));
QueryTransformer.toSql("table", builder, mapperService);
QueryTransformer.toSql("table", builder, mapperService, 0);
fail();
} catch (IOException e) {
assertEquals(
Expand Down

0 comments on commit ca087ee

Please sign in to comment.