diff --git a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java index 5cfc228f1b1..fd47235e828 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.BetweenPredicate; @@ -24,6 +25,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; @@ -51,6 +54,7 @@ import io.confluent.ksql.execution.expression.tree.Type; import io.confluent.ksql.execution.expression.tree.WhenClause; import java.util.List; +import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; @@ -188,6 +192,27 @@ public Expression visitSubscriptExpression( return new SubscriptExpression(node.getLocation(), base, index); } + @Override + public Expression visitCreateArrayExpression(final CreateArrayExpression exp, final C context) { + final Builder values = ImmutableList.builder(); + for (Expression value : exp.getValues()) { + values.add(rewriter.apply(value, context)); + } + return new CreateArrayExpression(exp.getLocation(), values.build()); + } + + @Override + public Expression visitCreateMapExpression(final CreateMapExpression exp, final C context) { + final ImmutableMap.Builder map = ImmutableMap.builder(); + for (Entry entry : exp.getMap().entrySet()) { + map.put( + rewriter.apply(entry.getKey(), context), + rewriter.apply(entry.getValue(), context) + ); + } + return new CreateMapExpression(exp.getLocation(), map.build()); + } + @Override public Expression visitStructExpression(final CreateStructExpression node, final C context) { final Builder fields = ImmutableList.builder(); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/map/AsMap.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udf/map/AsMap.java index 9fb87c47e67..2c83acc016c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/map/AsMap.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udf/map/AsMap.java @@ -38,5 +38,4 @@ public final Map asMap( } return map; } - -} \ No newline at end of file +} diff --git a/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java index 4fbaf11dd94..76003c2a9e8 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java @@ -671,6 +671,42 @@ public void shouldHandleMaps() { assertThat(result, is("value1")); } + @Test + public void shouldHandleCreateArray() { + // Given: + final Expression expression = analyzeQuery( + "SELECT ARRAY['foo', COL" + STRING_INDEX1 + "] FROM codegen_test EMIT CHANGES;", metaStore) + .getSelectExpressions() + .get(0) + .getExpression(); + + // When: + final Object result = codeGenRunner + .buildCodeGenFromParseTree(expression, "Array") + .evaluate(genericRow(ONE_ROW)); + + // Then: + assertThat(result, is(ImmutableList.of("foo", "S1"))); + } + + @Test + public void shouldHandleCreateMap() { + // Given: + final Expression expression = analyzeQuery( + "SELECT MAP('foo' := 'foo', 'bar' := COL" + STRING_INDEX1 + ") FROM codegen_test EMIT CHANGES;", metaStore) + .getSelectExpressions() + .get(0) + .getExpression(); + + // When: + final Object result = codeGenRunner + .buildCodeGenFromParseTree(expression, "Map") + .evaluate(genericRow(ONE_ROW)); + + // Then: + assertThat(result, is(ImmutableMap.of("foo", "foo", "bar", "S1"))); + } + @Test public void shouldHandleInvalidJavaIdentifiers() { // Given: diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java index 4c7c48aed1d..c385d04565c 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/InsertValuesExecutorTest.java @@ -31,6 +31,7 @@ import io.confluent.ksql.execution.ddl.commands.KsqlTopic; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.BooleanLiteral; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; @@ -470,46 +471,48 @@ public void shouldHandleNegativeValueExpression() { @Test public void shouldHandleUdfs() { // Given: - givenSourceStreamWithSchema(SINGLE_ARRAY_SCHEMA, SerdeOption.none(), Optional.empty()); + givenSourceStreamWithSchema(SINGLE_FIELD_SCHEMA, SerdeOption.none(), Optional.empty()); final ConfiguredStatement statement = givenInsertValuesStrings( ImmutableList.of("COL0"), ImmutableList.of( new FunctionCall( - FunctionName.of("AS_ARRAY"), - ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2)))) + FunctionName.of("SUBSTRING"), + ImmutableList.of(new StringLiteral("foo"), new IntegerLiteral(2)))) ); // When: executor.execute(statement, ImmutableMap.of(), engine, serviceContext); // Then: - verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of(ImmutableList.of(1, 2)))); + verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of("oo"))); verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); } @Test public void shouldHandleNestedUdfs() { // Given: - givenSourceStreamWithSchema(SINGLE_MAP_SCHEMA, SerdeOption.none(), Optional.empty()); + givenSourceStreamWithSchema(SINGLE_FIELD_SCHEMA, SerdeOption.none(), Optional.empty()); final ConfiguredStatement statement = givenInsertValuesStrings( ImmutableList.of("COL0"), ImmutableList.of( new FunctionCall( - FunctionName.of("AS_MAP"), + FunctionName.of("SUBSTRING"), ImmutableList.of( - new FunctionCall(FunctionName.of("AS_ARRAY"), ImmutableList.of(new StringLiteral("foo"))), - new FunctionCall(FunctionName.of("AS_ARRAY"), ImmutableList.of(new IntegerLiteral(1))) - )) - ) + new FunctionCall( + FunctionName.of("SUBSTRING"), + ImmutableList.of(new StringLiteral("foo"), new IntegerLiteral(2)) + ), + new IntegerLiteral(2)) + )) ); // When: executor.execute(statement, ImmutableMap.of(), engine, serviceContext); // Then: - verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of(ImmutableMap.of("foo", 1)))); + verify(valueSerializer).serialize(TOPIC_NAME, new GenericRow(ImmutableList.of("o"))); verify(producer).send(new ProducerRecord<>(TOPIC_NAME, null, 1L, KEY, VALUE)); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java index 931eca22958..d99ba83d1c0 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; @@ -34,6 +35,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; @@ -532,6 +535,42 @@ public void shouldRewriteSubscriptExpression() { assertThat(rewritten, equalTo(new SubscriptExpression(parsed.getLocation(), expr1, expr2))); } + @Test + public void shouldRewriteCreateArrayExpression() { + // Given: + final CreateArrayExpression parsed = parseExpression("ARRAY['foo', col4[1]]"); + final Expression firstVal = parsed.getValues().get(0); + final Expression secondVal = parsed.getValues().get(1); + when(processor.apply(firstVal, context)).thenReturn(expr1); + when(processor.apply(secondVal, context)).thenReturn(expr2); + + // When: + final Expression rewritten = expressionRewriter.rewrite(parsed, context); + + // Then: + assertThat(rewritten, equalTo(new CreateArrayExpression(ImmutableList.of(expr1, expr2)))); + } + + @Test + public void shouldRewriteCreateMapExpression() { + // Given: + final CreateMapExpression parsed = parseExpression("MAP('foo' := SUBSTRING('foo',0), 'bar' := col4[1])"); + final Expression firstVal = parsed.getMap().get(new StringLiteral("foo")); + final Expression secondVal = parsed.getMap().get(new StringLiteral("bar")); + when(processor.apply(firstVal, context)).thenReturn(expr1); + when(processor.apply(secondVal, context)).thenReturn(expr2); + when(processor.apply(new StringLiteral("foo"), context)).thenReturn(new StringLiteral("foo")); + when(processor.apply(new StringLiteral("bar"), context)).thenReturn(new StringLiteral("bar")); + + // When: + final Expression rewritten = expressionRewriter.rewrite(parsed, context); + + // Then: + assertThat(rewritten, + equalTo(new CreateMapExpression( + ImmutableMap.of(new StringLiteral("foo"), expr1, new StringLiteral("bar"), expr2)))); + } + @Test public void shouldRewriteStructExpression() { // Given: diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/list/AsArrayTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udf/list/AsArrayTest.java deleted file mode 100644 index 7aeb469b117..00000000000 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/list/AsArrayTest.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"; you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.function.udf.list; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import java.util.List; -import org.junit.Test; - -public class AsArrayTest { - - @Test - public void shouldCreateArrayFromEmpty() { - // When: - final List array = new AsArray().asArray(); - - // Then: - assertThat(array, empty()); - } - - @Test - public void shouldCreateSingleNullArray() { - // When: - final List array = new AsArray().asArray((String) null); - - // Then: - assertThat(array, is(Lists.newArrayList((String) null))); - } - - @Test - public void shouldCreateSingleElementArray() { - // When: - final List array = new AsArray().asArray("foo"); - - // Then: - assertThat(array, is(ImmutableList.of("foo"))); - } - - @Test - public void shouldCreateMultiElementArray() { - // When: - final List array = new AsArray().asArray("foo", "bar"); - - // Then: - assertThat(array, is(ImmutableList.of("foo", "bar"))); - } - - @Test - public void shouldCreateMultiElementArrayWithNulls() { - // When: - final List array = new AsArray().asArray("foo", null); - - // Then: - assertThat(array, is(Lists.newArrayList("foo", null))); - } - - @Test - public void shouldCreateMultiElementArrayOfInts() { - // When: - final List array = new AsArray().asArray(1, 2); - - // Then: - assertThat(array, is(ImmutableList.of(1, 2))); - } - -} \ No newline at end of file diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index d90abb9e2ca..32248b75c26 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -16,6 +16,8 @@ package io.confluent.ksql.execution.codegen; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; @@ -38,6 +40,7 @@ import io.confluent.ksql.util.KsqlException; import java.util.ArrayList; import java.util.List; +import java.util.Map.Entry; import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -185,6 +188,21 @@ public Void visitSubscriptExpression(final SubscriptExpression node, final Void return null; } + @Override + public Void visitCreateArrayExpression(final CreateArrayExpression exp, final Void context) { + exp.getValues().forEach(val -> process(val, context)); + return null; + } + + @Override + public Void visitCreateMapExpression(final CreateMapExpression exp, final Void context) { + for (Entry entry : exp.getMap().entrySet()) { + process(entry.getKey(), context); + process(entry.getValue(), context); + } + return null; + } + @Override public Void visitStructExpression( final CreateStructExpression exp, diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index c445ac22cb7..85a96823488 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multiset; import io.confluent.ksql.execution.codegen.helpers.ArrayAccess; +import io.confluent.ksql.execution.codegen.helpers.ArrayBuilder; import io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; @@ -31,6 +32,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; @@ -82,6 +85,7 @@ import java.math.RoundingMode; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.function.Function; import java.util.function.Supplier; @@ -104,13 +108,15 @@ public class SqlToJavaVisitor { "java.util.List", "java.util.ArrayList", "com.google.common.collect.ImmutableList", + "com.google.common.collect.ImmutableMap", "java.util.function.Supplier", DecimalUtil.class.getCanonicalName(), BigDecimal.class.getCanonicalName(), MathContext.class.getCanonicalName(), RoundingMode.class.getCanonicalName(), SchemaBuilder.class.getCanonicalName(), - Struct.class.getCanonicalName() + Struct.class.getCanonicalName(), + ArrayBuilder.class.getCanonicalName() ); private static final Map DECIMAL_OPERATOR_NAME = ImmutableMap @@ -749,6 +755,45 @@ public Pair visitSubscriptExpression( } } + @Override + public Pair visitCreateArrayExpression( + final CreateArrayExpression exp, + final Void context + ) { + final StringBuilder array = new StringBuilder("new ArrayBuilder("); + array.append(exp.getValues().size()); + array.append((')')); + + for (Expression value : exp.getValues()) { + array.append(".add("); + array.append(process(value, context).getLeft()); + array.append(")"); + } + return new Pair<>( + "((List)" + array.toString() + ".build())", + expressionTypeManager.getExpressionSqlType(exp)); + } + + @Override + public Pair visitCreateMapExpression( + final CreateMapExpression exp, + final Void context + ) { + final StringBuilder map = new StringBuilder("ImmutableMap.builder()"); + + for (Entry entry: exp.getMap().entrySet()) { + map.append(".put("); + map.append(process(entry.getKey(), context).getLeft()); + map.append(", "); + map.append(process(entry.getValue(), context).getLeft()); + map.append(")"); + } + + return new Pair<>( + "((Map)" + map.toString() + ".build())", + expressionTypeManager.getExpressionSqlType(exp)); + } + @Override public Pair visitStructExpression( final CreateStructExpression node, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/list/AsArray.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/ArrayBuilder.java similarity index 52% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udf/list/AsArray.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/ArrayBuilder.java index 4eac80d1ee2..e4833044eab 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/list/AsArray.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/ArrayBuilder.java @@ -13,22 +13,31 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udf.list; +package io.confluent.ksql.execution.codegen.helpers; -import io.confluent.ksql.function.udf.Udf; -import io.confluent.ksql.function.udf.UdfDescription; -import io.confluent.ksql.function.udf.UdfParameter; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; -@UdfDescription(name = "AS_ARRAY", description = "Construct a list based on some inputs") -public class AsArray { +/** + * Used to construct arrays using the builder pattern. Note that we + * cannot use {@link com.google.common.collect.ImmutableList} because + * it does not accept null values. + */ +public class ArrayBuilder { + + private final ArrayList list; + + public ArrayBuilder(final int size) { + list = new ArrayList<>(size); + } + + public ArrayBuilder add(final Object value) { + list.add(value); + return this; + } - @SuppressWarnings("varargs") - @SafeVarargs - @Udf - public final List asArray(@UdfParameter final T... elements) { - return Arrays.asList(elements); + public List build() { + return list; } -} \ No newline at end of file +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java index f4dd6da6e10..3c5947f1706 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java @@ -24,6 +24,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; @@ -102,6 +104,31 @@ public String visitSubscriptExpression(final SubscriptExpression node, final Con + "[" + process(node.getIndex(), context) + "]"; } + @Override + public String visitCreateArrayExpression( + final CreateArrayExpression exp, + final Context context + ) { + return exp + .getValues() + .stream() + .map(val -> process(val, context)) + .collect(Collectors.joining(", ", "ARRAY[", "]")); + } + + @Override + public String visitCreateMapExpression(final CreateMapExpression exp, final Context context) { + return exp + .getMap() + .entrySet() + .stream() + .map(entry -> + process(entry.getKey(), context) + + ":=" + + process(entry.getValue(), context)) + .collect(Collectors.joining(", ", "MAP(", ")")); + } + @Override public String visitStructExpression(final CreateStructExpression exp, final Context context) { return exp diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateArrayExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateArrayExpression.java new file mode 100644 index 00000000000..d2168491c32 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateArrayExpression.java @@ -0,0 +1,68 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.expression.tree; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.parser.NodeLocation; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +@Immutable +public class CreateArrayExpression extends Expression { + + private final ImmutableList values; + + public CreateArrayExpression( + final Optional location, + final List values + ) { + super(location); + this.values = ImmutableList.copyOf(values); + } + + public CreateArrayExpression(final List values) { + this(Optional.empty(), values); + } + + public ImmutableList getValues() { + return values; + } + + @Override + protected R accept(final ExpressionVisitor visitor, final C context) { + return visitor.visitCreateArrayExpression(this, context); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final CreateArrayExpression that = (CreateArrayExpression) o; + return Objects.equals(values, that.values); + } + + @Override + public int hashCode() { + return Objects.hash(values); + } + +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateMapExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateMapExpression.java new file mode 100644 index 00000000000..adcab9f71ec --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/CreateMapExpression.java @@ -0,0 +1,67 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.expression.tree; + +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.parser.NodeLocation; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +@Immutable +public class CreateMapExpression extends Expression { + + private final ImmutableMap map; + + public CreateMapExpression( + final Optional location, + final Map map + ) { + super(location); + this.map = ImmutableMap.copyOf(map); + } + + public CreateMapExpression(final Map map) { + this(Optional.empty(), map); + } + + public ImmutableMap getMap() { + return map; + } + + @Override + protected R accept(final ExpressionVisitor visitor, final C context) { + return visitor.visitCreateMapExpression(this, context); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final CreateMapExpression that = (CreateMapExpression) o; + return Objects.equals(map, that.map); + } + + @Override + public int hashCode() { + return Objects.hash(map); + } +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java index 9a4ccd1ee13..b4c105da6c9 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java @@ -73,6 +73,10 @@ default R process(final Expression node, @Nullable final C context) { R visitSubscriptExpression(SubscriptExpression exp, @Nullable C context); + R visitCreateArrayExpression(CreateArrayExpression exp, @Nullable C context); + + R visitCreateMapExpression(CreateMapExpression exp, @Nullable C context); + R visitStructExpression(CreateStructExpression exp, @Nullable C context); R visitTimeLiteral(TimeLiteral exp, @Nullable C context); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java index d9f9091f5ce..061b7da92da 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java @@ -50,6 +50,19 @@ public Void visitSubscriptExpression(final SubscriptExpression node, final C con return null; } + @Override + public Void visitCreateArrayExpression(final CreateArrayExpression exp, final C context) { + exp.getValues().forEach(val -> process(val, context)); + return null; + } + + @Override + public Void visitCreateMapExpression(final CreateMapExpression exp, final C context) { + exp.getMap().keySet().forEach(key -> process(key, context)); + exp.getMap().values().forEach(val -> process(val, context)); + return null; + } + @Override public Void visitStructExpression(final CreateStructExpression node, final C context) { node.getFields().forEach(field -> process(field.getValue(), context)); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java index 2fcbef3e7d0..fdfd2c6249e 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java @@ -166,6 +166,16 @@ public R visitSubscriptExpression(final SubscriptExpression node, final C contex return visitExpression(node, context); } + @Override + public R visitCreateArrayExpression(final CreateArrayExpression node, final C context) { + return visitExpression(node, context); + } + + @Override + public R visitCreateMapExpression(final CreateMapExpression node, final C context) { + return visitExpression(node, context); + } + @Override public R visitStructExpression(final CreateStructExpression node, final C context) { return visitExpression(node, context); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index d518cad2d81..8147fbea6f5 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -23,6 +23,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; @@ -68,6 +70,7 @@ import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.VisitorUtil; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -331,6 +334,95 @@ public Void visitSubscriptExpression( return null; } + @Override + public Void visitCreateArrayExpression( + final CreateArrayExpression exp, + final ExpressionTypeContext context + ) { + if (exp.getValues().isEmpty()) { + throw new KsqlException( + "Array constructor cannot be empty. Please supply at least one element " + + "(see https://github.com/confluentinc/ksql/issues/4239)."); + } + + final List sqlTypes = exp + .getValues() + .stream() + .map(val -> { + process(val, context); + return context.getSqlType(); + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + if (sqlTypes.size() == 0) { + throw new KsqlException("Cannot construct an array with all NULL elements " + + "(see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may " + + "cast a NULL value to the desired type."); + } + + if (new HashSet<>(sqlTypes).size() != 1) { + throw new KsqlException( + String.format( + "Cannot construct an array with mismatching types (%s) from expression %s.", + sqlTypes, + exp)); + } + + context.setSqlType(SqlArray.of(sqlTypes.get(0))); + return null; + } + + @Override + public Void visitCreateMapExpression( + final CreateMapExpression exp, + final ExpressionTypeContext context + ) { + if (exp.getMap().isEmpty()) { + throw new KsqlException("Map constructor cannot be empty. Please supply at least one key " + + "value pair (see https://github.com/confluentinc/ksql/issues/4239)."); + } + + final List keyTypes = exp.getMap() + .keySet() + .stream() + .map(key -> { + process(key, context); + return context.getSqlType(); + }) + .collect(Collectors.toList()); + + if (keyTypes.stream().anyMatch(type -> !SqlTypes.STRING.equals(type))) { + throw new KsqlException("Only STRING keys are supported in maps but got: " + keyTypes); + } + + final List valueTypes = exp.getMap() + .values() + .stream() + .map(val -> { + process(val, context); + return context.getSqlType(); + }) + .distinct() + .collect(Collectors.toList()); + + if (valueTypes.size() != 1) { + throw new KsqlException( + String.format( + "Cannot construct a map with mismatching value types (%s) from expression %s.", + valueTypes, + exp)); + } + + if (valueTypes.get(0) == null) { + throw new KsqlException("Cannot construct MAP with NULL values. As a workaround, you " + + "may cast a NULL value to the desired type."); + } + + context.setSqlType(SqlMap.of(valueTypes.get(0))); + return null; + } + @Override public Void visitStructExpression( final CreateStructExpression exp, diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index cf219302e39..d36901ac035 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -32,12 +32,15 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression.Sign; import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; @@ -144,6 +147,44 @@ public void shouldProcessMapExpressionCorrectly() { assertThat(javaExpression, equalTo("((Double) ((java.util.Map)TEST1_COL5).get(\"key1\"))")); } + @Test + public void shouldProcessCreateArrayExpressionCorrectly() { + // Given: + Expression expression = new CreateArrayExpression( + ImmutableList.of( + new SubscriptExpression(MAPCOL, new StringLiteral("key1")), + new DoubleLiteral(1.0d) + ) + ); + + // When: + String java = sqlToJavaVisitor.process(expression); + + // Then: + assertThat( + java, + equalTo("((List)new ArrayBuilder(2).add(((Double) ((java.util.Map)TEST1_COL5).get(\"key1\"))).add(1.0).build())")); + } + + @Test + public void shouldProcessCreateMapExpressionCorrectly() { + // Given: + Expression expression = new CreateMapExpression( + ImmutableMap.of( + new StringLiteral("foo"), + new SubscriptExpression(MAPCOL, new StringLiteral("key1")), + new StringLiteral("bar"), + new DoubleLiteral(1.0d) + ) + ); + + // When: + String java = sqlToJavaVisitor.process(expression); + + // Then: + assertThat(java, equalTo("((Map)ImmutableMap.builder().put(\"foo\", ((Double) ((java.util.Map)TEST1_COL5).get(\"key1\"))).put(\"bar\", 1.0).build())")); + } + @Test public void shouldProcessStructExpressionCorrectly() { // Given: diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java index cc243a91780..e96e2b641b5 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java @@ -17,9 +17,11 @@ import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.BetweenPredicate; @@ -27,11 +29,14 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; +import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.InListExpression; import io.confluent.ksql.execution.expression.tree.InPredicate; @@ -87,6 +92,29 @@ public void shouldFormatSubscriptExpression() { equalTo("'abc'[3.0]")); } + @Test + public void shouldFormatCreateArrayExpression() { + assertThat(ExpressionFormatter.formatExpression( + new CreateArrayExpression(ImmutableList.of( + new StringLiteral("foo"), + new SubscriptExpression(new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("abc"))), new IntegerLiteral(1))) + )), + equalTo("ARRAY['foo', abc[1]]") + ); + } + + @Test + public void shouldFormatCreateMapExpression() { + assertThat(ExpressionFormatter.formatExpression( + new CreateMapExpression(ImmutableMap.builder() + .put(new StringLiteral("foo"), new SubscriptExpression(new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("abc"))), new IntegerLiteral(1))) + .put(new StringLiteral("bar"), new StringLiteral("val")) + .build() + )), + equalTo("MAP('foo':=abc[1], 'bar':='val')") + ); + } + @Test public void shouldFormatStructExpression() { assertThat(ExpressionFormatter.formatExpression(new CreateStructExpression( diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index 9c07a3688c7..ee2d1b08f03 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -33,11 +33,14 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; import io.confluent.ksql.execution.expression.tree.BooleanLiteral; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; @@ -48,6 +51,7 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.NotExpression; +import io.confluent.ksql.execution.expression.tree.NullLiteral; import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression; import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression; import io.confluent.ksql.execution.expression.tree.StringLiteral; @@ -318,6 +322,145 @@ public void shouldFailIfThereIsInvalidFieldNameInStructCall() { expressionTypeManager.getExpressionSqlType(expression); } + @Test + public void shouldEvaluateTypeForCreateArrayExpression() { + // Given: + Expression expression = new CreateArrayExpression( + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of(TEST1, COL0))) + ); + + // When: + final SqlType type = expressionTypeManager.getExpressionSqlType(expression); + + // Then: + assertThat(type, is(SqlTypes.array(SqlTypes.BIGINT))); + } + + + @Test + public void shouldEvaluateTypeForCreateArrayExpressionWithNull() { + // Given: + Expression expression = new CreateArrayExpression( + ImmutableList.of( + new ColumnReferenceExp(ColumnRef.of(TEST1, COL0)), + new NullLiteral() + ) + ); + + // When: + final SqlType type = expressionTypeManager.getExpressionSqlType(expression); + + // Then: + assertThat(type, is(SqlTypes.array(SqlTypes.BIGINT))); + } + + @Test + public void shouldThrowOnArrayAllNulls() { + // Given: + Expression expression = new CreateArrayExpression( + ImmutableList.of( + new NullLiteral() + ) + ); + + // Expect + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Cannot construct an array with all NULL elements"); + + // When: + expressionTypeManager.getExpressionSqlType(expression); + } + + @Test + public void shouldThrowOnArrayMultipleTypes() { + // Given: + Expression expression = new CreateArrayExpression( + ImmutableList.of( + new ColumnReferenceExp(ColumnRef.of(TEST1, COL0)), + new StringLiteral("foo") + ) + ); + + // Expect + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Cannot construct an array with mismatching types"); + + // When: + expressionTypeManager.getExpressionSqlType(expression); + } + + @Test + public void shouldEvaluateTypeForCreateMapExpression() { + // Given: + Expression expression = new CreateMapExpression( + ImmutableMap.of( + COL1, new ColumnReferenceExp(ColumnRef.of(TEST1, COL0)) + ) + ); + + // When: + final SqlType type = expressionTypeManager.getExpressionSqlType(expression); + + // Then: + assertThat(type, is(SqlTypes.map(SqlTypes.BIGINT))); + } + + @Test + public void shouldThrowOnMapOfNonStringKeys() { + // Given: + Expression expression = new CreateMapExpression( + ImmutableMap.of( + new IntegerLiteral(1), + new ColumnReferenceExp(ColumnRef.of(TEST1, COL0)) + ) + ); + + // Expect + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Only STRING keys are supported in maps"); + + // When: + expressionTypeManager.getExpressionSqlType(expression); + } + + @Test + public void shouldThrowOnMapOfMultipleTypes() { + // Given: + Expression expression = new CreateMapExpression( + ImmutableMap.of( + new StringLiteral("foo"), + new ColumnReferenceExp(ColumnRef.of(TEST1, COL0)), + new StringLiteral("bar"), + new StringLiteral("bar") + ) + ); + + // Expect + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Cannot construct a map with mismatching value types"); + + // When: + expressionTypeManager.getExpressionSqlType(expression); + } + + @Test + public void shouldThrowOnMapOfNullValues() { + // Given: + Expression expression = new CreateMapExpression( + ImmutableMap.of( + new StringLiteral("foo"), + new NullLiteral() + ) + ); + + // Expect + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Cannot construct MAP with NULL values"); + + // When: + expressionTypeManager.getExpressionSqlType(expression); + } + @Test public void shouldEvaluateTypeForStructExpression() { // Given: diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/asarray.json b/ksql-functional-tests/src/test/resources/query-validation-tests/asarray.json deleted file mode 100644 index 2d96ae987c4..00000000000 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/asarray.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "comments": [ - "Tests covering AS_LIST functionality with generics" - ], - "tests": [ - { - "name": "construct a list from two elements", - "statements": [ - "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT as_array(a, b, 3) as l FROM TEST;" - ], - "inputs": [ - {"topic": "test_topic", "value": {"a": 1, "b": 2}}, - {"topic": "test_topic", "value": {"a": null, "b": null}} - ], - "outputs": [ - {"topic": "OUTPUT", "value": {"L": [1, 2, 3]}}, - {"topic": "OUTPUT", "value": {"L": [null, null, 3]}} - ] - } - ] -} \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/asmap.json b/ksql-functional-tests/src/test/resources/query-validation-tests/asmap.json deleted file mode 100644 index 51480f599df..00000000000 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/asmap.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "comments": [ - "Tests covering map creation" - ], - "tests": [ - { - "name": "create map from key/value lists", - "statements": [ - "CREATE STREAM TEST (ks ARRAY, vals ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT AS_MAP(ks, vals) as m FROM TEST;" - ], - "inputs": [ - {"topic": "test_topic", "value": {"ks": ["a", "b"], "vals": [1, 2]}}, - {"topic": "test_topic", "value": {"ks": ["a", "b", "c"], "vals": [1, 2, 3]}}, - {"topic": "test_topic", "value": {"ks": ["a", "b"], "vals": [1, 2, 3]}}, - {"topic": "test_topic", "value": {"ks": ["a", "b", "c"], "vals": [1, 2, null]}} - ], - "outputs": [ - {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2}}}, - {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2, "c": 3}}}, - {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2}}}, - {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2, "c": null}}} - ] - } - ] -} \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/create_array.json b/ksql-functional-tests/src/test/resources/query-validation-tests/create_array.json new file mode 100644 index 00000000000..77342d469cd --- /dev/null +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/create_array.json @@ -0,0 +1,68 @@ +{ + "comments": [ + "Tests covering AS_LIST functionality with generics" + ], + "tests": [ + { + "name": "construct a list from two elements", + "statements": [ + "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ARRAY[a, b, 3] as l FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"a": 1, "b": 2}}, + {"topic": "test_topic", "value": {"a": null, "b": null}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"L": [1, 2, 3]}}, + {"topic": "OUTPUT", "value": {"L": [null, null, 3]}} + ] + }, + { + "name": "construct a list from null casted elements", + "statements": [ + "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ARRAY[CAST(NULL AS INT)] as l FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"a": 1, "b": 2}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"L": [null]}} + ] + }, + { + "name": "construct a list from no elements", + "statements": [ + "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ARRAY[] as l FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Array constructor cannot be empty. Please supply at least one element (see https://github.com/confluentinc/ksql/issues/4239)." + } + }, + { + "name": "construct a list from null non-casted elements", + "statements": [ + "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ARRAY[NULL] as l FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Cannot construct an array with all NULL elements (see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may cast a NULL value to the desired type." + } + }, + { + "name": "construct a list from mismatching elements", + "statements": [ + "CREATE STREAM TEST (a INT, b INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ARRAY[1, 1.0] as l FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Cannot construct an array with mismatching types ([INTEGER, DOUBLE]) from expression ARRAY[1, 1.0]" + } + } + ] +} \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/create_map.json b/ksql-functional-tests/src/test/resources/query-validation-tests/create_map.json new file mode 100644 index 00000000000..2291b1d81b0 --- /dev/null +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/create_map.json @@ -0,0 +1,72 @@ +{ + "comments": [ + "Tests covering map creation" + ], + "tests": [ + { + "name": "create map from named tuples", + "statements": [ + "CREATE STREAM TEST (k1 VARCHAR, k2 VARCHAR, v1 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT MAP(k1:=v1, k2:=v1*2) as M FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"k1": "foo", "k2": "bar", "v1": 10}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"M": {"foo": 10, "bar": 20}}} + ] + }, + { + "name": "create map from key/value lists", + "statements": [ + "CREATE STREAM TEST (ks ARRAY, vals ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT AS_MAP(ks, vals) as m FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ks": ["a", "b"], "vals": [1, 2]}}, + {"topic": "test_topic", "value": {"ks": ["a", "b", "c"], "vals": [1, 2, 3]}}, + {"topic": "test_topic", "value": {"ks": ["a", "b"], "vals": [1, 2, 3]}}, + {"topic": "test_topic", "value": {"ks": ["a", "b", "c"], "vals": [1, 2, null]}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2}}}, + {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2, "c": 3}}}, + {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2}}}, + {"topic": "OUTPUT", "value": {"M": {"a": 1, "b": 2, "c": null}}} + ] + }, + { + "name": "create map from named tuples mismatching types", + "statements": [ + "CREATE STREAM TEST (k1 VARCHAR, k2 VARCHAR, v1 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT MAP(k1:=v1, k2:='hello') as M FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Cannot construct a map with mismatching value types ([INTEGER, STRING]) from expression MAP(TEST.K1:=TEST.V1, TEST.K2:='hello')." + } + }, + { + "name": "create map from named tuples null values", + "statements": [ + "CREATE STREAM TEST (k1 VARCHAR, k2 VARCHAR, v1 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT MAP(k1:=v1, k2:=NULL) as M FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Cannot construct a map with mismatching value types ([INTEGER, null]) from expression MAP(TEST.K1:=TEST.V1, TEST.K2:=null)." + } + }, + { + "name": "create empty map", + "statements": [ + "CREATE STREAM TEST (k1 VARCHAR, k2 VARCHAR, v1 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT MAP() as M FROM TEST;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlException", + "message": "Map constructor cannot be empty. Please supply at least one key value pair (see https://github.com/confluentinc/ksql/issues/4239)." + } + } + ] +} \ No newline at end of file diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/cube.json b/ksql-functional-tests/src/test/resources/query-validation-tests/cube.json index cda8e6b6b0d..fcf45b9d84a 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/cube.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/cube.json @@ -7,7 +7,7 @@ "name": "cube two int columns", "statements": [ "CREATE STREAM TEST (col1 INT, col2 INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT cube_explode(as_array(col1, col2)) VAL FROM TEST;" + "CREATE STREAM OUTPUT AS SELECT cube_explode(array[col1, col2]) VAL FROM TEST;" ], "inputs": [ {"topic": "test_topic", "key": "0", "value": {"col1": 1, "col2": 2}}, @@ -27,7 +27,7 @@ "name": "cube three columns", "statements": [ "CREATE STREAM TEST (col1 VARCHAR, col2 VARCHAR, col3 VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT cube_explode(as_array(col1, col2, col3)) VAL FROM TEST;" + "CREATE STREAM OUTPUT AS SELECT cube_explode(array[col1, col2, col3]) VAL FROM TEST;" ], "inputs": [ {"topic": "test_topic", "key": "0", "value": {"col1": "one", "col2": "two", "col3" : "three"}} @@ -47,7 +47,7 @@ "name": "cube two columns with udf on third", "statements": [ "CREATE STREAM TEST (col1 VARCHAR, col2 VARCHAR, col3 INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT cube_explode(as_array(col1, col2)) VAL1, ABS(col3) VAL2 FROM TEST;" + "CREATE STREAM OUTPUT AS SELECT cube_explode(array[col1, col2]) VAL1, ABS(col3) VAL2 FROM TEST;" ], "inputs": [ {"topic": "test_topic", "key": "0", "value": {"col1": "one", "col2": "two", "col3" : 3}} @@ -63,7 +63,7 @@ "name": "cube two columns twice", "statements": [ "CREATE STREAM TEST (col1 VARCHAR, col2 VARCHAR, col3 INT, col4 INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT cube_explode(as_array(col1, col2)) VAL1, cube_explode(as_array(col3, col4)) VAL2 FROM TEST;" + "CREATE STREAM OUTPUT AS SELECT cube_explode(array[col1, col2]) VAL1, cube_explode(array[col3, col4]) VAL2 FROM TEST;" ], "inputs": [ {"topic": "test_topic", "key": "0", "value": {"col1": "one", "col2": "two", "col3" : 3, "col4" : 4}} diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json index bab92b33295..cddb568aab4 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json @@ -120,7 +120,7 @@ "name": "table functions with complex expressions", "statements": [ "CREATE STREAM TEST (F0 INT, F1 INT, F2 INT, F3 INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT F0, EXPLODE(AS_ARRAY(ABS(F1 + F2), ABS(F2 + F3), ABS(F3 + F1))) FROM TEST;" + "CREATE STREAM OUTPUT AS SELECT F0, EXPLODE(ARRAY[ABS(F1 + F2), ABS(F2 + F3), ABS(F3 + F1)]) FROM TEST;" ], "inputs": [ {"topic": "test_topic", "key": "0", "value": {"ID": 0, "F0": 1, "F1": 10, "F2": 11, "F3": 12}} diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json index 9f65eb6fa8d..0bc5884658f 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/insert-values.json @@ -226,7 +226,7 @@ "name": "should handle arbitrary expressions", "statements": [ "CREATE STREAM TEST (I INT, A ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", - "INSERT INTO TEST (I, A) VALUES (-1, AS_ARRAY(1, 1 + 1, 3));" + "INSERT INTO TEST (I, A) VALUES (-1, ARRAY[1, 1 + 1, 3]);" ], "inputs": [ ], @@ -234,6 +234,30 @@ {"topic": "test_topic", "key": null, "value": {"I": -1, "A": [1, 2, 3]}} ] }, + { + "name": "should handle arbitrary nested expressions", + "statements": [ + "CREATE STREAM TEST (I INT, A ARRAY>) WITH (kafka_topic='test_topic', value_format='JSON');", + "INSERT INTO TEST (I, A) VALUES (-1, ARRAY[ARRAY[1]]);" + ], + "inputs": [ + ], + "outputs": [ + {"topic": "test_topic", "key": null, "value": {"I": -1, "A": [[1]]}} + ] + }, + { + "name": "should handle map expressions", + "statements": [ + "CREATE STREAM TEST (I INT, A MAP) WITH (kafka_topic='test_topic', value_format='JSON');", + "INSERT INTO TEST (I, A) VALUES (-1, MAP('a':=0, 'b':=1));" + ], + "inputs": [ + ], + "outputs": [ + {"topic": "test_topic", "key": null, "value": {"I": -1, "A": {"a": 0, "b": 1}}} + ] + }, { "name": "should handle quoted identifiers", "statements": [ @@ -253,7 +277,7 @@ "name": "should handle struct expressions", "statements": [ "CREATE STREAM TEST (val STRUCT>) WITH (kafka_topic='test_topic', value_format='JSON');", - "INSERT INTO TEST (val) VALUES (STRUCT(FOO := '2.1', `bar` := AS_ARRAY('bar')));" + "INSERT INTO TEST (val) VALUES (STRUCT(FOO := '2.1', `bar` := ARRAY['bar']));" ], "inputs": [ ], @@ -265,7 +289,7 @@ "name": "should handle struct coercion", "statements": [ "CREATE STREAM TEST (val STRUCT, baz DOUBLE>) WITH (kafka_topic='test_topic', value_format='JSON');", - "INSERT INTO TEST (val) VALUES (STRUCT(FOO := 2, BAR := AS_ARRAY(2), BAZ := 2));" + "INSERT INTO TEST (val) VALUES (STRUCT(FOO := 2, BAR := ARRAY[2], BAZ := 2));" ], "inputs": [ ], diff --git a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 index 57775858df4..bc67fce1143 100644 --- a/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 +++ b/ksql-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 @@ -254,6 +254,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CAST '(' expression AS type ')' #cast | ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor + | MAP '(' (expression ASSIGN expression (',' expression ASSIGN expression)*)? ')' #mapConstructor | STRUCT '(' (identifier ASSIGN expression (',' identifier ASSIGN expression)*)? ')' #structConstructor | identifier '(' ASTERISK ')' #functionCall | identifier'(' (expression (',' expression)*)? ')' #functionCall diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index 1c0807e9872..0556ec79afa 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -31,6 +31,8 @@ import io.confluent.ksql.execution.expression.tree.Cast; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.CreateArrayExpression; +import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; import io.confluent.ksql.execution.expression.tree.DecimalLiteral; @@ -60,10 +62,12 @@ import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.parser.SqlBaseParser.ArrayConstructorContext; import io.confluent.ksql.parser.SqlBaseParser.CreateConnectorContext; import io.confluent.ksql.parser.SqlBaseParser.DescribeConnectorContext; import io.confluent.ksql.parser.SqlBaseParser.DropConnectorContext; import io.confluent.ksql.parser.SqlBaseParser.DropTypeContext; +import io.confluent.ksql.parser.SqlBaseParser.ExpressionContext; import io.confluent.ksql.parser.SqlBaseParser.IdentifierContext; import io.confluent.ksql.parser.SqlBaseParser.InsertValuesContext; import io.confluent.ksql.parser.SqlBaseParser.IntervalClauseContext; @@ -927,6 +931,38 @@ public Node visitCast(final SqlBaseParser.CastContext context) { ); } + @Override + public Node visitArrayConstructor(final ArrayConstructorContext context) { + final ImmutableList.Builder values = ImmutableList.builder(); + + for (ExpressionContext exp : context.expression()) { + values.add((Expression) visit(exp)); + } + + return new CreateArrayExpression( + getLocation(context), + values.build() + ); + } + + @Override + public Node visitMapConstructor(final SqlBaseParser.MapConstructorContext context) { + final ImmutableMap.Builder values = ImmutableMap.builder(); + + final List expression = context.expression(); + for (int i = 0; i < expression.size(); i += 2) { + values.put( + (Expression) visit(expression.get(i)), + (Expression) visit(expression.get(i + 1)) + ); + } + + return new CreateMapExpression( + getLocation(context), + values.build() + ); + } + @Override public Node visitStructConstructor(final SqlBaseParser.StructConstructorContext context) { final ImmutableList.Builder fields = ImmutableList.builder();