diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/SqlFormatter.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/SqlFormatter.java index fabfdc6225c5..44d8467a2116 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/SqlFormatter.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/SqlFormatter.java @@ -20,6 +20,7 @@ import com.google.common.base.Strings; import io.confluent.ksql.execution.expression.formatter.ExpressionFormatter; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.parser.tree.AliasedRelation; import io.confluent.ksql.parser.tree.AllColumns; import io.confluent.ksql.parser.tree.AstNode; @@ -182,7 +183,7 @@ protected Void visitAllColumns(final AllColumns node, final Integer context) { @Override protected Void visitTable(final Table node, final Integer indent) { - builder.append(node.getName().toString()); + builder.append(escapedName(node.getName())); return null; } @@ -212,7 +213,7 @@ protected Void visitAliasedRelation(final AliasedRelation node, final Integer in process(node.getRelation(), indent); builder.append(' ') - .append(node.getAlias()); + .append(ParserUtil.escapeIfReservedIdentifier(node.getAlias())); return null; } @@ -381,7 +382,7 @@ private void visitExtended() { @Override public Void visitRegisterType(final RegisterType node, final Integer context) { builder.append("CREATE TYPE "); - builder.append(node.getName()); + builder.append(ParserUtil.escapeIfReservedIdentifier(node.getName())); builder.append(" AS "); builder.append(ExpressionFormatterUtil.formatExpression(node.getType())); builder.append(";"); @@ -395,7 +396,7 @@ private void visitDrop(final DropStatement node, final String sourceType) { if (node.getIfExists()) { builder.append("IF EXISTS "); } - builder.append(node.getName()); + builder.append(escapedName(node.getName())); if (node.isDeleteTopic()) { builder.append(" DELETE TOPIC"); } @@ -404,7 +405,7 @@ private void visitDrop(final DropStatement node, final String sourceType) { private void processRelation(final Relation relation, final Integer indent) { if (relation instanceof Table) { builder.append("TABLE ") - .append(((Table) relation).getName()) + .append(escapedName(((Table) relation).getName())) .append('\n'); } else { process(relation, indent); @@ -434,7 +435,7 @@ private void formatCreate(final CreateSource node) { builder.append("IF NOT EXISTS "); } - builder.append(node.getName()); + builder.append(escapedName(node.getName())); final String elements = node.getElements().stream() .map(Formatter::formatTableElement) @@ -463,7 +464,7 @@ private void formatCreateAs(final CreateAsSelect node, final Integer indent) { builder.append("IF NOT EXISTS "); } - builder.append(node.getName()); + builder.append(escapedName(node.getName())); final String tableProps = node.getProperties().toString(); if (!tableProps.isEmpty()) { @@ -487,4 +488,11 @@ private static String formatTableElement(final TableElement e) { + (e.getNamespace() == Namespace.KEY ? " KEY" : ""); } } + + private static String escapedName(final QualifiedName name) { + return name.getParts() + .stream() + .map(ParserUtil::escapeIfReservedIdentifier) + .collect(Collectors.joining(".")); + } } diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java index 1cd11e2bbfa1..3fdf01900b93 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/SqlFormatterTest.java @@ -104,6 +104,10 @@ public class SqlFormatterTest { .valueField("CATEGORY", categorySchema) .build(); + private static final LogicalSchema tableSchema = LogicalSchema.builder() + .valueField("TABLE", SqlTypes.STRING) + .build(); + private static final LogicalSchema ORDERS_SCHEMA = LogicalSchema.builder() .valueField("ORDERTIME", SqlTypes.BIGINT) .valueField("ORDERID", SqlTypes.BIGINT) @@ -185,6 +189,18 @@ public void setUp() { ); metaStore.putSource(ksqlTableOrders); + + final KsqlTable ksqlTableTable = new KsqlTable<>( + "sqlexpression", + "TABLE", + tableSchema, + SerdeOption.none(), + KeyField.of("TABLE", tableSchema.findValueField("TABLE").get()), + new MetadataTimestampExtractionPolicy(), + ksqlTopicItems + ); + + metaStore.putSource(ksqlTableTable); } @Test @@ -285,7 +301,7 @@ public void shouldFormatLeftJoinWithWithin() { criteria, Optional.of(new WithinExpression(10, TimeUnit.SECONDS))); - final String expected = "left L\nLEFT OUTER JOIN right R WITHIN 10 SECONDS ON " + final String expected = "`left` L\nLEFT OUTER JOIN `right` R WITHIN 10 SECONDS ON " + "(('left.col0' = 'right.col0'))"; assertEquals(expected, SqlFormatter.formatSql(join)); } @@ -296,7 +312,7 @@ public void shouldFormatLeftJoinWithoutJoinWindow() { criteria, Optional.empty()); final String result = SqlFormatter.formatSql(join); - final String expected = "left L\nLEFT OUTER JOIN right R ON (('left.col0' = 'right.col0'))"; + final String expected = "`left` L\nLEFT OUTER JOIN `right` R ON (('left.col0' = 'right.col0'))"; assertEquals(expected, result); } @@ -306,7 +322,7 @@ public void shouldFormatInnerJoin() { criteria, Optional.of(new WithinExpression(10, TimeUnit.SECONDS))); - final String expected = "left L\nINNER JOIN right R WITHIN 10 SECONDS ON " + final String expected = "`left` L\nINNER JOIN `right` R WITHIN 10 SECONDS ON " + "(('left.col0' = 'right.col0'))"; assertEquals(expected, SqlFormatter.formatSql(join)); } @@ -317,7 +333,7 @@ public void shouldFormatInnerJoinWithoutJoinWindow() { criteria, Optional.empty()); - final String expected = "left L\nINNER JOIN right R ON (('left.col0' = 'right.col0'))"; + final String expected = "`left` L\nINNER JOIN `right` R ON (('left.col0' = 'right.col0'))"; assertEquals(expected, SqlFormatter.formatSql(join)); } @@ -328,7 +344,7 @@ public void shouldFormatOuterJoin() { criteria, Optional.of(new WithinExpression(10, TimeUnit.SECONDS))); - final String expected = "left L\nFULL OUTER JOIN right R WITHIN 10 SECONDS ON" + final String expected = "`left` L\nFULL OUTER JOIN `right` R WITHIN 10 SECONDS ON" + " (('left.col0' = 'right.col0'))"; assertEquals(expected, SqlFormatter.formatSql(join)); } @@ -340,7 +356,7 @@ public void shouldFormatOuterJoinWithoutJoinWindow() { criteria, Optional.empty()); - final String expected = "left L\nFULL OUTER JOIN right R ON (('left.col0' = 'right.col0'))"; + final String expected = "`left` L\nFULL OUTER JOIN `right` R ON (('left.col0' = 'right.col0'))"; assertEquals(expected, SqlFormatter.formatSql(join)); } @@ -696,6 +712,30 @@ public void shouldFormatStructWithReservedWords() { assertThat(result, is("CREATE STREAM S (FOO STRUCT<`END` STRING>) WITH (KAFKA_TOPIC='foo', VALUE_FORMAT='JSON');")); } + @Test + public void shouldEscapeReservedSourceNames() { + // Given: + final Statement statement = parseSingle("CREATE STREAM `SELECT` (foo VARCHAR) WITH (kafka_topic='foo', value_format='JSON');"); + + // When: + final String result = SqlFormatter.formatSql(statement); + + // Then: + assertThat(result, is("CREATE STREAM `SELECT` (FOO STRING) WITH (KAFKA_TOPIC='foo', VALUE_FORMAT='JSON');")); + } + + @Test + public void shouldEscapeReservedNameAndAlias() { + // Given: + final Statement statement = parseSingle("CREATE STREAM a AS SELECT `SELECT` FROM `TABLE`;"); + + // When: + final String result = SqlFormatter.formatSql(statement); + + // Then: + assertThat(result, is("CREATE STREAM A AS SELECT `TABLE`.`SELECT` \"SELECT\"\nFROM `TABLE` `TABLE`")); + } + private Statement parseSingle(final String statementString) { return KsqlParserTestUtil.buildSingleAst(statementString, metaStore).getStatement(); }