Skip to content

Commit

Permalink
fix: support conversion of STRING to BIGINT for window bounds (#4500)
Browse files Browse the repository at this point in the history
Fixes: #4482

Comparison and Between expressions in the WHERE clause already support magic conversion from a STRING containing a ISO formatted datetime into a BIGINT for the `ROWTIME` column.

This changes extends the support to cover the `WINDOWSTART` and `WINDOWEND` columns.

The change also fixes a bug where by a numeric Between expression on `ROWTIME` resulted in a class-cast exception, e.g. `WHERE ROWTIME < 123546794894`.
  • Loading branch information
big-andy-coates authored Feb 10, 2020
1 parent 6b8bc2a commit 9c3cbf8
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 403 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,16 @@ public Expression process(final Expression expression) {

private final RewritingVisitor<C> rewriter;

@SuppressWarnings("unchecked")
public static <C, T extends Expression> T rewriteWith(
final BiFunction<Expression, Context<C>, Optional<Expression>> plugin, final T expression) {
return rewriteWith(plugin, expression, null);
}

@SuppressWarnings("unchecked")
public static <C, T extends Expression> T rewriteWith(
final BiFunction<Expression, Context<C>, Optional<Expression>> plugin,
final T expression,
final C context) {
return new ExpressionTreeRewriter<C>(plugin).rewrite(expression, context);
return new ExpressionTreeRewriter<>(plugin).rewrite(expression, context);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -182,7 +180,7 @@ public Expression visitSubscriptExpression(
final SubscriptExpression node,
final C context) {
final Optional<Expression> result
= plugin.apply(node, new Context<C>(context, this));
= plugin.apply(node, new Context<>(context, this));
if (result.isPresent()) {
return result.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.confluent.ksql.engine.rewrite;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
Expand All @@ -24,28 +25,33 @@
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.util.SchemaUtil;
import io.confluent.ksql.util.timestamp.PartialStringToTimestampParser;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class StatementRewriteForMagicPseudoTimestamp {

private static final Set<ColumnName> SUPPORTED_COLUMNS = ImmutableSet.<ColumnName>builder()
.addAll(SchemaUtil.windowBoundsColumnNames())
.add(SchemaUtil.ROWTIME_NAME)
.build();

public class StatementRewriteForRowtime {

private final PartialStringToTimestampParser parser;

public StatementRewriteForRowtime() {
public StatementRewriteForMagicPseudoTimestamp() {
this(new PartialStringToTimestampParser());
}

@VisibleForTesting
StatementRewriteForRowtime(final PartialStringToTimestampParser parser) {
StatementRewriteForMagicPseudoTimestamp(final PartialStringToTimestampParser parser) {
this.parser = Objects.requireNonNull(parser, "parser");
}

public Expression rewriteForRowtime(final Expression expression) {
if (noRewriteRequired(expression)) {
return expression;
}
public Expression rewrite(final Expression expression) {
return new ExpressionTreeRewriter<>(new OperatorPlugin()::process)
.rewrite(expression, null);
}
Expand All @@ -66,16 +72,22 @@ public Optional<Expression> visitBetweenPredicate(
final BetweenPredicate node,
final Context<Void> context
) {
if (noRewriteRequired(node.getValue())) {
if (!supportedColumnRef(node.getValue())) {
return Optional.empty();
}

final Optional<Expression> min = maybeRewriteTimestamp(node.getMin());
final Optional<Expression> max = maybeRewriteTimestamp(node.getMax());
if (!min.isPresent() && !max.isPresent()) {
return Optional.empty();
}

return Optional.of(
new BetweenPredicate(
node.getLocation(),
node.getValue(),
rewriteTimestamp(((StringLiteral) node.getMin()).getValue()),
rewriteTimestamp(((StringLiteral) node.getMax()).getValue())
min.orElse(node.getMin()),
max.orElse(node.getMax())
)
);
}
Expand All @@ -85,42 +97,45 @@ public Optional<Expression> visitComparisonExpression(
final ComparisonExpression node,
final Context<Void> context
) {
if (expressionIsRowtime(node.getLeft()) && node.getRight() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
node.getLeft(),
rewriteTimestamp(((StringLiteral) node.getRight()).getValue())
)
);
if (supportedColumnRef(node.getLeft())) {
final Optional<Expression> right = maybeRewriteTimestamp(node.getRight());
return right.map(r -> new ComparisonExpression(
node.getLocation(),
node.getType(),
node.getLeft(),
r
));
}

if (expressionIsRowtime(node.getRight()) && node.getLeft() instanceof StringLiteral) {
return Optional.of(
new ComparisonExpression(
node.getLocation(),
node.getType(),
rewriteTimestamp(((StringLiteral) node.getLeft()).getValue()),
node.getRight()
)
);
if (supportedColumnRef(node.getRight())) {
final Optional<Expression> left = maybeRewriteTimestamp(node.getLeft());
return left.map(l -> new ComparisonExpression(
node.getLocation(),
node.getType(),
l,
node.getRight()
));
}

return Optional.empty();
}
}

private static boolean refIsRowtime(final ColumnReferenceExp node) {
return node.getReference().equals(SchemaUtil.ROWTIME_NAME);
}
private Optional<Expression> maybeRewriteTimestamp(final Expression maybeTimestamp) {
if (!(maybeTimestamp instanceof StringLiteral)) {
return Optional.empty();
}

private static boolean expressionIsRowtime(final Expression node) {
return (node instanceof ColumnReferenceExp)
&& refIsRowtime((ColumnReferenceExp) node);
final String text = ((StringLiteral) maybeTimestamp).getValue();

return Optional.of(new LongLiteral(parser.parse(text)));
}

private LongLiteral rewriteTimestamp(final String timestamp) {
return new LongLiteral(parser.parse(timestamp));
private static boolean supportedColumnRef(final Expression maybeColumnRef) {
if (!(maybeColumnRef instanceof ColumnReferenceExp)) {
return false;
}

return SUPPORTED_COLUMNS.contains(((ColumnReferenceExp) maybeColumnRef).getReference());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import static java.util.Objects.requireNonNull;

import io.confluent.ksql.engine.rewrite.StatementRewriteForRowtime;
import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
Expand Down Expand Up @@ -136,8 +136,7 @@ public SchemaKStream<K> filter(
}

static Expression rewriteTimeComparisonForFilter(final Expression expression) {
return new StatementRewriteForRowtime()
.rewriteForRowtime(expression);
return new StatementRewriteForMagicPseudoTimestamp().rewrite(expression);
}

public SchemaKStream<K> select(
Expand Down Expand Up @@ -325,7 +324,7 @@ public SchemaKStream<Struct> selectKey(
final Expression keyExpression,
final QueryContext.Stacker contextStacker
) {
if (!needsRepartition(keyExpression)) {
if (repartitionNotNeeded(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

Expand Down Expand Up @@ -360,9 +359,9 @@ private KeyField getNewKeyField(final Expression expression) {
return getSchema().isMetaColumn(columnName) ? KeyField.none() : newKeyField;
}

protected boolean needsRepartition(final Expression expression) {
protected boolean repartitionNotNeeded(final Expression expression) {
if (!(expression instanceof UnqualifiedColumnReferenceExp)) {
return true;
return false;
}

final ColumnName columnName = ((UnqualifiedColumnReferenceExp) expression).getReference();
Expand All @@ -379,7 +378,7 @@ protected boolean needsRepartition(final Expression expression) {
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

return !namesMatch && !isRowKey(columnName);
return namesMatch || isRowKey(columnName);
}

private boolean isRowKey(final ColumnName fieldName) {
Expand Down Expand Up @@ -453,7 +452,7 @@ public SchemaKGroupedStream groupBy(
);
}

@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "rawtypes"})
private SchemaKGroupedStream groupByKey(
final KeyFormat rekeyedKeyFormat,
final ValueFormat valueFormat,
Expand Down Expand Up @@ -534,7 +533,7 @@ LogicalSchema resolveSchema(final ExecutionStep<?> step) {
return new StepSchemaResolver(ksqlConfig, functionRegistry).resolve(step, schema);
}

LogicalSchema resolveSchema(final ExecutionStep<?> step, final SchemaKStream right) {
LogicalSchema resolveSchema(final ExecutionStep<?> step, final SchemaKStream<?> right) {
return new StepSchemaResolver(ksqlConfig, functionRegistry).resolve(
step,
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public SchemaKTable<K> select(
@Override
public SchemaKStream<Struct> selectKey(final Expression keyExpression,
final Stacker contextStacker) {
if (!needsRepartition(keyExpression)) {
if (repartitionNotNeeded(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

Expand Down
Loading

0 comments on commit 9c3cbf8

Please sign in to comment.