Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Fix CAST bool field to integer issue #600

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class SQLFunctionsIT extends SQLIntegTestCase {
@Override
protected void init() throws Exception {
loadIndex(Index.ACCOUNT);
loadIndex(Index.BANK);
loadIndex(Index.ONLINE);
loadIndex(Index.DATE);
}
Expand Down Expand Up @@ -369,6 +370,54 @@ public void castFieldToDatetimeWithGroupByJdbcFormatTest() {
rows("2019-09-25T02:04:13.469Z"));
}

@Test
public void castBoolFieldToNumericValueInSelectClause() {
JSONObject response =
executeJdbcRequest(
"SELECT "
+ " male, "
+ " CAST(male AS INT) AS cast_int, "
+ " CAST(male AS LONG) AS cast_long, "
+ " CAST(male AS FLOAT) AS cast_float, "
+ " CAST(male AS DOUBLE) AS cast_double "
+ "FROM " + TestsConstants.TEST_INDEX_BANK + " "
+ "WHERE account_number = 1 OR account_number = 13"
);

verifySchema(response,
schema("male", "boolean"),
schema("cast_int", "integer"),
schema("cast_long", "long"),
schema("cast_float", "float"),
schema("cast_double", "double")
);
verifyDataRows(response,
rows(true, 1, 1, 1, 1),
rows(false, 0, 0, 0, 0)
);
}

@Test
public void castBoolFieldToNumericValueWithGroupByAlias() {
JSONObject response =
executeJdbcRequest(
"SELECT "
+ "CAST(male AS INT) AS cast_int, "
+ "COUNT(*) "
+ "FROM " + TestsConstants.TEST_INDEX_BANK + " "
+ "GROUP BY cast_int"
);

verifySchema(response,
schema("cast_int", "cast_int", "double"), //Type is double due to query plan fail to infer
schema("COUNT(*)", "integer")
);
verifyDataRows(response,
rows("0", 3),
rows("1", 4)
);
}

@Test
public void castStatementInWhereClauseGreaterThanTest() {
JSONObject response = executeJdbcRequest("SELECT balance FROM " + TEST_INDEX_ACCOUNT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public static <T> void verifyOrder(JSONArray array, Matcher<T>... matchers) {
assertThat(objects, containsInRelativeOrder(matchers));
}

public static TypeSafeMatcher<JSONObject> schema(String expectedName,
String expectedType) {
return schema(expectedName, null, expectedType);
}

public static TypeSafeMatcher<JSONObject> schema(String expectedName, String expectedAlias,
String expectedType) {
return new TypeSafeMatcher<JSONObject>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,13 +973,10 @@ public String getCastScriptStatement(String name, String castType, List<KVValue>
String castFieldName = String.format("doc['%s'].value", paramers.get(0).toString());
switch (StringUtils.toUpper(castType)) {
case "INT":
return String.format("def %s = Double.parseDouble(%s.toString()).intValue()", name, castFieldName);
case "LONG":
return String.format("def %s = Double.parseDouble(%s.toString()).longValue()", name, castFieldName);
case "FLOAT":
return String.format("def %s = Double.parseDouble(%s.toString()).floatValue()", name, castFieldName);
case "DOUBLE":
return String.format("def %s = Double.parseDouble(%s.toString()).doubleValue()", name, castFieldName);
return getCastToNumericValueScript(name, castFieldName, StringUtils.toLower(castType));
case "STRING":
return String.format("def %s = %s.toString()", name, castFieldName);
case "DATETIME":
Expand All @@ -990,6 +987,14 @@ public String getCastScriptStatement(String name, String castType, List<KVValue>
}
}

private String getCastToNumericValueScript(String varName, String docValue, String targetType) {
String script =
"def %1$s = (%2$s instanceof boolean) "
+ "? (%2$s ? 1 : 0) "
+ ": Double.parseDouble(%2$s.toString()).%3$sValue()";
return StringUtils.format(script, varName, docValue, targetType);
}

/**
* Returns return type of script function. This is simple approach, that might be not the best solution in the long
* term. For example - for JDBC, if the column type in index is INTEGER, and the query is "select column+5", current
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.amazon.opendistroforelasticsearch.sql.legacy.executor.format.Schema;
import com.amazon.opendistroforelasticsearch.sql.legacy.utils.SQLFunctions;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import org.elasticsearch.common.collect.Tuple;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -39,11 +40,12 @@
import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class SQLFunctionsTest {

private SQLFunctions sqlFunctions;
private SQLFunctions sqlFunctions = new SQLFunctions();

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -96,4 +98,16 @@ public void testCastReturnType() {
final Schema.Type returnType = sqlFunctions.getScriptFunctionReturnType(field, resolvedType);
Assert.assertEquals(returnType, Schema.Type.INTEGER);
}

@Test
public void testCastIntStatementScript() throws SqlParseException {
assertEquals(
"def result = (doc['age'].value instanceof boolean) "
+ "? (doc['age'].value ? 1 : 0) "
+ ": Double.parseDouble(doc['age'].value.toString()).intValue()",
sqlFunctions.getCastScriptStatement(
"result", "int", Arrays.asList(new KVValue("age")))
);
}

}