Skip to content

Commit

Permalink
feat: support implicit casting in UDFs (#4406)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Jan 30, 2020
1 parent 50b25d5 commit 6fc4f72
Show file tree
Hide file tree
Showing 8 changed files with 375 additions and 28 deletions.
30 changes: 23 additions & 7 deletions ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -141,8 +142,21 @@ void addFunction(final T function) {

T getFunction(final List<SqlType> arguments) {
final List<Node> candidates = new ArrayList<>();
getCandidates(arguments, 0, root, candidates, new HashMap<>());

// first try to get the candidates without any implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), false);
final Optional<T> fun = candidates
.stream()
.max(Node::compare)
.map(node -> node.value);

if (fun.isPresent()) {
return fun.get();
}

// if none were found (candidates is empty) try again with
// implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), true);
return candidates
.stream()
.max(Node::compare)
Expand All @@ -155,7 +169,8 @@ private void getCandidates(
final int argIndex,
final Node current,
final List<Node> candidates,
final Map<GenericType, SqlType> reservedGenerics
final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts
) {
if (argIndex == arguments.size()) {
if (current.value != null) {
Expand All @@ -167,9 +182,9 @@ private void getCandidates(
final SqlType arg = arguments.get(argIndex);
for (final Entry<Parameter, Node> candidate : current.children.entrySet()) {
final Map<GenericType, SqlType> reservedCopy = new HashMap<>(reservedGenerics);
if (candidate.getKey().accepts(arg, reservedCopy)) {
if (candidate.getKey().accepts(arg, reservedCopy, allowCasts)) {
final Node node = candidate.getValue();
getCandidates(arguments, argIndex + 1, node, candidates, reservedCopy);
getCandidates(arguments, argIndex + 1, node, candidates, reservedCopy, allowCasts);
}
}
}
Expand Down Expand Up @@ -324,12 +339,13 @@ public int hashCode() {
* @param reservedGenerics a mapping of generics to already reserved types - this map
* will be updated if the parameter is generic to point to the
* current argument for future checks to accept
*
* @param allowCasts whether or not to accept an implicit cast
* @return whether or not this argument can be used as a value for
* this parameter
*/
// CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity
boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reservedGenerics) {
boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts) {
if (argument == null) {
return true;
}
Expand All @@ -338,7 +354,7 @@ boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reserved
return reserveGenerics(type, argument, reservedGenerics);
}

return SchemaUtil.areCompatible(argument, type);
return SchemaUtil.areCompatible(argument, type, allowCasts);
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public final class SchemaConverters {

private static final FunctionToSqlConverter FUNCTION_TO_SQL_CONVERTER = new FunctionToSql();

private static final FunctionToSqlBase FUNCTION_TO_BASE_CONVERTER = new FunctionToSqlBase();

private SchemaConverters() {
}

Expand Down Expand Up @@ -159,6 +161,10 @@ public interface FunctionToSqlConverter {
SqlType toSqlType(ParamType paramType);
}

public interface FunctionToSqlBaseConverter {
SqlBaseType toBaseType(ParamType paramType);
}

public static ConnectToSqlTypeConverter connectToSqlConverter() {
return CONNECT_TO_SQL_CONVERTER;
}
Expand All @@ -183,6 +189,10 @@ public static FunctionToSqlConverter functionToSqlConverter() {
return FUNCTION_TO_SQL_CONVERTER;
}

public static FunctionToSqlBaseConverter functionToSqlBaseConverter() {
return FUNCTION_TO_BASE_CONVERTER;
}

private static final class ConnectToSqlConverter implements ConnectToSqlTypeConverter {

private static final Map<Schema.Type, Function<Schema, SqlType>> CONNECT_TO_SQL = ImmutableMap
Expand Down Expand Up @@ -377,6 +387,41 @@ public SqlType toSqlType(final ParamType paramType) {
}
}

private static class FunctionToSqlBase implements FunctionToSqlBaseConverter {

private static final BiMap<ParamType, SqlBaseType> FUNCTION_TO_BASE =
ImmutableBiMap.<ParamType, SqlBaseType>builder()
.put(ParamTypes.STRING, SqlBaseType.STRING)
.put(ParamTypes.BOOLEAN, SqlBaseType.BOOLEAN)
.put(ParamTypes.INTEGER, SqlBaseType.INTEGER)
.put(ParamTypes.LONG, SqlBaseType.BIGINT)
.put(ParamTypes.DOUBLE, SqlBaseType.DOUBLE)
.put(ParamTypes.DECIMAL, SqlBaseType.DECIMAL)
.build();

@Override
public SqlBaseType toBaseType(final ParamType paramType) {
final SqlBaseType sqlType = FUNCTION_TO_BASE.get(paramType);
if (sqlType != null) {
return sqlType;
}

if (paramType instanceof MapType) {
return SqlBaseType.MAP;
}

if (paramType instanceof ArrayType) {
return SqlBaseType.ARRAY;
}

if (paramType instanceof StructType) {
return SqlBaseType.STRUCT;
}

throw new KsqlException("Cannot convert param type to sql type: " + paramType);
}
}

private static class SqlToFunction implements SqlToFunctionConverter {

@Override
Expand Down
41 changes: 31 additions & 10 deletions ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package io.confluent.ksql.util;

import static io.confluent.ksql.schema.ksql.SchemaConverters.functionToSqlBaseConverter;

import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.BooleanType;
Expand Down Expand Up @@ -152,22 +154,34 @@ public static Schema ensureOptional(final Schema schema) {
}

public static boolean areCompatible(final SqlType actual, final ParamType declared) {
return areCompatible(actual, declared, false);
}

public static boolean areCompatible(
final SqlType actual,
final ParamType declared,
final boolean allowCast
) {
if (actual.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) {
return areCompatible(((SqlArray) actual).getItemType(), ((ArrayType) declared).element());
return areCompatible(
((SqlArray) actual).getItemType(),
((ArrayType) declared).element(),
allowCast);
}

if (actual.baseType() == SqlBaseType.MAP && declared instanceof MapType) {
return areCompatible(
((SqlMap) actual).getValueType(),
((MapType) declared).value()
((MapType) declared).value(),
allowCast
);
}

if (actual.baseType() == SqlBaseType.STRUCT && declared instanceof StructType) {
return isStructCompatible(actual, declared);
}

return isPrimitiveMatch(actual, declared);
return isPrimitiveMatch(actual, declared, allowCast);
}

private static boolean isStructCompatible(final SqlType actual, final ParamType declared) {
Expand All @@ -181,6 +195,7 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType
for (final Entry<String, ParamType> entry : ((StructType) declared).getSchema().entrySet()) {
final String k = entry.getKey();
final Optional<io.confluent.ksql.schema.ksql.types.Field> field = actualStruct.field(k);
// intentionally do not allow implicit casting within structs
if (!field.isPresent() || !areCompatible(field.get().type(), entry.getValue())) {
return false;
}
Expand All @@ -189,15 +204,21 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
private static boolean isPrimitiveMatch(final SqlType actual, final ParamType declared) {
private static boolean isPrimitiveMatch(
final SqlType actual,
final ParamType declared,
final boolean allowCast
) {
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
// CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity
return actual.baseType() == SqlBaseType.STRING && declared instanceof StringType
|| actual.baseType() == SqlBaseType.INTEGER && declared instanceof IntegerType
|| actual.baseType() == SqlBaseType.BIGINT && declared instanceof LongType
|| actual.baseType() == SqlBaseType.BOOLEAN && declared instanceof BooleanType
|| actual.baseType() == SqlBaseType.DOUBLE && declared instanceof DoubleType
|| actual.baseType() == SqlBaseType.DECIMAL && declared instanceof DecimalType;
final SqlBaseType base = actual.baseType();
return base == SqlBaseType.STRING && declared instanceof StringType
|| base == SqlBaseType.INTEGER && declared instanceof IntegerType
|| base == SqlBaseType.BIGINT && declared instanceof LongType
|| base == SqlBaseType.BOOLEAN && declared instanceof BooleanType
|| base == SqlBaseType.DOUBLE && declared instanceof DoubleType
|| base == SqlBaseType.DECIMAL && declared instanceof DecimalType
|| allowCast && base.canImplicitlyCast(functionToSqlBaseConverter().toBaseType(declared));
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public class UdfIndexTest {
private static final ParamType STRING = ParamTypes.STRING;
private static final ParamType DECIMAL = ParamTypes.DECIMAL;
private static final ParamType INT = ParamTypes.INTEGER;
private static final ParamType LONG = ParamTypes.LONG;
private static final ParamType DOUBLE = ParamTypes.DOUBLE;
private static final ParamType STRUCT1 = StructType.builder().field("a", STRING).build();
private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build();
private static final ParamType MAP1 = MapType.of(STRING);
Expand Down Expand Up @@ -88,6 +90,37 @@ public void shouldFindOneArg() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindOneArgWithCast() {
// Given:
final KsqlScalarFunction[] functions = new KsqlScalarFunction[]{
function(EXPECTED, false, LONG)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(SqlTypes.INTEGER));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindPreferredOneArgWithCast() {
// Given:
final KsqlScalarFunction[] functions = new KsqlScalarFunction[]{
function(OTHER, false, LONG),
function(EXPECTED, false, INT),
function(OTHER, false, DOUBLE)
};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(SqlTypes.INTEGER));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindTwoDifferentArgs() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,27 @@ public void shouldPassCompatibleSchemas() {
SqlTypes.map(SqlTypes.decimal(1, 1)),
MapType.of(ParamTypes.DECIMAL)),
is(true));
}

@Test
public void shouldPassCompatibleSchemasWithImplicitCasting() {
assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.LONG, true), is(true));
assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.DOUBLE, true), is(true));
assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.DECIMAL, true), is(true));

assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.DOUBLE, true), is(true));
assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.DECIMAL, true), is(true));

assertThat(SchemaUtil.areCompatible(SqlTypes.decimal(2, 1), ParamTypes.DOUBLE, true), is(true));
}

@Test
public void shouldNotPassInCompatibleSchemasWithImplicitCasting() {
assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.INTEGER, true), is(false));

assertThat(SchemaUtil.areCompatible(SqlTypes.DOUBLE, ParamTypes.LONG, true), is(false));

assertThat(SchemaUtil.areCompatible(SqlTypes.DOUBLE, ParamTypes.DECIMAL, true), is(false));
}

@Test
Expand Down
Loading

0 comments on commit 6fc4f72

Please sign in to comment.