Skip to content

Commit

Permalink
Allow literal in aggregation (#203) (#1288)
Browse files Browse the repository at this point in the history
* Support constant in aggregation.

Signed-off-by: Yury-Fridlyand <[email protected]>

Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand authored Jan 23, 2023
1 parent ff9ac16 commit b29f4c2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,32 @@ public void noGroupKeyMaxAddMinShouldPass() {
verifyDataRows(response, rows(60));
}

// todo age field should has long type instead of integer type.
@Ignore
@Test
public void noGroupKeyMaxAddLiteralShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT MAX(age) + 1 as add " +
"SELECT MAX(age) + 1 as `add` " +
"FROM %s",
Index.ACCOUNT.getName()));

verifySchema(response, schema("add", "add", "long"));
verifySchema(response, schema("MAX(age) + 1", "add", "long"));
verifyDataRows(response, rows(41));
}

@Ignore("skip this test because the old engine returns an integer instead of a double type")
@Test
public void noGroupKeyAvgOnIntegerShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT AVG(age) as avg " +
"SELECT AVG(age) as `avg` " +
"FROM %s",
Index.BANK.getName()));

verifySchema(response, schema("avg", "avg", "double"));
verifyDataRows(response, rows(34));
verifySchema(response, schema("AVG(age)", "avg", "double"));
verifyDataRows(response, rows(34D));
}

@Test
public void hasGroupKeyAvgOnIntegerShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT gender, AVG(age) as avg " +
"SELECT gender, AVG(age) as `avg` " +
"FROM %s " +
"GROUP BY gender",
Index.BANK.getName()));
Expand Down Expand Up @@ -103,33 +100,30 @@ public void hasGroupKeyMaxAddMinShouldPass() {
rows("f", 60));
}

// todo age field should has long type instead of integer type.
@Ignore
@Test
public void hasGroupKeyMaxAddLiteralShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT gender, MAX(age) + 1 as add " +
"SELECT gender, MAX(age) + 1 as `add` " +
"FROM %s " +
"GROUP BY gender",
Index.ACCOUNT.getName()));

verifySchema(response,
schema("gender", null, "text"),
schema("add", "add", "long"));
schema("MAX(age) + 1", "add", "long"));
verifyDataRows(response,
rows("m", 1),
rows("f", 1));
rows("m", 41),
rows("f", 41));
}

@Ignore("Handled by v2 engine which returns 'name': 'Log(MAX(age) + MIN(age))' instead")
@Test
public void noGroupKeyLogMaxAddMinShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT Log(MAX(age) + MIN(age)) as log " +
"SELECT Log(MAX(age) + MIN(age)) as `log` " +
"FROM %s",
Index.ACCOUNT.getName()));

verifySchema(response, schema("log", "log", "double"));
verifySchema(response, schema("Log(MAX(age) + MIN(age))", "log", "double"));
verifyDataRows(response, rows(4.0943445622221d));
}

Expand All @@ -149,12 +143,10 @@ public void hasGroupKeyLogMaxAddMinShouldPass() {
rows("f", 4.0943445622221d));
}

// todo age field should has long type instead of integer type.
@Ignore
@Test
public void AddLiteralOnGroupKeyShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT gender, age+10, max(balance) as max " +
"SELECT gender, age+10, max(balance) as `max` " +
"FROM %s " +
"WHERE gender = 'm' and age < 22 " +
"GROUP BY gender, age " +
Expand All @@ -163,8 +155,8 @@ public void AddLiteralOnGroupKeyShouldPass() {

verifySchema(response,
schema("gender", null, "text"),
schema("age", "age", "long"),
schema("max", "max", "long"));
schema("age+10", null, "long"),
schema("max(balance)", "max", "long"));
verifyDataRows(response,
rows("m", 30, 49568),
rows("m", 31, 49433));
Expand All @@ -189,8 +181,6 @@ public void logWithAddLiteralOnGroupKeyShouldPass() {
rows("m", 3.4339872044851463d, 49433));
}

// todo max field should has long as type instead of integer type.
@Ignore
@Test
public void logWithAddLiteralOnGroupKeyAndMaxSubtractLiteralShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
Expand All @@ -203,8 +193,8 @@ public void logWithAddLiteralOnGroupKeyAndMaxSubtractLiteralShouldPass() {

verifySchema(response,
schema("gender", null, "text"),
schema("logAge", "logAge", "double"),
schema("max", "max", "long"));
schema("Log(age+10)", "logAge", "double"),
schema("max(balance) - 100", "max", "long"));
verifyDataRows(response,
rows("m", 3.4011973816621555d, 49468),
rows("m", 3.4339872044851463d, 49333));
Expand All @@ -213,38 +203,36 @@ public void logWithAddLiteralOnGroupKeyAndMaxSubtractLiteralShouldPass() {
/**
* The date is in JDBC format.
*/
@Ignore("skip this test due to inconsistency in type in new engine")
@Test
public void groupByDateShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT birthdate, count(*) as count " +
"SELECT birthdate, count(*) as `count` " +
"FROM %s " +
"WHERE age < 30 " +
"GROUP BY birthdate ",
Index.BANK.getName()));

verifySchema(response,
schema("birthdate", null, "date"),
schema("count", "count", "integer"));
schema("birthdate", null, "timestamp"),
schema("count(*)", "count", "integer"));
verifyDataRows(response,
rows("2018-06-23 00:00:00.000", 1));
rows("2018-06-23 00:00:00", 1));
}

@Ignore("skip this test due to inconsistency in type in new engine")
@Test
public void groupByDateWithAliasShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"SELECT birthdate as birth, count(*) as count " +
"SELECT birthdate as birth, count(*) as `count` " +
"FROM %s " +
"WHERE age < 30 " +
"GROUP BY birthdate ",
Index.BANK.getName()));

verifySchema(response,
schema("birth", "birth", "date"),
schema("count", "count", "integer"));
schema("birthdate", "birth", "timestamp"),
schema("count(*)", "count", "integer"));
verifyDataRows(response,
rows("2018-06-23 00:00:00.000", 1));
rows("2018-06-23 00:00:00", 1));
}

@Test
Expand All @@ -256,4 +244,13 @@ public void aggregateCastStatementShouldNotReturnZero() {
verifySchema(response, schema("SUM(CAST(male AS INT))", "male_sum", "integer"));
verifyDataRows(response, rows(4));
}

@Test
public void groupByConstantShouldPass() {
JSONObject response = executeJdbcRequest(String.format(
"select 1 from %s GROUP BY 1", Index.BANK.getName()));

verifySchema(response, schema("1", null, "integer"));
verifyDataRows(response, rows(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.script.Script;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.opensearch.storage.script.ScriptUtils;
import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer;
Expand All @@ -38,7 +39,8 @@ public <T> T build(Expression expression, Function<String, T> fieldBuilder,
if (expression instanceof ReferenceExpression) {
String fieldName = ((ReferenceExpression) expression).getAttr();
return fieldBuilder.apply(ScriptUtils.convertTextToKeyword(fieldName, expression.type()));
} else if (expression instanceof FunctionExpression) {
} else if (expression instanceof FunctionExpression
|| expression instanceof LiteralExpression) {
return scriptBuilder.apply(new Script(
DEFAULT_SCRIPT_TYPE, EXPRESSION_LANG_NAME, serializer.serialize(expression),
emptyMap()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.expression.DSL.named;
import static org.opensearch.sql.expression.DSL.ref;
import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD;
Expand Down Expand Up @@ -68,6 +69,27 @@ void should_build_bucket_with_field() {
asc(named("age", ref("age", INTEGER))))));
}

@Test
void should_build_bucket_with_literal() {
var literal = literal(1);
when(serializer.serialize(literal)).thenReturn("mock-serialize");
assertEquals(
"{\n"
+ " \"terms\" : {\n"
+ " \"script\" : {\n"
+ " \"source\" : \"mock-serialize\",\n"
+ " \"lang\" : \"opensearch_query_expression\"\n"
+ " },\n"
+ " \"missing_bucket\" : true,\n"
+ " \"missing_order\" : \"first\",\n"
+ " \"order\" : \"asc\"\n"
+ " }\n"
+ "}",
buildQuery(
Arrays.asList(
asc(named(literal)))));
}

@Test
void should_build_bucket_with_keyword_field() {
assertEquals(
Expand Down

0 comments on commit b29f4c2

Please sign in to comment.