diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/HavenaskIndexSearcher.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/HavenaskIndexSearcher.java index 08600eb6..b6b3cf07 100644 --- a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/HavenaskIndexSearcher.java +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/HavenaskIndexSearcher.java @@ -76,7 +76,12 @@ public HavenaskIndexSearcher( @Override public void search(Query query, Collector collector) throws IOException { - String sql = QueryTransformer.toSql(tableName, searchContext.request().source(), searchContext.indexShard().mapperService()); + String sql = QueryTransformer.toSql( + tableName, + searchContext.request().source(), + searchContext.indexShard().mapperService(), + searchContext.indexShard().shardId().getId() + ); String kvpair = "format:full_json;timeout:10000;databaseName:" + SQL_DATABASE; HavenaskSqlResponse response = client.execute(HavenaskSqlAction.INSTANCE, new HavenaskSqlRequest(sql, kvpair)).actionGet(); if (false == Strings.isNullOrEmpty(response.getResult())) { diff --git a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/QueryTransformer.java b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/QueryTransformer.java index 88d14285..04b63187 100644 --- a/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/QueryTransformer.java +++ b/elastic-fed/modules/havenask-engine/src/main/java/org/havenask/engine/index/engine/QueryTransformer.java @@ -28,13 +28,15 @@ import org.havenask.search.builder.SearchSourceBuilder; public class QueryTransformer { - public static String toSql(String table, SearchSourceBuilder dsl, MapperService mapperService) throws IOException { + public static String toSql(String table, SearchSourceBuilder dsl, MapperService mapperService, int shardId) throws IOException { StringBuilder sqlQuery = new StringBuilder(); QueryBuilder queryBuilder = dsl.query(); StringBuilder where = new StringBuilder(); + StringBuilder shardParams = new StringBuilder(); StringBuilder selectParams = new StringBuilder(); StringBuilder orderBy = new StringBuilder(); + shardParams.append(String.format(Locale.ROOT, " /*+ SCAN_ATTR(partitionIds='%s')*/", shardId)); selectParams.append(" _id"); if (dsl.knnSearch().size() > 0) { @@ -115,7 +117,7 @@ public static String toSql(String table, SearchSourceBuilder dsl, MapperService throw new IOException("unsupported DSL: " + dsl); } } - sqlQuery.append("select").append(selectParams).append(" from ").append(table); + sqlQuery.append("select").append(shardParams).append(selectParams).append(" from ").append(table); sqlQuery.append(where).append(orderBy); int size = 0; if (dsl.size() >= 0) { diff --git a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/engine/QueryTransformerTests.java b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/engine/QueryTransformerTests.java index b66cbcf9..2b87e407 100644 --- a/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/engine/QueryTransformerTests.java +++ b/elastic-fed/modules/havenask-engine/src/test/java/org/havenask/engine/index/engine/QueryTransformerTests.java @@ -73,16 +73,17 @@ public void setup() throws IOException { public void testMatchAllDocsQuery() throws IOException { builder.query(QueryBuilders.matchAllQuery()); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals(sql, "select _id from table"); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table"); } public void testProximaQuery() throws IOException { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(new KnnQueryBuilder("field", new float[] { 1.0f, 2.0f }, 20)); - String sql = QueryTransformer.toSql("table", builder, mapperService); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); assertEquals( - "select _id, (1/(1+vector_score('field'))) as _score from table where MATCHINDEX('field', '1.0,2.0&n=20') order by _score desc", + "select /*+ SCAN_ATTR(partitionIds='0')*/ _id, (1/(1+vector_score('field'))) " + + "as _score from table where MATCHINDEX('field', '1.0,2.0&n=20') order by _score desc", sql ); } @@ -91,7 +92,7 @@ public void testUnsupportedDSL() { try { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.existsQuery("field")); - QueryTransformer.toSql("table", builder, mapperService); + QueryTransformer.toSql("table", builder, mapperService, 0); fail(); } catch (IOException e) { assertEquals(e.getMessage(), "unsupported DSL: {\"query\":{\"exists\":{\"field\":\"field\",\"boost\":1.0}}}"); @@ -102,16 +103,16 @@ public void testUnsupportedDSL() { public void testTermQuery() throws IOException { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.termQuery("field", "value")); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals(sql, "select _id from table where field='value'"); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table where field='value'"); } // test match query public void testMatchQuery() throws IOException { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.matchQuery("field", "value")); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals(sql, "select _id from table where MATCHINDEX('field', 'value')"); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table where MATCHINDEX('field', 'value')"); } // test limit @@ -120,8 +121,8 @@ public void testLimit() throws IOException { builder.query(QueryBuilders.matchAllQuery()); builder.from(10); builder.size(10); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals("select _id from table limit 20", sql); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals("select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table limit 20", sql); } // test no from @@ -129,8 +130,8 @@ public void testNoFrom() throws IOException { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.matchAllQuery()); builder.size(10); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals(sql, "select _id from table limit 10"); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table limit 10"); } // test no size @@ -138,8 +139,8 @@ public void testNoSize() throws IOException { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.matchAllQuery()); builder.from(10); - String sql = QueryTransformer.toSql("table", builder, mapperService); - assertEquals(sql, "select _id from table"); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); + assertEquals(sql, "select /*+ SCAN_ATTR(partitionIds='0')*/ _id from table"); } // test knn dsl @@ -147,9 +148,9 @@ public void testKnnDsl() throws IOException { SearchSourceBuilder l2NormBuilder = new SearchSourceBuilder(); l2NormBuilder.query(QueryBuilders.matchAllQuery()); l2NormBuilder.knnSearch(List.of(new KnnSearchBuilder("field1", new float[] { 1.0f, 2.0f }, 20, 20, null))); - String l2NormSql = QueryTransformer.toSql("table", l2NormBuilder, mapperService); + String l2NormSql = QueryTransformer.toSql("table", l2NormBuilder, mapperService, 0); assertEquals( - "select _id, ((1/(1+vector_score('field1')))) as _score from table " + "select /*+ SCAN_ATTR(partitionIds='0')*/ _id, ((1/(1+vector_score('field1')))) as _score from table " + "where MATCHINDEX('field1', '1.0,2.0&n=20') order by _score desc", l2NormSql ); @@ -157,9 +158,9 @@ public void testKnnDsl() throws IOException { SearchSourceBuilder dotProductBuilder = new SearchSourceBuilder(); dotProductBuilder.query(QueryBuilders.matchAllQuery()); dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 20, 20, null))); - String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService); + String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService, 0); assertEquals( - "select _id, (((1+vector_score('field2'))/2)) as _score from table " + "select /*+ SCAN_ATTR(partitionIds='0')*/ _id, (((1+vector_score('field2'))/2)) as _score from table " + "where MATCHINDEX('field2', '0.6,0.8&n=20') order by _score desc", dotProductSql ); @@ -175,9 +176,10 @@ public void testMultiKnnDsl() throws IOException { new KnnSearchBuilder("field2", new float[] { 0.6f, 0.8f }, 10, 10, null) ) ); - String sql = QueryTransformer.toSql("table", builder, mapperService); + String sql = QueryTransformer.toSql("table", builder, mapperService, 0); assertEquals( - "select _id, ((1/(1+vector_score('field1'))) + ((1+vector_score('field2'))/2)) as _score from table " + "select /*+ SCAN_ATTR(partitionIds='0')*/ _id, ((1/(1+vector_score('field1'))) + " + + "((1+vector_score('field2'))/2)) as _score from table " + "where MATCHINDEX('field1', '1.0,2.0&n=20') or MATCHINDEX('field2', '0.6,0.8&n=10') order by _score desc", sql ); @@ -188,7 +190,7 @@ public void testIllegalKnnParams() throws IOException { dotProductBuilder.query(QueryBuilders.matchAllQuery()); dotProductBuilder.knnSearch(List.of(new KnnSearchBuilder("field2", new float[] { 1.0f, 2.0f }, 20, 20, null))); try { - String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService); + String dotProductSql = QueryTransformer.toSql("table", dotProductBuilder, mapperService, 0); fail("should throw IllegalArgumentException"); } catch (IllegalArgumentException e) { assertEquals("The [dot_product] similarity can only be used with unit-length vectors.", e.getMessage()); @@ -201,7 +203,7 @@ public void testUnsupportedKnnDsl() { SearchSourceBuilder builder = new SearchSourceBuilder(); builder.query(QueryBuilders.matchAllQuery()); builder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, 1.0f))); - QueryTransformer.toSql("table", builder, mapperService); + QueryTransformer.toSql("table", builder, mapperService, 0); fail(); } catch (IOException e) { assertEquals( @@ -219,7 +221,7 @@ public void testUnsupportedKnnDsl() { KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder("field", new float[] { 1.0f, 2.0f }, 20, 20, null); knnSearchBuilder.addFilterQuery(QueryBuilders.matchAllQuery()); builder.knnSearch(List.of(knnSearchBuilder)); - QueryTransformer.toSql("table", builder, mapperService); + QueryTransformer.toSql("table", builder, mapperService, 0); fail(); } catch (IOException e) { assertEquals(