Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support Cast() function #253

Merged
merged 18 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/main/antlr/OpenDistroSqlLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,30 @@ ASC: 'ASC';
BETWEEN: 'BETWEEN';
BY: 'BY';
CASE: 'CASE';
CAST: 'CAST';
CROSS: 'CROSS';
DATETIME: 'DATETIME';
DELETE: 'DELETE';
DESC: 'DESC';
DESCRIBE: 'DESCRIBE';
DISTINCT: 'DISTINCT';
DOUBLE: 'DOUBLE';
ELSE: 'ELSE';
EXISTS: 'EXISTS';
FALSE: 'FALSE';
FLOAT: 'FLOAT';
FROM: 'FROM';
GROUP: 'GROUP';
HAVING: 'HAVING';
IN: 'IN';
INNER: 'INNER';
INT: 'INT';
IS: 'IS';
JOIN: 'JOIN';
LEFT: 'LEFT';
LIKE: 'LIKE';
LIMIT: 'LIMIT';
LONG: 'LONG';
MATCH: 'MATCH';
NATURAL: 'NATURAL';
NOT: 'NOT';
Expand All @@ -79,6 +85,7 @@ REGEXP: 'REGEXP';
RIGHT: 'RIGHT';
SELECT: 'SELECT';
SHOW: 'SHOW';
STRING: 'STRING';
THEN: 'THEN';
TRUE: 'TRUE';
UNION: 'UNION';
Expand Down
13 changes: 11 additions & 2 deletions src/main/antlr/OpenDistroSqlParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ dottedId
| '.' uid
;


// Literals

decimalLiteral
Expand Down Expand Up @@ -297,7 +296,8 @@ functionCall
;

specificFunction
: CASE expression caseFuncAlternative+
: CAST '(' expression AS convertedDataType ')' #dataTypeFunctionCall
| CASE expression caseFuncAlternative+
(ELSE elseArg=functionArg)? END #caseFunctionCall
| CASE caseFuncAlternative+
(ELSE elseArg=functionArg)? END #caseFunctionCall
Expand All @@ -308,6 +308,15 @@ caseFuncAlternative
THEN consequent=functionArg
;

convertedDataType
: typeName=DATETIME
| typeName=INT
| typeName=DOUBLE
| typeName=LONG
| typeName=FLOAT
| typeName=STRING
;

aggregateWindowedFunction
: (AVG | MAX | MIN | SUM)
'(' aggregator=(ALL | DISTINCT)? functionArg ')'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package com.amazon.opendistroforelasticsearch.sql.executor.format;

import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.amazon.opendistroforelasticsearch.sql.domain.Field;
import com.amazon.opendistroforelasticsearch.sql.domain.JoinSelect;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
Expand Down Expand Up @@ -124,6 +125,9 @@ private void loadFromEsState(Query query) {
// Assumption is all indices share the same mapping which is validated in TermFieldRewriter.
Map<String, Map<String, FieldMappingMetaData>> indexMappings = mappings.values().iterator().next();

// if index mappings size is 0 and the expression is a cast: that means that we are casting by alias
// if so, add the original field that was being looked at to the mapping (how?)

/*
* There are three cases regarding type name to consider:
* 1. If the correct type name was given, its typeMapping is retrieved
Expand Down Expand Up @@ -245,9 +249,13 @@ private List<Field> fetchFields(Query query) {
List<Field> groupByFields = select.getGroupBys().isEmpty() ? new ArrayList<>() :
select.getGroupBys().get(0);


for (Field selectField : select.getFields()) {
if (selectField instanceof MethodField && !selectField.isScriptField()) {
groupByFields.add(selectField);
} else if (selectField.isScriptField()
&& selectField.getAlias().equals(groupByFields.get(0).getName())) {
return select.getFields();
}
}
return groupByFields;
Expand Down Expand Up @@ -315,7 +323,10 @@ private Schema.Type fetchMethodReturnType(Field field) {
// TODO: return type information is disconnected from the function definitions in SQLFunctions.
// Refactor SQLFunctions to have functions self-explanatory (types, scripts) and pluggable
// (similar to Strategy pattern)

if (field.getExpression() instanceof SQLCastExpr) {
return SQLFunctions.getCastFunctionReturnType(
((SQLCastExpr) field.getExpression()).getDataType().getName());
}
return SQLFunctions.getScriptFunctionReturnType(
((ScriptMethodField) field).getFunctionName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLCastExpr;
import com.amazon.opendistroforelasticsearch.sql.domain.Field;
import com.amazon.opendistroforelasticsearch.sql.domain.KVValue;
import com.amazon.opendistroforelasticsearch.sql.domain.MethodField;
Expand Down Expand Up @@ -95,7 +96,6 @@ public SqlElasticSearchRequestBuilder explain() throws SqlParseException {
setIndicesAndTypes();

setFields(select.getFields());

setWhere(select.getWhere());
setSorts(select.getOrderBys());
setLimit(select.getOffset(), select.getRowCount());
Expand Down Expand Up @@ -152,6 +152,9 @@ public void setFields(List<Field> fields) throws SqlParseException {
MethodField method = (MethodField) field;
if (method.getName().toLowerCase().equals("script")) {
handleScriptField(method);
if (method.getExpression() instanceof SQLCastExpr) {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
includeFields.add(method.getParams().get(0).toString());
}
} else if (method.getName().equalsIgnoreCase("include")) {
for (KVValue kvValue : method.getParams()) {
includeFields.add(kvValue.value.toString());
Expand Down Expand Up @@ -265,8 +268,15 @@ private String getNullOrderString(SQLBinaryOpExpr expr) {

private ScriptSortType getScriptSortType(Order order) {
ScriptSortType scriptSortType;
ScriptMethodField smf = (ScriptMethodField) order.getSortField();
Schema.Type scriptFunctionReturnType = SQLFunctions.getScriptFunctionReturnType(smf.getFunctionName());
Schema.Type scriptFunctionReturnType;
if (order.getSortField().getExpression() instanceof SQLCastExpr) {
scriptFunctionReturnType = SQLFunctions.getCastFunctionReturnType(
((SQLCastExpr) order.getSortField().getExpression()).getDataType().getName());
} else {
ScriptMethodField smf = (ScriptMethodField) order.getSortField();
scriptFunctionReturnType = SQLFunctions.getScriptFunctionReturnType(smf.getFunctionName());
}


// as of now script function return type returns only text and double
switch (scriptFunctionReturnType) {
Expand All @@ -275,9 +285,11 @@ private ScriptSortType getScriptSortType(Order order) {
break;

case DOUBLE:
case FLOAT:
case INTEGER:
case LONG:
scriptSortType = ScriptSortType.NUMBER;
break;

default:
throw new RuntimeException("unknown");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,25 @@ public static Schema.Type getScriptFunctionReturnType(String functionName) {
"The following method is not supported in Schema: %s",
functionName));
}

public static Schema.Type getCastFunctionReturnType(String castType) {
switch (castType) {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
case "FLOAT":
return Schema.Type.FLOAT;
case "DOUBLE":
return Schema.Type.DOUBLE;
case "INT":
return Schema.Type.INTEGER;
case "STRING":
return Schema.Type.TEXT;
case "DATETIME":
return Schema.Type.DATE;
case "LONG":
return Schema.Type.LONG;
default:
throw new UnsupportedOperationException(
StringUtils.format("The following type is not supported by cast(): %s", castType)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.stream.IntStream;

import static com.amazon.opendistroforelasticsearch.sql.esintgtest.TestsConstants.TEST_INDEX_ACCOUNT;
import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.hit;
import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.hitAny;
import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.kvDouble;
import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.kvString;
Expand All @@ -50,6 +51,7 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.hasValue;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isEmptyOrNullString;
import static org.hamcrest.Matchers.not;
Expand Down Expand Up @@ -159,6 +161,113 @@ public void caseChangeWithAggregationTest() throws IOException {
hitAny("/aggregations/UPPER_1/buckets", kvString("/key", equalTo("AMBER"))));
}

@Test
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
public void castIntFieldToDoubleWithoutAliasTest() throws IOException {
String query = "SELECT CAST(age AS DOUBLE) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account limit 1";
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved

final SearchHit[] hits = query(query).getHits();
assertTrue(hits[0].getFields().containsKey("cast_age"));
assertTrue(hits[0].getFields().get("cast_age").getValue() instanceof Double);
}

@Test
public void castIntFieldToDoubleWithAliasTest() throws IOException {
String query = "SELECT CAST(age AS DOUBLE) AS test_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account limit 1";
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved

final SearchHit[] hits = query(query).getHits();
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
assertTrue(hits[0].getFields().containsKey("test_alias"));
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
assertTrue(hits[0].getFields().get("test_alias").getValue() instanceof Double);
}

@Test
public void castIntFieldToStringWithoutAliasTest() throws IOException {
String query = "SELECT CAST(balance AS STRING) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account limit 1";

final SearchHit[] hits = query(query).getHits();
assertTrue(hits[0].getFields().containsKey("cast_balance"));
assertTrue(hits[0].getFields().get("cast_balance").getValue() instanceof String);
}

@Test
public void castIntFieldToStringWithAliasTest() throws IOException {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
String query = "SELECT CAST(balance AS STRING) AS cast_string_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account limit 1";

final SearchHit[] hits = query(query).getHits();
assertTrue(hits[0].getFields().containsKey("cast_string_alias"));
assertTrue(hits[0].getFields().get("cast_string_alias").getValue() instanceof String);
}

@Test
public void castIntFieldToFloatWithoutAliasJdbcFormatTest() {
JSONObject response = executeJdbcRequest(
"SELECT CAST(balance AS FLOAT) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " limit 1");

String float_type_cast = "{\"name\":\"cast_balance\",\"type\":\"float\"}";
assertEquals(response.getJSONArray("schema").get(0).toString(), float_type_cast);
}

@Test
public void castIntFieldToFloatWithAliasJdbcFormatTest() {
JSONObject response = executeJdbcRequest(
"SELECT CAST(balance AS FLOAT) AS jdbc_float_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " limit 1");

String float_type_cast = "{\"name\":\"jdbc_float_alias\",\"type\":\"float\"}";
assertEquals(response.getJSONArray("schema").get(0).toString(), float_type_cast);
}

@Test
public void castIntFieldToDoubleWithoutAliasOrderByTest() throws IOException {
String query = "SELECT CAST(age AS DOUBLE) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account " +
"ORDER BY age limit 1";

final SearchHit[] hits = query(query).getHits();
assertTrue(hits[0].getFields().containsKey("cast_age"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look into usage of kvInt / kvString in other tests

assertTrue(hits[0].getFields().get("cast_age").getValue() instanceof Double);
}

@Test
public void castIntFieldToDoubleWithAliasOrderByTest() throws IOException {
String query = "SELECT CAST(age AS DOUBLE) AS alias FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " /account " +
"ORDER BY alias limit 1";

final SearchHit[] hits = query(query).getHits();
assertTrue(hits[0].getFields().containsKey("alias"));
assertTrue(hits[0].getFields().get("alias").getValue() instanceof Double);
}

@Test
public void castIntFieldToFloatWithoutAliasJdbcFormatGroupByTest() {
JSONObject response = executeJdbcRequest(
"SELECT CAST(balance AS FLOAT) AS jdbc_float_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " GROUP BY balance LIMIT 5");

String float_type_cast = "{\"name\":\"balance\",\"type\":\"long\"}";
assertEquals(response.getJSONArray("schema").get(0).toString(), float_type_cast);
}

@Test
public void castIntFieldToFloatWithAliasJdbcFormatGroupByTest() {
JSONObject response = executeJdbcRequest(
"SELECT CAST(balance AS FLOAT) AS jdbc_float_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " GROUP BY jdbc_float_alias LIMIT 5");

String float_type_cast = "{\"name\":\"jdbc_float_alias\",\"type\":\"float\"}";
assertEquals(response.getJSONArray("schema").get(0).toString(), float_type_cast);
}

@Test
public void castIntFieldToDoubleWithAliasJdbcFormatGroupByTest() {
davidcui1225 marked this conversation as resolved.
Show resolved Hide resolved
JSONObject response = executeJdbcRequest(
"SELECT CAST(balance AS DOUBLE) AS jdbc_double_alias " +
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " GROUP BY jdbc_double_alias LIMIT 5");

String float_type_cast = "{\"name\":\"jdbc_double_alias\",\"type\":\"double\"}";
assertEquals(response.getJSONArray("schema").get(0).toString(), float_type_cast);
}

@Test
public void concat_ws_field_and_string() throws Exception {
//here is a bug,csv field with spa
Expand Down Expand Up @@ -302,4 +411,8 @@ private SearchHits query(String query) throws IOException {
rsp);
return SearchResponse.fromXContent(parser).getHits();
}

private JSONObject executeJdbcRequest(String query) {
return new JSONObject(executeQuery(query, "jdbc"));
}
}