Skip to content

Commit

Permalink
feat: use coherent naming scheme for generated java code (#3417)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Sep 28, 2019
1 parent 5f31309 commit 06a2a57
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
@SuppressWarnings({"SameParameterValue", "OptionalGetWithoutIsPresent"})
public class CodeGenRunnerTest {

private static final String COL_INVALID_JAVA = "col!Invalid:(";

private static final LogicalSchema META_STORE_SCHEMA = LogicalSchema.builder()
.valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT)
.valueColumn(ColumnName.of("COL1"), SqlTypes.STRING)
Expand All @@ -96,6 +98,7 @@ public class CodeGenRunnerTest {
.struct()
.field("A", SqlTypes.STRING)
.build())
.valueColumn(ColumnName.of(COL_INVALID_JAVA), SqlTypes.BIGINT)
.build();

private static final int INT64_INDEX1 = 0;
Expand All @@ -112,6 +115,7 @@ public class CodeGenRunnerTest {
private static final int MAP_INDEX1 = 11;
private static final int MAP_INDEX2 = 12;
private static final int STRUCT_INDEX = 15;
private static final int INVALID_JAVA_IDENTIFIER_INDEX = 16;

private static final Schema STRUCT_SCHEMA = SchemaConverters.sqlToConnectConverter()
.toConnectSchema(META_STORE_SCHEMA.findValueColumn("COL15").get().type());
Expand All @@ -123,7 +127,8 @@ public class CodeGenRunnerTest {
ImmutableMap.of("k1", 4),
ImmutableList.of("one", "two"),
ImmutableList.of(ImmutableList.of("1", "2"), ImmutableList.of("3")),
new Struct(STRUCT_SCHEMA).put("A", "VALUE"));
new Struct(STRUCT_SCHEMA).put("A", "VALUE"),
(long) INVALID_JAVA_IDENTIFIER_INDEX);

@Rule
public final ExpectedException expectedException = ExpectedException.none();
Expand Down Expand Up @@ -656,6 +661,25 @@ public void shouldHandleMaps() {
assertThat(result, is("value1"));
}

@Test
public void shouldHandleInvalidJavaIdentifiers() {
// Given:
final Expression expression = analyzeQuery(
"SELECT `" + COL_INVALID_JAVA + "` FROM codegen_test EMIT CHANGES;",
metaStore)
.getSelectExpressions()
.get(0)
.getExpression();

// When:
final Object result = codeGenRunner
.buildCodeGenFromParseTree(expression, "math")
.evaluate(genericRow(ONE_ROW));

// Then:
assertThat(result, is((long) INVALID_JAVA_IDENTIFIER_INDEX));
}

@Test
public void shouldHandleCaseStatement() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package io.confluent.ksql.execution.codegen;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
Expand Down Expand Up @@ -113,16 +115,23 @@ public ExpressionMetadata buildCodeGenFromParseTree(
final List<Integer> columnIndexes = new ArrayList<>(parameters.size());
final List<Kudf> kudfObjects = new ArrayList<>(parameters.size());

final Builder<String, String> fieldToParamName = ImmutableMap.<String, String>builder();
int index = 0;
for (final ParameterType param : parameters) {
parameterNames[index] = param.paramName;
final String paramName = CodeGenUtil.paramName(index);
fieldToParamName.put(param.fieldName, paramName);
parameterNames[index] = paramName;
parameterTypes[index] = param.type;
columnIndexes.add(schema.valueColumnIndex(param.fieldName).orElse(-1));
kudfObjects.add(param.getKudf());
index++;
}

final String javaCode = new SqlToJavaVisitor(schema, functionRegistry).process(expression);
final String javaCode = new SqlToJavaVisitor(
schema,
functionRegistry,
fieldToParamName.build()::get
).process(expression);

final IExpressionEvaluator ee =
CompilerFactoryFactory.getDefaultCompilerFactory().newExpressionEvaluator();
Expand Down Expand Up @@ -180,7 +189,6 @@ private void addParameter(final Column schemaColumn) {
parameters.add(new ParameterType(
SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(schemaColumn.type()),
schemaColumn.fullName(),
schemaColumn.fullName().replace(".", "_"),
ksqlConfig));
}

Expand All @@ -201,11 +209,10 @@ public Object visitFunctionCall(final FunctionCall node, final Object context) {

final UdfFactory holder = functionRegistry.getUdfFactory(functionName);
final KsqlFunction function = holder.getFunction(argumentTypes);
final String parameterName = node.getName().name() + "_" + functionNumber;
final String parameterName = CodeGenUtil.functionName(node.getName().name(), functionNumber);
parameters.add(new ParameterType(
function,
parameterName,
parameterName,
ksqlConfig));
return null;
}
Expand Down Expand Up @@ -329,59 +336,48 @@ public static final class ParameterType {

private final Class type;
private final Optional<KsqlFunction> function;
private final String paramName;
private final String fieldName;
private final KsqlConfig ksqlConfig;

private ParameterType(
final Class type,
final String fieldName,
final String paramName,
final KsqlConfig ksqlConfig
) {
this(
null,
Objects.requireNonNull(type, "type"),
fieldName,
paramName,
ksqlConfig);
}

private ParameterType(
final KsqlFunction function,
final String fieldName,
final String paramName,
final KsqlConfig ksqlConfig) {
this(
Objects.requireNonNull(function, "function"),
function.getKudfClass(),
fieldName,
paramName,
ksqlConfig);
}

private ParameterType(
final KsqlFunction function,
final Class type,
final String fieldName,
final String paramName,
final KsqlConfig ksqlConfig
) {
this.function = Optional.ofNullable(function);
this.type = Objects.requireNonNull(type, "type");
this.fieldName = Objects.requireNonNull(fieldName, "fieldName");
this.paramName = Objects.requireNonNull(paramName, "paramName");
this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig");
}

public Class getType() {
return type;
}

public String getParamName() {
return paramName;
}

public String getFieldName() {
return fieldName;
}
Expand All @@ -401,13 +397,12 @@ public boolean equals(final Object o) {
final ParameterType that = (ParameterType) o;
return Objects.equals(type, that.type)
&& Objects.equals(function, that.function)
&& Objects.equals(paramName, that.paramName)
&& Objects.equals(fieldName, that.fieldName);
}

@Override
public int hashCode() {
return Objects.hash(type, function, paramName, fieldName);
return Objects.hash(type, function, fieldName);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.execution.codegen;

public final class CodeGenUtil {

private static final String PARAM_NAME_PREFIX = "var";

private CodeGenUtil() {
}

public static String paramName(final int index) {
return PARAM_NAME_PREFIX + index;
}

public static String functionName(final String fun, final int index) {
return fun + "_" + index;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringEscapeUtils;
Expand Down Expand Up @@ -128,12 +129,18 @@ public class SqlToJavaVisitor {
private final FunctionRegistry functionRegistry;

private final ExpressionTypeManager expressionTypeManager;
private final Function<String, String> fieldToParamName;

public SqlToJavaVisitor(final LogicalSchema schema, final FunctionRegistry functionRegistry) {
public SqlToJavaVisitor(
final LogicalSchema schema,
final FunctionRegistry functionRegistry,
final Function<String, String> fieldToParamName
) {
this.schema = Objects.requireNonNull(schema, "schema");
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.expressionTypeManager =
new ExpressionTypeManager(schema, functionRegistry);
this.fieldToParamName = Objects.requireNonNull(fieldToParamName, "fieldToParamName");
}

public String process(final Expression expression) {
Expand Down Expand Up @@ -272,7 +279,7 @@ public Pair<String, Schema> visitColumnReference(
new KsqlException("Field not found: " + fieldName));

final Schema schema = SQL_TO_CONNECT_SCHEMA_CONVERTER.toConnectSchema(schemaColumn.type());
return new Pair<>(fieldName.replace(".", "_"), schema);
return new Pair<>(fieldToParamName.apply(fieldName), schema);
}

@Override
Expand Down Expand Up @@ -310,7 +317,10 @@ public Pair<String, Schema> visitFunctionCall(
final Void context) {
final String functionName = node.getName().name();

final String instanceName = functionName + "_" + functionCounter++;
final String instanceName = fieldToParamName.apply(
CodeGenUtil.functionName(functionName, functionCounter++)
);

final Schema functionReturnSchema = getFunctionReturnSchema(node, functionName);
final String javaReturnType = SchemaUtil.getJavaType(functionReturnSchema).getSimpleName();
final String arguments = node.getArguments().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMap.Builder;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.CodeGenUtil;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.codegen.SqlToJavaVisitor;
import io.confluent.ksql.execution.expression.tree.Expression;
Expand Down Expand Up @@ -68,8 +71,12 @@ public SqlPredicate(
final Class[] parameterTypes = new Class[parameters.size()];
columnIndexes = new int[parameters.size()];
int index = 0;

final Builder<String, String> fieldToParamName = ImmutableMap.<String, String>builder();
for (final CodeGenRunner.ParameterType param : parameters) {
parameterNames[index] = param.getParamName();
final String paramName = CodeGenUtil.paramName(index);
fieldToParamName.put(param.getFieldName(), paramName);
parameterNames[index] = paramName;
parameterTypes[index] = param.getType();
columnIndexes[index] = schema.valueColumnIndex(param.getFieldName()).orElse(-1);
index++;
Expand All @@ -84,7 +91,8 @@ public SqlPredicate(

final String expressionStr = new SqlToJavaVisitor(
schema,
functionRegistry
functionRegistry,
fieldToParamName.build()::get
).process(this.filterExpression);

ee.cook(expressionStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import java.util.Optional;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -87,7 +88,11 @@ public class SqlToJavaVisitorTest {

@Before
public void init() {
sqlToJavaVisitor = new SqlToJavaVisitor(SCHEMA, functionRegistry);
sqlToJavaVisitor = new SqlToJavaVisitor(
SCHEMA,
functionRegistry,
fn -> fn.replace(".", "_")
);
}

@Test
Expand Down

0 comments on commit 06a2a57

Please sign in to comment.