Skip to content

Commit

Permalink
fix: left/right/outer joins have nullable fields (substrait-io#157)
Browse files Browse the repository at this point in the history
During a left, right, or outer join, some set of fields from input
become nullable because they are optionally matched in the output. For
a left join, the right fields become nullable. For a right join, the
left fields become nullable. For an outer join, both sets become
nullable.
  • Loading branch information
carlyeks authored Jul 13, 2023
1 parent fc75482 commit 0efe714
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 4 deletions.
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import com.github.bsideup.jabel.Desugar;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.ImmutablePlan;
Expand Down Expand Up @@ -245,6 +247,14 @@ public List<FieldReference> fieldReferences(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
.type(type)
.failureBehavior(FailureBehavior.UNSPECIFIED)
.build();
}

public List<Expression.SortField> sortFields(Rel input, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(
Expand Down
166 changes: 162 additions & 4 deletions core/src/main/java/io/substrait/relation/Join.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,32 @@
import io.substrait.expression.Expression;
import io.substrait.proto.JoinRel;
import io.substrait.type.Type;
import io.substrait.type.Type.Binary;
import io.substrait.type.Type.Bool;
import io.substrait.type.Type.Date;
import io.substrait.type.Type.Decimal;
import io.substrait.type.Type.FP32;
import io.substrait.type.Type.FP64;
import io.substrait.type.Type.FixedBinary;
import io.substrait.type.Type.FixedChar;
import io.substrait.type.Type.I16;
import io.substrait.type.Type.I32;
import io.substrait.type.Type.I64;
import io.substrait.type.Type.I8;
import io.substrait.type.Type.IntervalDay;
import io.substrait.type.Type.IntervalYear;
import io.substrait.type.Type.ListType;
import io.substrait.type.Type.Map;
import io.substrait.type.Type.Str;
import io.substrait.type.Type.Struct;
import io.substrait.type.Type.Time;
import io.substrait.type.Type.Timestamp;
import io.substrait.type.Type.TimestampTZ;
import io.substrait.type.Type.UUID;
import io.substrait.type.Type.UserDefined;
import io.substrait.type.Type.VarChar;
import io.substrait.type.TypeCreator;
import io.substrait.type.TypeVisitor;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;
Expand Down Expand Up @@ -47,12 +72,145 @@ public static JoinType fromProto(JoinRel.JoinType proto) {
}
}

private static final class NullableTypeVisitor implements TypeVisitor<Type, RuntimeException> {

@Override
public Type visit(Bool type) throws RuntimeException {
return TypeCreator.NULLABLE.BOOLEAN;
}

@Override
public Type visit(I8 type) throws RuntimeException {
return TypeCreator.NULLABLE.I8;
}

@Override
public Type visit(I16 type) throws RuntimeException {
return TypeCreator.NULLABLE.I16;
}

@Override
public Type visit(I32 type) throws RuntimeException {
return TypeCreator.NULLABLE.I32;
}

@Override
public Type visit(I64 type) throws RuntimeException {
return TypeCreator.NULLABLE.I64;
}

@Override
public Type visit(FP32 type) throws RuntimeException {
return TypeCreator.NULLABLE.FP32;
}

@Override
public Type visit(FP64 type) throws RuntimeException {
return TypeCreator.NULLABLE.FP64;
}

@Override
public Type visit(Str type) throws RuntimeException {
return TypeCreator.NULLABLE.STRING;
}

@Override
public Type visit(Binary type) throws RuntimeException {
return TypeCreator.NULLABLE.BINARY;
}

@Override
public Type visit(Date type) throws RuntimeException {
return TypeCreator.NULLABLE.DATE;
}

@Override
public Type visit(Time type) throws RuntimeException {
return TypeCreator.NULLABLE.TIME;
}

@Override
public Type visit(TimestampTZ type) throws RuntimeException {
return TypeCreator.NULLABLE.TIMESTAMP_TZ;
}

@Override
public Type visit(Timestamp type) throws RuntimeException {
return TypeCreator.NULLABLE.TIMESTAMP;
}

@Override
public Type visit(IntervalYear type) throws RuntimeException {
return TypeCreator.NULLABLE.INTERVAL_YEAR;
}

@Override
public Type visit(IntervalDay type) throws RuntimeException {
return TypeCreator.NULLABLE.INTERVAL_DAY;
}

@Override
public Type visit(UUID type) throws RuntimeException {
return TypeCreator.NULLABLE.UUID;
}

@Override
public Type visit(FixedChar type) throws RuntimeException {
return TypeCreator.NULLABLE.fixedChar(type.length());
}

@Override
public Type visit(VarChar type) throws RuntimeException {
return TypeCreator.NULLABLE.varChar(type.length());
}

@Override
public Type visit(FixedBinary type) throws RuntimeException {
return TypeCreator.NULLABLE.fixedBinary(type.length());
}

@Override
public Type visit(Decimal type) throws RuntimeException {
return TypeCreator.NULLABLE.decimal(type.precision(), type.scale());
}

@Override
public Type visit(Struct type) throws RuntimeException {
return TypeCreator.NULLABLE.struct(type.fields());
}

@Override
public Type visit(ListType type) throws RuntimeException {
return TypeCreator.NULLABLE.list(type.elementType());
}

@Override
public Type visit(Map type) throws RuntimeException {
return TypeCreator.NULLABLE.map(type.key(), type.value());
}

@Override
public Type visit(UserDefined type) throws RuntimeException {
return TypeCreator.NULLABLE.userDefined(type.uri(), type.name());
}
}

@Override
protected Type.Struct deriveRecordType() {
return TypeCreator.REQUIRED.struct(
Stream.concat(
getLeft().getRecordType().fields().stream(),
getRight().getRecordType().fields().stream()));
var nullable = new NullableTypeVisitor();
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(t -> t.accept(nullable));
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(t -> t.accept(nullable));
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.plan.Plan;
import io.substrait.relation.Join.JoinType;
import io.substrait.relation.Rel;
import io.substrait.relation.Set.SetOp;
import io.substrait.type.Type;
Expand Down Expand Up @@ -150,6 +151,54 @@ public void emit() {
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), R.I32, N.STRING);
}

@Test
public void leftJoin() {
final List<Type> joinTableType = List.of(R.STRING, R.FP64, R.BINARY);
final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType);

Plan.Root root =
b.root(
b.project(
r -> b.fieldReferences(r, 0, 1, 3),
b.remap(6, 7, 8),
b.join(ji -> b.bool(true), JoinType.LEFT, joinTable, joinTable)));

var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), R.STRING, R.FP64, N.STRING);
}

@Test
public void rightJoin() {
final List<Type> joinTableType = List.of(R.STRING, R.FP64, R.BINARY);
final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType);

Plan.Root root =
b.root(
b.project(
r -> b.fieldReferences(r, 0, 1, 3),
b.remap(6, 7, 8),
b.join(ji -> b.bool(true), JoinType.RIGHT, joinTable, joinTable)));

var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, R.STRING);
}

@Test
public void outerJoin() {
final List<Type> joinTableType = List.of(R.STRING, R.FP64, R.BINARY);
final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType);

Plan.Root root =
b.root(
b.project(
r -> b.fieldReferences(r, 0, 1, 3),
b.remap(6, 7, 8),
b.join(ji -> b.bool(true), JoinType.OUTER, joinTable, joinTable)));

var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING);
}
}

@Nested
Expand Down

0 comments on commit 0efe714

Please sign in to comment.