diff --git a/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java index 28b66526a949..b9d08bbf859d 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java @@ -76,6 +76,8 @@ @SuppressWarnings({"SameParameterValue", "OptionalGetWithoutIsPresent"}) public class CodeGenRunnerTest { + private static final String COL_INVALID_JAVA = "col!Invalid:("; + private static final LogicalSchema META_STORE_SCHEMA = LogicalSchema.builder() .valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT) .valueColumn(ColumnName.of("COL1"), SqlTypes.STRING) @@ -96,6 +98,7 @@ public class CodeGenRunnerTest { .struct() .field("A", SqlTypes.STRING) .build()) + .valueColumn(ColumnName.of(COL_INVALID_JAVA), SqlTypes.BIGINT) .build(); private static final int INT64_INDEX1 = 0; @@ -112,6 +115,7 @@ public class CodeGenRunnerTest { private static final int MAP_INDEX1 = 11; private static final int MAP_INDEX2 = 12; private static final int STRUCT_INDEX = 15; + private static final int INVALID_JAVA_IDENTIFIER_INDEX = 16; private static final Schema STRUCT_SCHEMA = SchemaConverters.sqlToConnectConverter() .toConnectSchema(META_STORE_SCHEMA.findValueColumn("COL15").get().type()); @@ -123,7 +127,8 @@ public class CodeGenRunnerTest { ImmutableMap.of("k1", 4), ImmutableList.of("one", "two"), ImmutableList.of(ImmutableList.of("1", "2"), ImmutableList.of("3")), - new Struct(STRUCT_SCHEMA).put("A", "VALUE")); + new Struct(STRUCT_SCHEMA).put("A", "VALUE"), + (long) INVALID_JAVA_IDENTIFIER_INDEX); @Rule public final ExpectedException expectedException = ExpectedException.none(); @@ -656,6 +661,25 @@ public void shouldHandleMaps() { assertThat(result, is("value1")); } + @Test + public void shouldHandleInvalidJavaIdentifiers() { + // Given: + final Expression expression = analyzeQuery( + "SELECT `" + COL_INVALID_JAVA + "` FROM codegen_test EMIT CHANGES;", + metaStore) + .getSelectExpressions() + .get(0) + .getExpression(); + + // When: + final Object result = codeGenRunner + .buildCodeGenFromParseTree(expression, "math") + .evaluate(genericRow(ONE_ROW)); + + // Then: + assertThat(result, is((long) INVALID_JAVA_IDENTIFIER_INDEX)); + } + @Test public void shouldHandleCaseStatement() { // Given: diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 0222dad98386..dbf7c0b66869 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -15,6 +15,8 @@ package io.confluent.ksql.execution.codegen; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMap.Builder; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.BetweenPredicate; @@ -113,16 +115,23 @@ public ExpressionMetadata buildCodeGenFromParseTree( final List columnIndexes = new ArrayList<>(parameters.size()); final List kudfObjects = new ArrayList<>(parameters.size()); + final Builder fieldToParamName = ImmutableMap.builder(); int index = 0; for (final ParameterType param : parameters) { - parameterNames[index] = param.paramName; + final String paramName = CodeGenUtil.paramName(index); + fieldToParamName.put(param.fieldName, paramName); + parameterNames[index] = paramName; parameterTypes[index] = param.type; columnIndexes.add(schema.valueColumnIndex(param.fieldName).orElse(-1)); kudfObjects.add(param.getKudf()); index++; } - final String javaCode = new SqlToJavaVisitor(schema, functionRegistry).process(expression); + final String javaCode = new SqlToJavaVisitor( + schema, + functionRegistry, + fieldToParamName.build()::get + ).process(expression); final IExpressionEvaluator ee = CompilerFactoryFactory.getDefaultCompilerFactory().newExpressionEvaluator(); @@ -180,7 +189,6 @@ private void addParameter(final Column schemaColumn) { parameters.add(new ParameterType( SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(schemaColumn.type()), schemaColumn.fullName(), - schemaColumn.fullName().replace(".", "_"), ksqlConfig)); } @@ -201,11 +209,10 @@ public Object visitFunctionCall(final FunctionCall node, final Object context) { final UdfFactory holder = functionRegistry.getUdfFactory(functionName); final KsqlFunction function = holder.getFunction(argumentTypes); - final String parameterName = node.getName().name() + "_" + functionNumber; + final String parameterName = CodeGenUtil.functionName(node.getName().name(), functionNumber); parameters.add(new ParameterType( function, parameterName, - parameterName, ksqlConfig)); return null; } @@ -329,34 +336,29 @@ public static final class ParameterType { private final Class type; private final Optional function; - private final String paramName; private final String fieldName; private final KsqlConfig ksqlConfig; private ParameterType( final Class type, final String fieldName, - final String paramName, final KsqlConfig ksqlConfig ) { this( null, Objects.requireNonNull(type, "type"), fieldName, - paramName, ksqlConfig); } private ParameterType( final KsqlFunction function, final String fieldName, - final String paramName, final KsqlConfig ksqlConfig) { this( Objects.requireNonNull(function, "function"), function.getKudfClass(), fieldName, - paramName, ksqlConfig); } @@ -364,13 +366,11 @@ private ParameterType( final KsqlFunction function, final Class type, final String fieldName, - final String paramName, final KsqlConfig ksqlConfig ) { this.function = Optional.ofNullable(function); this.type = Objects.requireNonNull(type, "type"); this.fieldName = Objects.requireNonNull(fieldName, "fieldName"); - this.paramName = Objects.requireNonNull(paramName, "paramName"); this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig"); } @@ -378,10 +378,6 @@ public Class getType() { return type; } - public String getParamName() { - return paramName; - } - public String getFieldName() { return fieldName; } @@ -401,13 +397,12 @@ public boolean equals(final Object o) { final ParameterType that = (ParameterType) o; return Objects.equals(type, that.type) && Objects.equals(function, that.function) - && Objects.equals(paramName, that.paramName) && Objects.equals(fieldName, that.fieldName); } @Override public int hashCode() { - return Objects.hash(type, function, paramName, fieldName); + return Objects.hash(type, function, fieldName); } } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java new file mode 100644 index 000000000000..57d50827d41d --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenUtil.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.codegen; + +public final class CodeGenUtil { + + private static final String PARAM_NAME_PREFIX = "var"; + + private CodeGenUtil() { + } + + public static String paramName(final int index) { + return PARAM_NAME_PREFIX + index; + } + + public static String functionName(final String fun, final int index) { + return fun + "_" + index; + } + +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 5bda9ce36ed6..748e97537593 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -77,6 +77,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import org.apache.commons.lang3.StringEscapeUtils; @@ -128,12 +129,18 @@ public class SqlToJavaVisitor { private final FunctionRegistry functionRegistry; private final ExpressionTypeManager expressionTypeManager; + private final Function fieldToParamName; - public SqlToJavaVisitor(final LogicalSchema schema, final FunctionRegistry functionRegistry) { + public SqlToJavaVisitor( + final LogicalSchema schema, + final FunctionRegistry functionRegistry, + final Function fieldToParamName + ) { this.schema = Objects.requireNonNull(schema, "schema"); this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry"); this.expressionTypeManager = new ExpressionTypeManager(schema, functionRegistry); + this.fieldToParamName = Objects.requireNonNull(fieldToParamName, "fieldToParamName"); } public String process(final Expression expression) { @@ -272,7 +279,7 @@ public Pair visitColumnReference( new KsqlException("Field not found: " + fieldName)); final Schema schema = SQL_TO_CONNECT_SCHEMA_CONVERTER.toConnectSchema(schemaColumn.type()); - return new Pair<>(fieldName.replace(".", "_"), schema); + return new Pair<>(fieldToParamName.apply(fieldName), schema); } @Override @@ -310,7 +317,10 @@ public Pair visitFunctionCall( final Void context) { final String functionName = node.getName().name(); - final String instanceName = functionName + "_" + functionCounter++; + final String instanceName = fieldToParamName.apply( + CodeGenUtil.functionName(functionName, functionCounter++) + ); + final Schema functionReturnSchema = getFunctionReturnSchema(node, functionName); final String javaReturnType = SchemaUtil.getJavaType(functionReturnSchema).getSimpleName(); final String arguments = node.getArguments().stream() diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/sqlpredicate/SqlPredicate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/sqlpredicate/SqlPredicate.java index b90442225b39..703b8f56005e 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/sqlpredicate/SqlPredicate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/sqlpredicate/SqlPredicate.java @@ -17,8 +17,11 @@ import static java.util.Objects.requireNonNull; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMap.Builder; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.CodeGenUtil; import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.codegen.SqlToJavaVisitor; import io.confluent.ksql.execution.expression.tree.Expression; @@ -68,8 +71,12 @@ public SqlPredicate( final Class[] parameterTypes = new Class[parameters.size()]; columnIndexes = new int[parameters.size()]; int index = 0; + + final Builder fieldToParamName = ImmutableMap.builder(); for (final CodeGenRunner.ParameterType param : parameters) { - parameterNames[index] = param.getParamName(); + final String paramName = CodeGenUtil.paramName(index); + fieldToParamName.put(param.getFieldName(), paramName); + parameterNames[index] = paramName; parameterTypes[index] = param.getType(); columnIndexes[index] = schema.valueColumnIndex(param.getFieldName()).orElse(-1); index++; @@ -84,7 +91,8 @@ public SqlPredicate( final String expressionStr = new SqlToJavaVisitor( schema, - functionRegistry + functionRegistry, + fieldToParamName.build()::get ).process(this.filterExpression); ee.cook(expressionStr); diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index e5f84fae16b9..3961b46e448d 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -63,6 +63,7 @@ import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.util.Optional; +import java.util.function.Function; import org.apache.kafka.connect.data.Schema; import org.junit.Before; import org.junit.Rule; @@ -87,7 +88,11 @@ public class SqlToJavaVisitorTest { @Before public void init() { - sqlToJavaVisitor = new SqlToJavaVisitor(SCHEMA, functionRegistry); + sqlToJavaVisitor = new SqlToJavaVisitor( + SCHEMA, + functionRegistry, + fn -> fn.replace(".", "_") + ); } @Test