diff --git a/ksql-engine/src/main/java/io/confluent/ksql/engine/InsertValuesExecutor.java b/ksql-engine/src/main/java/io/confluent/ksql/engine/InsertValuesExecutor.java index c101483a9a82..03d58d634ab0 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/engine/InsertValuesExecutor.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/engine/InsertValuesExecutor.java @@ -16,15 +16,17 @@ package io.confluent.ksql.engine; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; import com.google.common.collect.Streams; import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException; import io.confluent.ksql.GenericRow; import io.confluent.ksql.KsqlExecutionContext; import io.confluent.ksql.exception.KsqlTopicAuthorizationException; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.Literal; -import io.confluent.ksql.execution.expression.tree.NullLiteral; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; +import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.NoopProcessingLogContext; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; @@ -63,6 +65,7 @@ import java.util.concurrent.Future; import java.util.function.LongSupplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.http.HttpStatus; import org.apache.kafka.clients.producer.Producer; @@ -190,7 +193,12 @@ private ProducerRecord buildRecord( } try { - final RowData row = extractRow(insertValues, dataSource); + final RowData row = extractRow( + insertValues, + dataSource, + executionContext.getMetaStore(), + config); + final byte[] key = serializeKey(row.key, dataSource, config, serviceContext); final byte[] value = serializeValue(row.value, dataSource, config, serviceContext); @@ -226,7 +234,9 @@ private void throwIfDisabled(final KsqlConfig config) { private RowData extractRow( final InsertValues insertValues, - final DataSource dataSource + final DataSource dataSource, + final FunctionRegistry functionRegistry, + final KsqlConfig config ) { final List columns = insertValues.getColumns().isEmpty() ? implicitColumns(dataSource, insertValues.getValues()) @@ -234,7 +244,8 @@ private RowData extractRow( final LogicalSchema schema = dataSource.getSchema(); - final Map values = resolveValues(insertValues, columns, schema); + final Map values = resolveValues( + insertValues, columns, schema, functionRegistry, config); handleExplicitKeyField(values, dataSource.getKeyField()); @@ -306,7 +317,9 @@ private static List implicitColumns( private static Map resolveValues( final InsertValues insertValues, final List columns, - final LogicalSchema schema + final LogicalSchema schema, + final FunctionRegistry functionRegistry, + final KsqlConfig config ) { final Map values = new HashMap<>(); for (int i = 0; i < columns.size(); i++) { @@ -314,7 +327,8 @@ private static Map resolveValues( final SqlType columnType = columnType(column, schema); final Expression valueExp = insertValues.getValues().get(i); - final Object value = new ExpressionResolver(columnType, column) + final Object value = + new ExpressionResolver(columnType, column, schema, functionRegistry, config) .process(valueExp, null); values.put(column, value); @@ -482,26 +496,39 @@ private static class ExpressionResolver extends VisitParentExpressionVisitor { diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java index 742b0fa5aa02..6eff1494a878 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java @@ -29,9 +29,11 @@ import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; +import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.BooleanLiteral; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.LongLiteral; import io.confluent.ksql.execution.expression.tree.StringLiteral; @@ -43,6 +45,7 @@ import io.confluent.ksql.metastore.model.KsqlStream; import io.confluent.ksql.metastore.model.KsqlTable; import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; import io.confluent.ksql.parser.tree.InsertValues; @@ -50,6 +53,8 @@ import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.PersistenceSchema; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; @@ -101,6 +106,15 @@ public class InsertValuesExecutorTest { .valueColumn(COL0, SqlTypes.STRING) .build(); + private static final LogicalSchema SINGLE_ARRAY_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("COL0"), SqlArray.of(SqlTypes.INTEGER)) + .build(); + + private static final LogicalSchema SINGLE_MAP_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("COL0"), SqlMap.of(SqlTypes.INTEGER)) + .build(); + + private static final LogicalSchema SCHEMA = LogicalSchema.builder() .valueColumn(COL0, SqlTypes.STRING) .valueColumn(ColumnName.of("COL1"), SqlTypes.BIGINT) @@ -434,6 +448,74 @@ public void shouldHandleNullKeyForSourceWithKeyField() { verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); } + @Test + public void shouldHandleNegativeValueExpression() { + // Given: + givenSourceStreamWithSchema(SCHEMA, SerdeOption.none(), Optional.of(ColumnName.of("COL0"))); + + final ConfiguredStatement statement = givenInsertValuesStrings( + ImmutableList.of("COL0", "COL1"), + ImmutableList.of( + new StringLiteral("str"), + ArithmeticUnaryExpression.negative(Optional.empty(), new LongLiteral(1)) + ) + ); + + // When: + executor.execute(statement, ImmutableMap.of(), engine, serviceContext); + + // Then: + verify(keySerializer).serialize(TOPIC_NAME, keyStruct("str")); + verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of("str", -1L))); + verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); + } + + @Test + public void shouldHandleUdfs() { + // Given: + givenSourceStreamWithSchema(SINGLE_ARRAY_SCHEMA, SerdeOption.none(), Optional.empty()); + + final ConfiguredStatement statement = givenInsertValuesStrings( + ImmutableList.of("COL0"), + ImmutableList.of( + new FunctionCall( + FunctionName.of("AS_ARRAY"), + ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2)))) + ); + + // When: + executor.execute(statement, ImmutableMap.of(), engine, serviceContext); + + // Then: + verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of(ImmutableList.of(1, 2)))); + verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); + } + + @Test + public void shouldHandleNestedUdfs() { + // Given: + givenSourceStreamWithSchema(SINGLE_MAP_SCHEMA, SerdeOption.none(), Optional.empty()); + + final ConfiguredStatement statement = givenInsertValuesStrings( + ImmutableList.of("COL0"), + ImmutableList.of( + new FunctionCall( + FunctionName.of("AS_MAP"), + ImmutableList.of( + new FunctionCall(FunctionName.of("AS_ARRAY"), ImmutableList.of(new StringLiteral("foo"))), + new FunctionCall(FunctionName.of("AS_ARRAY"), ImmutableList.of(new IntegerLiteral(1))) + )) + ) + ); + + // When: + executor.execute(statement, ImmutableMap.of(), engine, serviceContext); + + // Then: + verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of(ImmutableMap.of("foo", 1)))); + verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); + } + @Test public void shouldAllowUpcast() { // Given: diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json index 8f2df55d6e58..5cfdc22e3c77 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json @@ -245,6 +245,18 @@ "outputs": [ {"topic": "test_topic", "key": null, "value": {"I": 1, "BI": 2, "D": 3.0}} ] + }, + { + "name": "should handle arbitrary expressions", + "statements": [ + "CREATE STREAM TEST (I INT, A ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "INSERT INTO TEST (I, A) VALUES (-1, AS_ARRAY(1, 1 + 1, 3));" + ], + "inputs": [ + ], + "outputs": [ + {"topic": "test_topic", "key": null, "value": {"I": -1, "A": [1, 2, 3]}} + ] } ] } \ No newline at end of file diff --git a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 index 6d1b873da7db..c9947dc59db2 100644 --- a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 +++ b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 @@ -158,7 +158,7 @@ groupingExpressions ; values - : '(' (literal (',' literal)*)? ')' + : '(' (valueExpression (',' valueExpression)*)? ')' ; /* diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index f1eb4fe0e5f8..bda4a1407cc0 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -352,7 +352,7 @@ public Node visitInsertValues(final InsertValuesContext context) { targetLocation, SourceName.of(targetName), columns, - visit(context.values().literal(), Expression.class)); + visit(context.values().valueExpression(), Expression.class)); } @Override diff --git a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java index 9dc2c2be6caf..ba898bee1621 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercer.java @@ -15,13 +15,19 @@ package io.confluent.ksql.schema.ksql; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlException; import java.math.BigDecimal; import java.math.MathContext; import java.math.RoundingMode; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Function; @@ -35,23 +41,30 @@ public final class DefaultSqlValueCoercer implements SqlValueCoercer { .put(SqlBaseType.DOUBLE, Number::doubleValue) .build(); - public Optional coerce(final Object value, final SqlType targetType) { - if (targetType.baseType() == SqlBaseType.ARRAY - || targetType.baseType() == SqlBaseType.MAP - || targetType.baseType() == SqlBaseType.STRUCT - ) { - throw new KsqlException("Unsupported SQL type: " + targetType.baseType()); + @Override + public Optional coerce(final Object value, final SqlType targetType) { + return doCoerce(value, targetType); + } + + private static Optional doCoerce(final Object value, final SqlType targetType) { + switch (targetType.baseType()) { + case DECIMAL: + return coerceDecimal(value, (SqlDecimal) targetType); + case ARRAY: + return coerceArray(value, (SqlArray) targetType); + case MAP: + return coerceMap(value, (SqlMap) targetType); + case STRUCT: + throw new KsqlException("Unsupported SQL type: " + targetType.baseType()); + default: + break; } final SqlBaseType valueSqlType = SchemaConverters.javaToSqlConverter() .toSqlType(value.getClass()); if (valueSqlType.equals(targetType.baseType())) { - return optional(value); - } - - if (targetType.baseType() == SqlBaseType.DECIMAL) { - return coerceDecimal(value, (SqlDecimal) targetType); + return Optional.of(value); } if (!(value instanceof Number) || !valueSqlType.canImplicitlyCast(targetType.baseType())) { @@ -59,16 +72,53 @@ public Optional coerce(final Object value, final SqlType targetType) { } final Number result = UPCASTER.get(targetType.baseType()).apply((Number) value); - return optional(result); + return Optional.of(result); + } + + private static Optional coerceArray(final Object value, final SqlArray targetType) { + if (!(value instanceof List)) { + return Optional.empty(); + } + + final List list = (List) value; + final ImmutableList.Builder coerced = ImmutableList.builder(); + for (final Object el : list) { + final Optional coercedEl = doCoerce(el, targetType.getItemType()); + if (!coercedEl.isPresent()) { + return Optional.empty(); + } + coerced.add(coercedEl.get()); + } + + return Optional.of(coerced.build()); } - private static Optional coerceDecimal(final Object value, final SqlDecimal targetType) { + private static Optional coerceMap(final Object value, final SqlMap targetType) { + if (!(value instanceof Map)) { + return Optional.empty(); + } + + final Map map = (Map) value; + final HashMap coerced = new HashMap<>(); + for (final Map.Entry entry : map.entrySet()) { + final Optional coercedKey = doCoerce(entry.getKey(), SqlTypes.STRING); + final Optional coercedValue = doCoerce(entry.getValue(), targetType.getValueType()); + if (!coercedKey.isPresent() || !coercedValue.isPresent()) { + return Optional.empty(); + } + coerced.put(coercedKey.get(), coercedValue.get()); + } + + return Optional.of(coerced); + } + + private static Optional coerceDecimal(final Object value, final SqlDecimal targetType) { final int precision = targetType.getPrecision(); final int scale = targetType.getScale(); if (value instanceof String) { try { - return optional(new BigDecimal((String) value, new MathContext(precision)) + return Optional.of(new BigDecimal((String) value, new MathContext(precision)) .setScale(scale, RoundingMode.UNNECESSARY)); } catch (final NumberFormatException e) { throw new KsqlException("Cannot coerce value to DECIMAL: " + value, e); @@ -76,7 +126,7 @@ private static Optional coerceDecimal(final Object value, final SqlDecima } if (value instanceof Number && !(value instanceof Double)) { - return optional( + return Optional.of( new BigDecimal( ((Number) value).doubleValue(), new MathContext(precision)) @@ -85,9 +135,4 @@ private static Optional coerceDecimal(final Object value, final SqlDecima return Optional.empty(); } - - @SuppressWarnings("unchecked") - private static Optional optional(final Object value) { - return Optional.of((T)value); - } } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/SqlValueCoercer.java b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/SqlValueCoercer.java index baeb5c8ff4d0..221c5467edee 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/SqlValueCoercer.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/schema/ksql/SqlValueCoercer.java @@ -30,8 +30,7 @@ public interface SqlValueCoercer { * * @param value the value to try to coerce. * @param targetSchema the target SQL type. - * @param target Java type * @return the coerced value if the value could be coerced, {@link Optional#empty()} otherwise. */ - Optional coerce(Object value, SqlType targetSchema); + Optional coerce(Object value, SqlType targetSchema); } diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java index 9f94b5b750b1..982b962e4aa5 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java @@ -649,16 +649,16 @@ public void shouldFormatInsertValuesNoSchema() { } @Test - public void shouldNotParseArbitraryExpressions() { + public void shouldParseArbitraryExpressions() { // Given: final String statementString = "INSERT INTO ADDRESS VALUES (2 + 1);"; - - // Expect: - expectedException.expect(ParseFailedException.class); - expectedException.expectMessage("mismatched input"); + final Statement statement = KsqlParserTestUtil.buildSingleAst(statementString, metaStore).getStatement(); // When: - KsqlParserTestUtil.buildSingleAst(statementString, metaStore); + final String result = SqlFormatter.formatSql(statement); + + // Then: + assertThat(result, is("INSERT INTO ADDRESS VALUES ((2 + 1))")); } @Test diff --git a/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java b/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java index 868942cd0037..4bb374809a55 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/schema/ksql/DefaultSqlValueCoercerTest.java @@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlException; @@ -56,6 +58,8 @@ public class DefaultSqlValueCoercerTest { .put(SqlBaseType.DECIMAL, SqlTypes.decimal(2, 1)) .put(SqlBaseType.DOUBLE, SqlTypes.DOUBLE) .put(SqlBaseType.STRING, SqlTypes.STRING) + .put(SqlBaseType.ARRAY, SqlArray.of(SqlTypes.BIGINT)) + .put(SqlBaseType.MAP, SqlMap.of(SqlTypes.BIGINT)) .build(); private static final Map INSTANCES = ImmutableMap @@ -66,6 +70,8 @@ public class DefaultSqlValueCoercerTest { .put(SqlBaseType.DECIMAL, BigDecimal.ONE) .put(SqlBaseType.DOUBLE, 3.0D) .put(SqlBaseType.STRING, "4.1") + .put(SqlBaseType.ARRAY, ImmutableList.of(1L, 2L)) + .put(SqlBaseType.MAP, ImmutableMap.of("foo", 1L)) .build(); private DefaultSqlValueCoercer coercer; @@ -78,16 +84,6 @@ public void setUp() { coercer = new DefaultSqlValueCoercer(); } - @Test(expected = KsqlException.class) - public void shouldThrowOnArray() { - coercer.coerce(ImmutableList.of(), SqlTypes.array(SqlTypes.STRING)); - } - - @Test(expected = KsqlException.class) - public void shouldThrowOnMap() { - coercer.coerce(ImmutableMap.of(), SqlTypes.map(SqlTypes.STRING)); - } - @Test(expected = KsqlException.class) public void shouldThrowOnStruct() { coercer.coerce(new Struct(SchemaBuilder.struct()), @@ -167,6 +163,40 @@ public void shouldNotCoerceToDouble() { assertThat(coercer.coerce("1", SqlTypes.DOUBLE), is(Optional.empty())); } + @Test + public void shouldCoerceToArray() { + final SqlType arrayType = SqlTypes.array(SqlTypes.DOUBLE); + assertThat(coercer.coerce(ImmutableList.of(1), arrayType), is(Optional.of(ImmutableList.of(1d)))); + assertThat(coercer.coerce(ImmutableList.of(1L), arrayType), is(Optional.of(ImmutableList.of(1d)))); + assertThat(coercer.coerce(ImmutableList.of(1.1), arrayType), is(Optional.of(ImmutableList.of(1.1d)))); + } + + @Test + public void shouldNotCoerceToArray() { + final SqlType arrayType = SqlTypes.array(SqlTypes.DOUBLE); + assertThat(coercer.coerce(true, arrayType), is(Optional.empty())); + assertThat(coercer.coerce(1L, arrayType), is(Optional.empty())); + assertThat(coercer.coerce("foo", arrayType), is(Optional.empty())); + assertThat(coercer.coerce(ImmutableMap.of("foo", 1), arrayType), is(Optional.empty())); + } + + @Test + public void shouldCoerceToMap() { + final SqlType mapType = SqlTypes.map(SqlTypes.DOUBLE); + assertThat(coercer.coerce(ImmutableMap.of("foo", 1), mapType), is(Optional.of(ImmutableMap.of("foo", 1d)))); + assertThat(coercer.coerce(ImmutableMap.of("foo", 1L), mapType), is(Optional.of(ImmutableMap.of("foo", 1d)))); + assertThat(coercer.coerce(ImmutableMap.of("foo", 1.1), mapType), is(Optional.of(ImmutableMap.of("foo", 1.1d)))); + } + + @Test + public void shouldNotCoerceToMap() { + final SqlType mapType = SqlTypes.map(SqlTypes.DOUBLE); + assertThat(coercer.coerce(true, mapType), is(Optional.empty())); + assertThat(coercer.coerce(1L, mapType), is(Optional.empty())); + assertThat(coercer.coerce("foo", mapType), is(Optional.empty())); + assertThat(coercer.coerce(ImmutableList.of("foo"), mapType), is(Optional.empty())); + } + @Test public void shouldCoerceToString() { assertThat(coercer.coerce("foobar", SqlTypes.STRING), is(Optional.of("foobar")));