Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Nested Function in Order By Clause #280

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
Expand Down Expand Up @@ -105,7 +107,18 @@ private void validateArgs(List<UnresolvedExpression> args) {
* @param field : Nested field to generate path of.
* @return : Path of field derived from last level of nesting.
*/
private ReferenceExpression generatePath(String field) {
public static ReferenceExpression generatePath(String field) {
return new ReferenceExpression(field.substring(0, field.lastIndexOf(".")), STRING);
}

/**
* Check if supplied expression is a nested function.
* @param expr Expression checking if is nested function.
* @return True if expression is a nested function.
*/
public static Boolean isNestedFunction(Expression expr) {
return (expr instanceof FunctionExpression
&& ((FunctionExpression) expr).getFunctionName().getFunctionName()
.equalsIgnoreCase(BuiltinFunctionName.NESTED.name()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME;
import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;
import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
import static org.opensearch.sql.ast.dsl.AstDSL.alias;
import static org.opensearch.sql.ast.dsl.AstDSL.argument;
Expand Down Expand Up @@ -39,6 +41,7 @@
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC;
Expand Down Expand Up @@ -574,6 +577,10 @@ public void project_nested_field_arg() {
function("nested", qualifiedName("message", "info")), null)
)
);

assertTrue(isNestedFunction(DSL.nested(DSL.ref("message.info", STRING))));
assertFalse(isNestedFunction(DSL.literal("fieldA")));
assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message")))));
}

@Test
Expand Down
11 changes: 11 additions & 0 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4469,6 +4469,17 @@ Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause::
| b |
+---------------------------------+

Example with ``field`` and ``path`` parameters in the SELECT and ORDER BY clause::

os> SELECT nested(message.info, message) FROM nested ORDER BY nested(message.info, message) DESC;
fetched rows / total rows = 2/2
+---------------------------------+
| nested(message.info, message) |
|---------------------------------|
| b |
| a |
+---------------------------------+


System Functions
================
Expand Down
34 changes: 34 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,40 @@ public void nested_function_with_order_by_clause() {
rows("zz"));
}

@Test
public void nested_function_with_order_by_clause_desc() {
String query =
"SELECT nested(message.info) FROM " + TEST_INDEX_NESTED_TYPE
+ " ORDER BY nested(message.info, message) DESC";
JSONObject result = executeJdbcRequest(query);

assertEquals(6, result.getInt("total"));
verifyDataRows(result,
rows("zz"),
rows("c"),
rows("c"),
rows("a"),
rows("b"),
GumpacG marked this conversation as resolved.
Show resolved Hide resolved
rows("a"));
}

@Test
public void nested_function_and_field_with_order_by_clause() {
String query =
"SELECT nested(message.info), myNum FROM " + TEST_INDEX_NESTED_TYPE
+ " ORDER BY nested(message.info, message), myNum";
JSONObject result = executeJdbcRequest(query);

assertEquals(6, result.getInt("total"));
verifyDataRows(result,
rows("a", 1),
rows("c", 4),
rows("a", 4),
rows("b", 2),
rows("c", 3),
rows("zz", new JSONArray(List.of(3, 4))));
}

// Nested function in GROUP BY clause is not yet implemented for JDBC format. This test ensures
// that the V2 engine falls back to legacy implementation.
// TODO Fix the test when NESTED is supported in GROUP BY in the V2 engine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.opensearch.storage.scan;

import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import java.util.function.Function;
import lombok.EqualsAndHashCode;
import org.opensearch.sql.expression.ReferenceExpression;
Expand Down Expand Up @@ -113,9 +115,15 @@ public boolean pushDownNested(LogicalNested nested) {
return delegate.pushDownNested(nested);
}

/**
* Valid if sorting is only by fields.
* @param sort Logical sort
* @return True if sorting by fields only
*/
private boolean sortByFieldsOnly(LogicalSort sort) {
return sort.getSortList().stream()
.map(sortItem -> sortItem.getRight() instanceof ReferenceExpression)
.map(sortItem -> sortItem.getRight() instanceof ReferenceExpression
|| isNestedFunction(sortItem.getRight()))
.reduce(true, Boolean::logicalAnd);
MitchellGale marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package org.opensearch.sql.opensearch.storage.script.filter.lucene;

import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.function.Function;
Expand Down Expand Up @@ -62,10 +64,7 @@ public boolean canSupport(FunctionExpression func) {
* @return return true if function has supported nested function expression.
*/
public boolean isNestedPredicate(FunctionExpression func) {
return ((func.getArguments().get(0) instanceof FunctionExpression
&& ((FunctionExpression)func.getArguments().get(0))
.getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name()))
);
return isNestedFunction(func.getArguments().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@

package org.opensearch.sql.opensearch.storage.script.sort;

import static org.opensearch.sql.analysis.NestedAnalyzer.generatePath;
import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import com.google.common.collect.ImmutableMap;
import java.util.Map;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.NestedSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;

/**
Expand Down Expand Up @@ -53,11 +59,44 @@ public SortBuilder<?> build(Expression expression, Sort.SortOption option) {
return SortBuilders.scoreSort().order(sortOrderMap.get(option.getSortOrder()));
}
return fieldBuild((ReferenceExpression) expression, option);
} else if (isNestedFunction(expression)) {

validateNestedArgs((FunctionExpression) expression);
String orderByName = ((FunctionExpression)expression).getArguments().get(0).toString();
// Generate path if argument not supplied in function.
ReferenceExpression path = ((FunctionExpression)expression).getArguments().size() == 2
MitchellGale marked this conversation as resolved.
Show resolved Hide resolved
? (ReferenceExpression) ((FunctionExpression)expression).getArguments().get(1)
: generatePath(orderByName);
return SortBuilders.fieldSort(orderByName)
.order(sortOrderMap.get(option.getSortOrder()))
.setNestedSort(new NestedSortBuilder(path.toString()));
} else {
throw new IllegalStateException("unsupported expression " + expression.getClass());
}
}

/**
* Validate semantics for arguments in nested function.
* @param nestedFunc Nested function expression.
*/
private void validateNestedArgs(FunctionExpression nestedFunc) {
if (nestedFunc.getArguments().size() > 2) {
throw new IllegalArgumentException(
"nested function supports 2 parameters (field, path) or 1 parameter (field)"
);
}

for (Expression arg : nestedFunc.getArguments()) {
if (!(arg instanceof ReferenceExpression)) {
throw new IllegalArgumentException(
String.format("Illegal nested field name: %s",
arg.toString()
)
);
}
}
}

private FieldSortBuilder fieldBuild(ReferenceExpression ref, Sort.SortOption option) {
return SortBuilders.fieldSort(
OpenSearchTextType.convertTextToKeyword(ref.getAttr(), ref.type()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight;
Expand Down Expand Up @@ -58,6 +59,7 @@
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder;
import org.opensearch.search.sort.NestedSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
Expand Down Expand Up @@ -574,6 +576,76 @@ void only_one_project_should_be_push() {
);
}

@Test
void test_nested_sort_filter_push_down() {
assertEqualsAfterOptimization(
project(
indexScanBuilder(
withFilterPushedDown(QueryBuilders.termQuery("intV", 1)),
withSortPushedDown(
SortBuilders.fieldSort("message.info")
.order(SortOrder.ASC)
.setNestedSort(new NestedSortBuilder("message")))),
DSL.named("intV", DSL.ref("intV", INTEGER))
),
project(
sort(
filter(
relation("schema", table),
DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))
),
Pair.of(
SortOption.DEFAULT_ASC, DSL.nested(DSL.ref("message.info", STRING))
)
),
DSL.named("intV", DSL.ref("intV", INTEGER))
)
);
}

@Test
void test_function_expression_sort_returns_optimized_logical_sort() {
// Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false
assertEqualsAfterOptimization(
sort(
indexScanBuilder(),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.match(DSL.namedArgument("field", literal("message")))
)
),
sort(
relation("schema", table),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.match(DSL.namedArgument("field", literal("message"))
)
)
)
);
}

@Test
void test_non_field_sort_returns_optimized_logical_sort() {
// Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false
assertEqualsAfterOptimization(
sort(
indexScanBuilder(),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.literal("field")
)
),
sort(
relation("schema", table),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.literal("field")
)
)
);
}

@Test
void sort_with_expression_cannot_merge_with_relation() {
assertEqualsAfterOptimization(
Expand Down
Loading