Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Jan 3, 2025
1 parent 6e8e114 commit d419299
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -123,6 +124,10 @@ public boolean isUniformAndHasConstValue(Slot slot) {
return uniformSet.isUniformAndHasConstValue(slot);
}

public Map<Slot, Optional<Expression>> getSlotUniformValueMap() {
return new HashMap<>(uniformSet.slotUniformValue);
}

public Optional<Expression> getUniformValue(Slot slot) {
return uniformSet.slotUniformValue.get(slot);
}
Expand Down Expand Up @@ -195,6 +200,10 @@ public void addUniformSlotForOuterJoinNullableSide(DataTrait dataTrait) {
uniformSet.addUniformSlotForOuterJoinNullableSide(dataTrait.uniformSet);
}

public void addUniformSlotValueMap(Map<Slot, Optional<Expression>> map) {
uniformSet.add(map);
}

public void addUniformSlotAndLiteral(Slot slot, Expression literal) {
uniformSet.add(slot, literal);
}
Expand Down Expand Up @@ -538,13 +547,17 @@ public void add(Set<Slot> slots) {
}
}

public void add(UniformDescription ud) {
slotUniformValue.putAll(ud.slotUniformValue);
for (Map.Entry<Slot, Optional<Expression>> entry : ud.slotUniformValue.entrySet()) {
public void add(Map<Slot, Optional<Expression>> map) {
slotUniformValue.putAll(map);
for (Map.Entry<Slot, Optional<Expression>> entry : map.entrySet()) {
add(entry.getKey(), entry.getValue().orElse(null));
}
}

public void add(UniformDescription ud) {
add(ud.slotUniformValue);
}

public void add(Slot slot, Expression literal) {
if (null == literal) {
slotUniformValue.putIfAbsent(slot, Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.Repeat;
Expand All @@ -34,9 +35,12 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
* LogicalRepeat.
Expand Down Expand Up @@ -190,7 +194,17 @@ public void computeUnique(DataTrait.Builder builder) {

@Override
public void computeUniform(DataTrait.Builder builder) {
builder.addUniformSlot(child(0).getLogicalProperties().getTrait());
DataTrait dataTrait = child(0).getLogicalProperties().getTrait();
Set<Expression> common = getCommonGroupingSetExpressions();
Map<Slot, Optional<Expression>> slotUniformValue = dataTrait.getSlotUniformValueMap();
Map<Slot, Optional<Expression>> newSlotUniformValue = new HashMap<>();
for (Map.Entry<Slot, Optional<Expression>> entry : slotUniformValue.entrySet()) {
Optional<Expression> value = entry.getValue();
if (!value.isPresent() || value.get() instanceof NullLiteral || common.contains(value.get())) {
newSlotUniformValue.put(entry.getKey(), value);
}
}
builder.addUniformSlotValueMap(newSlotUniformValue);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,13 @@ cherry 3

-- !right_anti_right_side --

-- !grouping --
\N \N
\N 1
\N 2
\N 3
1 \N
1 1
1 2
1 3

Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,7 @@ suite("eliminate_group_by_key_by_uniform") {
qt_left_anti_left_side "select t1.b from test1 t1 left anti join (select * from test2 where b=105) t2 on t1.a=t2.a where t1.b=1 group by t1.b,t1.a order by 1;"
qt_right_semi_right_side "select t2.b from test1 t1 right semi join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t2.a order by 1;"
qt_right_anti_right_side "select t2.b from test1 t1 right anti join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t2.a order by 1;"

//grouping
qt_grouping "select k, k3 from (select 1 as k, a k3, sum(b) as sum_k1 from test1 group by cube(k,a)) t group by k,k3 order by 1,2"
}

0 comments on commit d419299

Please sign in to comment.