diff --git a/docs/querying/aggregations.md b/docs/querying/aggregations.md index 7da5b5c56bd5..8fd99824e0d6 100644 --- a/docs/querying/aggregations.md +++ b/docs/querying/aggregations.md @@ -426,3 +426,26 @@ This makes it possible to compute the results of a filtered and an unfiltered ag "aggregator" : } ``` + +### Grouping Aggregator + +A grouping aggregator can only be used as part of GroupBy queries which have a subtotal spec. It returns a number for +each output row that lets you infer whether a particular dimension is included in the sub-grouping used for that row. You can pass +a *non-empty* list of dimensions to this aggregator which *must* be a subset of dimensions that you are grouping on. +E.g if the aggregator has `["dim1", "dim2"]` as input dimensions and `[["dim1", "dim2"], ["dim1"], ["dim2"], []]` as subtotals, +following can be the possible output of the aggregator + +| subtotal used in query | Output | (bits representation) | +|------------------------|--------|-----------------------| +| `["dim1", "dim2"]` | 0 | (00) | +| `["dim1"]` | 1 | (01) | +| `["dim2"]` | 2 | (10) | +| `[]` | 3 | (11) | + +As illustrated in above example, output number can be thought of as an unsigned n bit number where n is the number of dimensions passed to the aggregator. +The bit at position X is set in this number to 0 if a dimension at position X in input to aggregators is included in the sub-grouping. Otherwise, this bit +is set to 1. + +```json +{ "type" : "grouping", "name" : , "groupings" : [] } +``` \ No newline at end of file diff --git a/docs/querying/groupbyquery.md b/docs/querying/groupbyquery.md index 652953490321..58d654ae699c 100644 --- a/docs/querying/groupbyquery.md +++ b/docs/querying/groupbyquery.md @@ -226,7 +226,9 @@ The response for the query above would look something like: ] ``` -> Notice that dimensions that are not included in an individual subtotalsSpec grouping are returned with a `null` value. This response format represents a behavior change as of Apache Druid 0.18.0. In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. +> Notice that dimensions that are not included in an individual subtotalsSpec grouping are returned with a `null` value. This response format represents a behavior change as of Apache Druid 0.18.0. +> In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. If you were relying on this old behavior to determine whether a particular dimension was not part of +> a subtotal grouping, you can now use [Grouping aggregator](aggregations.md#Grouping Aggregator) instead. ## Implementation details diff --git a/docs/querying/sql.md b/docs/querying/sql.md index a743bf34fd5b..2b83ccb1fff0 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -99,7 +99,8 @@ total. Finally, GROUP BY CUBE computes a grouping set for each combination of gr `GROUP BY CUBE (country, city)` is equivalent to `GROUP BY GROUPING SETS ( (country, city), (country), (city), () )`. Grouping columns that do not apply to a particular row will contain `NULL`. For example, when computing `GROUP BY GROUPING SETS ( (country, city), () )`, the grand total row corresponding to `()` will have `NULL` for the -"country" and "city" columns. +"country" and "city" columns. Column may also be `NULL` if it was `NULL` in the data itself. To differentiate such rows +, you can use `GROUPING` aggregation. When using GROUP BY GROUPING SETS, GROUP BY ROLLUP, or GROUP BY CUBE, be aware that results may not be generated in the order that you specify your grouping sets in the query. If you need results to be generated in a particular order, use @@ -337,6 +338,7 @@ Only the COUNT aggregation can accept DISTINCT. |`LATEST(expr, maxBytesPerString)`|Like `LATEST(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| +|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#Grouping Aggregator) on how to infer this number.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx). diff --git a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java index a974edd764cb..795ea5b6d31c 100644 --- a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java +++ b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java @@ -31,6 +31,7 @@ import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; import org.apache.druid.query.aggregation.FloatSumAggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.aggregation.HistogramAggregatorFactory; import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; @@ -118,7 +119,8 @@ public AggregatorsModule() @JsonSubTypes.Type(name = "longAny", value = LongAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class), - @JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class) + @JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class), + @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class) }) public interface AggregatorFactoryMixin { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java index 2cc5f1b06662..3c7b8d474395 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java @@ -134,6 +134,10 @@ public class AggregatorUtil public static final byte FLOAT_ANY_CACHE_TYPE_ID = 0x44; public static final byte STRING_ANY_CACHE_TYPE_ID = 0x45; + // GROUPING aggregator + public static final byte GROUPING_CACHE_TYPE_ID = 0x46; + + /** * returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg * diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java new file mode 100644 index 000000000000..62fbb47aa095 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import org.apache.druid.annotations.EverythingIsNonnullByDefault; +import org.apache.druid.query.aggregation.constant.LongConstantAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantBufferAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantVectorAggregator; +import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.segment.ColumnInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.utils.CollectionUtils; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * This class implements {@code grouping} function to determine the grouping that a row is part of. Different result rows + * for a query could have different grouping columns when subtotals are used. + * + * This aggregator factory takes following arguments + * - {@code name} - Name of aggregators + * - {@code groupings} - List of dimensions that the user is interested in tracking + * - {@code keyDimensions} - The list of grouping dimensions being included in the result row. This list is a subset of + * {@code groupings}. This argument cannot be passed by the user. It is set by druid engine + * when a particular subtotal spec is being processed. Whenever druid engine processes a new + * subtotal spec, engine sets that subtotal spec as new {@code keyDimensions}. + * + * When key dimensions are updated, {@code value} is updated as well. How the value is determined is captured + * at {@link #groupingId(List, Set)}. + * + * since grouping has to be calculated only once, it could have been implemented as a virtual function or + * post-aggregator etc. We modelled it as an aggregation operator so that its output can be used in a post-aggregator. + * Calcite too models grouping function as an aggregation operator. + * Since it is a non-trivial special aggregation, implementing it required changes in core druid engine to work. There + * were few approaches. We chose the approach that required least changes in core druid. + * Refer to https://github.com/apache/druid/pull/10518#discussion_r532941216 for more details. + * + * Currently, it works in following way + * - On data servers (no change), + * - this factory generates {@link LongConstantAggregator} / {@link LongConstantBufferAggregator} / {@link LongConstantVectorAggregator} + * with keyDimensions as null + * - The aggregators don't actually aggregate anything and their result is not actually used. We could have removed + * these aggregators on data servers but that would result in a signature mismatch on broker and data nodes. That requires + * extra handling and is error-prone. + * - On brokers + * - Results from data node is already being re-processed for each subtotal spec. We made modifications in this path to update the + * grouping id for each row. + * + */ +@EverythingIsNonnullByDefault +public class GroupingAggregatorFactory extends AggregatorFactory +{ + private static final Comparator VALUE_COMPARATOR = Long::compare; + private final String name; + private final List groupings; + private final long value; + @Nullable + private final Set keyDimensions; + + @JsonCreator + public GroupingAggregatorFactory( + @JsonProperty("name") String name, + @JsonProperty("groupings") List groupings + ) + { + this(name, groupings, null); + } + + @VisibleForTesting + GroupingAggregatorFactory( + String name, + List groupings, + @Nullable Set keyDimensions + ) + { + Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); + this.name = name; + this.groupings = groupings; + this.keyDimensions = keyDimensions; + value = groupingId(groupings, keyDimensions); + } + + @Override + public Aggregator factorize(ColumnSelectorFactory metricFactory) + { + return new LongConstantAggregator(value); + } + + @Override + public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) + { + return new LongConstantBufferAggregator(value); + } + + @Override + public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) + { + return new LongConstantVectorAggregator(value); + } + + @Override + public boolean canVectorize(ColumnInspector columnInspector) + { + return true; + } + + /** + * Replace the param {@code keyDimensions} with the new set of key dimensions + */ + public GroupingAggregatorFactory withKeyDimensions(Set newKeyDimensions) + { + return new GroupingAggregatorFactory(name, groupings, newKeyDimensions); + } + + @Override + public Comparator getComparator() + { + return VALUE_COMPARATOR; + } + + @JsonProperty + public List getGroupings() + { + return groupings; + } + + @Override + @JsonProperty + public String getName() + { + return name; + } + + public long getValue() + { + return value; + } + + @Nullable + @Override + public Object combine(@Nullable Object lhs, @Nullable Object rhs) + { + if (null == lhs) { + return rhs; + } + return lhs; + } + + @Override + public AggregatorFactory getCombiningFactory() + { + return new GroupingAggregatorFactory(name, groupings, keyDimensions); + } + + @Override + public List getRequiredColumns() + { + return Collections.singletonList(new GroupingAggregatorFactory(name, groupings, keyDimensions)); + } + + @Override + public Object deserialize(Object object) + { + return object; + } + + @Nullable + @Override + public Object finalizeComputation(@Nullable Object object) + { + return object; + } + + @Override + public List requiredFields() + { + // The aggregator doesn't need to read any fields. + return Collections.emptyList(); + } + + @Override + public ValueType getType() + { + return ValueType.LONG; + } + + @Override + public ValueType getFinalizedType() + { + return ValueType.LONG; + } + + @Override + public int getMaxIntermediateSize() + { + return Long.BYTES; + } + + @Override + public byte[] getCacheKey() + { + CacheKeyBuilder keyBuilder = new CacheKeyBuilder(AggregatorUtil.GROUPING_CACHE_TYPE_ID) + .appendStrings(groupings); + if (null != keyDimensions) { + keyBuilder.appendStrings(keyDimensions); + } + return keyBuilder.build(); + } + + /** + * Given the list of grouping dimensions, returns a long value where each bit at position X in the returned value + * corresponds to the dimension in groupings at same position X. X is the position relative to the right end. if + * keyDimensions contain the grouping dimension at position X, the bit is set to 0 at position X, otherwise it is + * set to 1. + * + * groupings keyDimensions value (3 least significant bits) value (long) + * a,b,c [a] 011 3 + * a,b,c [b] 101 5 + * a,b,c [c] 110 6 + * a,b,c [a,b] 001 1 + * a,b,c [a,c] 010 2 + * a,b,c [b,c] 100 4 + * a,b,c [a,b,c] 000 0 + * a,b,c [] 111 7 // None included + * a,b,c 000 0 // All included + */ + private long groupingId(List groupings, @Nullable Set keyDimensions) + { + Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions"); + // (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit + // also makes sure that values are always positive. + Preconditions.checkArgument( + groupings.size() < Long.SIZE, + "Number of dimensions %s is more than supported %s", + groupings.size(), + Long.SIZE - 1 + ); + long temp = 0L; + for (String groupingDimension : groupings) { + temp = temp << 1; + if (!isDimensionIncluded(groupingDimension, keyDimensions)) { + temp = temp | 1L; + } + } + return temp; + } + + private boolean isDimensionIncluded(String dimToCheck, @Nullable Set keyDimensions) + { + if (null == keyDimensions) { + // All dimensions are included + return true; + } else { + return keyDimensions.contains(dimToCheck); + } + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GroupingAggregatorFactory factory = (GroupingAggregatorFactory) o; + return name.equals(factory.name) && + groupings.equals(factory.groupings) && + Objects.equals(keyDimensions, factory.keyDimensions); + } + + @Override + public int hashCode() + { + return Objects.hash(name, groupings, keyDimensions); + } + + @Override + public String toString() + { + return "GroupingAggregatorFactory{" + + "name='" + name + '\'' + + ", groupings=" + groupings + + ", keyDimensions=" + keyDimensions + + '}'; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java new file mode 100644 index 000000000000..1fae2715b1d4 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.Aggregator; + +/** + * This aggregator is a no-op aggregator with a fixed non-null output value. It can be used in scenarios where + * result is constant such as {@link org.apache.druid.query.aggregation.GroupingAggregatorFactory} + */ +public class LongConstantAggregator implements Aggregator +{ + private final long value; + + public LongConstantAggregator(long value) + { + this.value = value; + } + + @Override + public void aggregate() + { + // No-op + } + + @Override + public Object get() + { + return value; + } + + @Override + public float getFloat() + { + return (float) value; + } + + @Override + public long getLong() + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java new file mode 100644 index 000000000000..1ddf11b57d7b --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.BufferAggregator; + +import java.nio.ByteBuffer; + +/** + * {@link BufferAggregator} variant of {@link LongConstantAggregator} + */ +public class LongConstantBufferAggregator implements BufferAggregator +{ + private final long value; + + public LongConstantBufferAggregator(long value) + { + this.value = value; + } + + @Override + public void init(ByteBuffer buf, int position) + { + // Since we always return a constant value despite what is in the buffer, there is no need to + // update the buffer at all + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + + } + + @Override + public Object get(ByteBuffer buf, int position) + { + return value; + } + + @Override + public float getFloat(ByteBuffer buf, int position) + { + return (float) value; + } + + @Override + public long getLong(ByteBuffer buf, int position) + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java new file mode 100644 index 000000000000..4af4b8a9fe77 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.VectorAggregator; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +/** + * {@link VectorAggregator} variant of {@link LongConstantAggregator} + */ +public class LongConstantVectorAggregator implements VectorAggregator +{ + private final long value; + + public LongConstantVectorAggregator(long value) + { + this.value = value; + } + + @Override + public void init(ByteBuffer buf, int position) + { + // Since we always return a constant value despite what is in the buffer, there is no need to + // update the buffer at all + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + + } + + @Override + public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset) + { + + } + + @Override + public Object get(ByteBuffer buf, int position) + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java index f86b2f0c3425..7e753af1331e 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java @@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.ResourceLimitExceededException; +import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.ResultRow; @@ -66,7 +67,7 @@ public interface ResultSupplier extends Closeable * @param dimensionsToInclude list of dimensions to include, or null to include all dimensions. Used by processing * of subtotals. If specified, the results will not necessarily be fully grouped. */ - Sequence results(@Nullable List dimensionsToInclude); + Sequence results(@Nullable List dimensionsToInclude); } private GroupByRowProcessor() @@ -140,7 +141,7 @@ public ByteBuffer get() return new ResultSupplier() { @Override - public Sequence results(@Nullable List dimensionsToInclude) + public Sequence results(@Nullable List dimensionsToInclude) { return getRowsFromGrouper(query, grouper, dimensionsToInclude); } @@ -156,7 +157,7 @@ public void close() throws IOException private static Sequence getRowsFromGrouper( final GroupByQuery query, final Grouper grouper, - @Nullable List dimensionsToInclude + @Nullable List dimensionsToInclude ) { return new BaseSequence<>( diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index c099eedd7617..9a87cb59b7a2 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -44,6 +44,7 @@ import org.apache.druid.query.BaseQuery; import org.apache.druid.query.ColumnSelectorPlus; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.dimension.ColumnSelectorStrategy; import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory; import org.apache.druid.query.dimension.DimensionSpec; @@ -87,6 +88,7 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.function.ToLongFunction; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -534,21 +536,43 @@ public static CloseableGrouperIterator makeGrouperIterat public static CloseableGrouperIterator makeGrouperIterator( final Grouper grouper, final GroupByQuery query, - @Nullable final List dimsToInclude, + @Nullable final List dimsToInclude, final Closeable closeable ) { final boolean includeTimestamp = query.getResultRowHasTimestamp(); final BitSet dimsToIncludeBitSet = new BitSet(query.getDimensions().size()); final int resultRowDimensionStart = query.getResultRowDimensionStart(); + final BitSet groupingAggregatorsBitSet = new BitSet(query.getAggregatorSpecs().size()); + final Object[] groupingAggregatorValues = new Long[query.getAggregatorSpecs().size()]; if (dimsToInclude != null) { - for (String dimension : dimsToInclude) { - final int dimIndex = query.getResultRowSignature().indexOf(dimension); + for (DimensionSpec dimensionSpec : dimsToInclude) { + String outputName = dimensionSpec.getOutputName(); + final int dimIndex = query.getResultRowSignature().indexOf(outputName); if (dimIndex >= 0) { dimsToIncludeBitSet.set(dimIndex - resultRowDimensionStart); } } + + // KeyDimensionNames are the input column names of dimensions. Its required since aggregators are not aware of the + // output column names. + // As we exclude certain dimensions from the result row, the value for any grouping_id aggregators have to change + // to reflect the new grouping dimensions, that aggregation is being done upon. We will mark the indices which have + // grouping aggregators and update the value for each row at those indices. + Set keyDimensionNames = dimsToInclude.stream() + .map(DimensionSpec::getDimension) + .collect(Collectors.toSet()); + for (int i = 0; i < query.getAggregatorSpecs().size(); i++) { + AggregatorFactory aggregatorFactory = query.getAggregatorSpecs().get(i); + if (aggregatorFactory instanceof GroupingAggregatorFactory) { + + groupingAggregatorsBitSet.set(i); + groupingAggregatorValues[i] = ((GroupingAggregatorFactory) aggregatorFactory) + .withKeyDimensions(keyDimensionNames) + .getValue(); + } + } } return new CloseableGrouperIterator<>( @@ -576,7 +600,13 @@ public static CloseableGrouperIterator makeGrouperIterat // Add aggregations. final int resultRowAggregatorStart = query.getResultRowAggregatorStart(); for (int i = 0; i < entry.getValues().length; i++) { - resultRow.set(resultRowAggregatorStart + i, entry.getValues()[i]); + if (dimsToInclude != null && groupingAggregatorsBitSet.get(i)) { + // Override with a new value, reflecting the new set of grouping dimensions + resultRow.set(resultRowAggregatorStart + i, groupingAggregatorValues[i]); + } else { + resultRow.set(resultRowAggregatorStart + i, entry.getValues()[i]); + + } } return resultRow; diff --git a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java index e81eded2f9c0..5e7c8b2d6bf6 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java @@ -405,6 +405,8 @@ public Sequence processSubtotalsSpec( // Iterate through each subtotalSpec, build results for it and add to subtotalsResults for (List subtotalSpec : subtotals) { final ImmutableSet dimsInSubtotalSpec = ImmutableSet.copyOf(subtotalSpec); + // Dimension spec including dimension name and output name + final List subTotalDimensionSpec = new ArrayList<>(dimsInSubtotalSpec.size()); final List dimensions = query.getDimensions(); final List newDimensions = new ArrayList<>(); @@ -418,6 +420,7 @@ public Sequence processSubtotalsSpec( dimensionSpec.getOutputType() ) ); + subTotalDimensionSpec.add(dimensionSpec); } else { // Insert dummy dimension so all subtotals queries have ResultRows with the same shape. // Use a field name that does not appear in the main query result, to assure the result will be null. @@ -447,7 +450,7 @@ public Sequence processSubtotalsSpec( // Since subtotalSpec is a prefix of base query dimensions, so results from base query are also sorted // by subtotalSpec as needed by stream merging. subtotalsResults.add( - processSubtotalsResultAndOptionallyClose(() -> resultSupplierOneFinal, subtotalSpec, subtotalQuery, false) + processSubtotalsResultAndOptionallyClose(() -> resultSupplierOneFinal, subTotalDimensionSpec, subtotalQuery, false) ); } else { // Since subtotalSpec is not a prefix of base query dimensions, so results from base query are not sorted @@ -459,7 +462,7 @@ public Sequence processSubtotalsSpec( Supplier resultSupplierTwo = () -> GroupByRowProcessor.process( baseSubtotalQuery, subtotalQuery, - resultSupplierOneFinal.results(subtotalSpec), + resultSupplierOneFinal.results(subTotalDimensionSpec), configSupplier.get(), resource, spillMapper, @@ -468,7 +471,7 @@ public Sequence processSubtotalsSpec( ); subtotalsResults.add( - processSubtotalsResultAndOptionallyClose(resultSupplierTwo, subtotalSpec, subtotalQuery, true) + processSubtotalsResultAndOptionallyClose(resultSupplierTwo, subTotalDimensionSpec, subtotalQuery, true) ); } } @@ -486,7 +489,7 @@ public Sequence processSubtotalsSpec( private Sequence processSubtotalsResultAndOptionallyClose( Supplier baseResultsSupplier, - List dimsToInclude, + List dimsToInclude, GroupByQuery subtotalQuery, boolean closeOnSequenceRead ) diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java new file mode 100644 index 000000000000..0117bf1e82a5 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.google.common.collect.Sets; +import junitparams.converters.Nullable; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.query.aggregation.constant.LongConstantAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantBufferAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantVectorAggregator; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +@RunWith(Enclosed.class) +public class GroupingAggregatorFactoryTest +{ + public static GroupingAggregatorFactory makeFactory(String[] groupings, @Nullable String[] keyDims) + { + GroupingAggregatorFactory factory = new GroupingAggregatorFactory("name", Arrays.asList(groupings)); + if (null != keyDims) { + factory = factory.withKeyDimensions(Sets.newHashSet(keyDims)); + } + return factory; + } + + public static class NewAggregatorTests + { + private ColumnSelectorFactory metricFactory; + + @Before + public void setup() + { + metricFactory = EasyMock.mock(ColumnSelectorFactory.class); + } + + @Test + public void testNewAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Aggregator aggregator = factory.factorize(metricFactory); + Assert.assertEquals(LongConstantAggregator.class, aggregator.getClass()); + Assert.assertEquals(1, aggregator.getLong()); + } + + @Test + public void testNewBufferAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + BufferAggregator aggregator = factory.factorizeBuffered(metricFactory); + Assert.assertEquals(LongConstantBufferAggregator.class, aggregator.getClass()); + Assert.assertEquals(1, aggregator.getLong(null, 0)); + } + + @Test + public void testNewVectorAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Assert.assertTrue(factory.canVectorize(metricFactory)); + VectorAggregator aggregator = factory.factorizeVector(null); + Assert.assertEquals(LongConstantVectorAggregator.class, aggregator.getClass()); + Assert.assertEquals(1L, aggregator.get(null, 0)); + } + + @Test + public void testWithKeyDimensions() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Aggregator aggregator = factory.factorize(metricFactory); + Assert.assertEquals(1, aggregator.getLong()); + factory = factory.withKeyDimensions(Sets.newHashSet("b")); + aggregator = factory.factorize(metricFactory); + Assert.assertEquals(2, aggregator.getLong()); + } + } + + public static class GroupingDimensionsTest + { + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Test + public void testFactory_nullGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Must have a non-empty grouping dimensions"); + GroupingAggregatorFactory factory = new GroupingAggregatorFactory("name", null, Sets.newHashSet("b")); + } + + @Test + public void testFactory_emptyGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Must have a non-empty grouping dimensions"); + makeFactory(new String[0], null); + } + + @Test + public void testFactory_highNumberOfGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(StringUtils.format( + "Number of dimensions %d is more than supported %d", + Long.SIZE, + Long.SIZE - 1 + )); + makeFactory(new String[Long.SIZE], null); + } + } + + @RunWith(Parameterized.class) + public static class ValueTests + { + private final GroupingAggregatorFactory factory; + private final long value; + + public ValueTests(String[] groupings, @Nullable String[] keyDimensions, long value) + { + factory = makeFactory(groupings, keyDimensions); + this.value = value; + } + + @Parameterized.Parameters + public static Collection arguments() + { + String[] maxGroupingList = new String[Long.SIZE - 1]; + for (int i = 0; i < maxGroupingList.length; i++) { + maxGroupingList[i] = String.valueOf(i); + } + return Arrays.asList(new Object[][]{ + {new String[]{"a", "b"}, new String[0], 3}, + {new String[]{"a", "b"}, null, 0}, + {new String[]{"a", "b"}, new String[]{"a"}, 1}, + {new String[]{"a", "b"}, new String[]{"b"}, 2}, + {new String[]{"a", "b"}, new String[]{"a", "b"}, 0}, + {new String[]{"b", "a"}, new String[]{"a"}, 2}, + {maxGroupingList, null, 0}, + {maxGroupingList, new String[0], Long.MAX_VALUE} + }); + } + + @Test + public void testValue() + { + Assert.assertEquals(value, factory.factorize(null).getLong()); + } + } +} diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java new file mode 100644 index 000000000000..d4f8b02220bf --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class LongConstantAggregatorTest +{ + private long randomVal; + private LongConstantAggregator aggregator; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantAggregator(randomVal); + } + + @Test + public void testLong() + { + Assert.assertEquals(randomVal, aggregator.getLong()); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(); + Assert.assertEquals(randomVal, aggregator.getLong()); + } + + @Test + public void testFloat() + { + Assert.assertEquals((float) randomVal, aggregator.getFloat(), 0.0001f); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get()); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java new file mode 100644 index 000000000000..0608fca85b74 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class LongConstantBufferAggregatorTest +{ + private long randomVal; + private LongConstantBufferAggregator aggregator; + private ByteBuffer byteBuffer; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantBufferAggregator(randomVal); + byteBuffer = EasyMock.mock(ByteBuffer.class); + EasyMock.replay(byteBuffer); + EasyMock.verifyUnexpectedCalls(byteBuffer); + } + + @Test + public void testLong() + { + Assert.assertEquals(randomVal, aggregator.getLong(byteBuffer, 0)); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(byteBuffer, 0); + Assert.assertEquals(randomVal, aggregator.getLong(byteBuffer, 0)); + } + + @Test + public void testFloat() + { + Assert.assertEquals((float) randomVal, aggregator.getFloat(byteBuffer, 0), 0.0001f); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java new file mode 100644 index 000000000000..f62dd0369c61 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class LongConstantVectorAggregatorTest +{ + private long randomVal; + private LongConstantVectorAggregator aggregator; + private ByteBuffer byteBuffer; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantVectorAggregator(randomVal); + byteBuffer = EasyMock.mock(ByteBuffer.class); + EasyMock.replay(byteBuffer); + EasyMock.verifyUnexpectedCalls(byteBuffer); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(byteBuffer, 0, 1, 10); + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } + + @Test + public void testAggregateWithIndirection() + { + aggregator.aggregate(byteBuffer, 2, new int[]{2, 3}, null, 0); + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java new file mode 100644 index 000000000000..bbff23f2375e --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.aggregation.builtin; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class GroupingSqlAggregator implements SqlAggregator +{ + @Override + public SqlAggFunction calciteFunction() + { + return SqlStdOperatorTable.GROUPING; + } + + @Nullable + @Override + public Aggregation toDruidAggregation( + PlannerContext plannerContext, + RowSignature rowSignature, + VirtualColumnRegistry virtualColumnRegistry, + RexBuilder rexBuilder, + String name, + AggregateCall aggregateCall, + Project project, + List existingAggregations, + boolean finalizeAggregations + ) + { + List arguments = aggregateCall.getArgList() + .stream() + .map(i -> getColumnName( + plannerContext, + rowSignature, + project, + virtualColumnRegistry, + i + )) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + if (arguments.size() < aggregateCall.getArgList().size()) { + return null; + } + + for (Aggregation existing : existingAggregations) { + for (AggregatorFactory factory : existing.getAggregatorFactories()) { + if (!(factory instanceof GroupingAggregatorFactory)) { + continue; + } + GroupingAggregatorFactory groupingFactory = (GroupingAggregatorFactory) factory; + if (groupingFactory.getGroupings().equals(arguments) + && groupingFactory.getName().equals(name)) { + return Aggregation.create(groupingFactory); + } + } + } + AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments); + return Aggregation.create(factory); + } + + @Nullable + private String getColumnName( + PlannerContext plannerContext, + RowSignature rowSignature, + Project project, + VirtualColumnRegistry virtualColumnRegistry, + int fieldNumber + ) + { + RexNode node = Expressions.fromFieldAccess(rowSignature, project, fieldNumber); + if (null == node) { + return null; + } + DruidExpression expression = Expressions.toDruidExpression(plannerContext, rowSignature, node); + if (null == expression) { + return null; + } + if (expression.isDirectColumnAccess()) { + return expression.getDirectColumn(); + } + + VirtualColumn virtualColumn = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( + plannerContext, + expression, + node.getType() + ); + return virtualColumn.getOutputName(); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index e1fb0b48fe54..9fbc8dffb53a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -37,6 +37,7 @@ import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; +import org.apache.druid.sql.calcite.aggregation.builtin.GroupingSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MaxSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MinSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.SumSqlAggregator; @@ -130,6 +131,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new MaxSqlAggregator()) .add(new SumSqlAggregator()) .add(new SumZeroSqlAggregator()) + .add(new GroupingSqlAggregator()) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 0b40cc66badf..0b94b4ca2730 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -55,6 +55,7 @@ import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; import org.apache.druid.query.aggregation.LongMinAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; @@ -74,6 +75,7 @@ import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; +import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; @@ -123,6 +125,7 @@ import org.junit.runner.RunWith; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -12029,7 +12032,7 @@ public void testGroupingSets() throws Exception cannotVectorize(); testQuery( - "SELECT dim2, gran, SUM(cnt)\n" + "SELECT dim2, gran, SUM(cnt), GROUPING(dim2, gran)\n" + "FROM (SELECT FLOOR(__time TO MONTH) AS gran, COALESCE(dim2, '') dim2, cnt FROM druid.foo) AS x\n" + "GROUP BY GROUPING SETS ( (dim2, gran), (dim2), (gran), () )", ImmutableList.of( @@ -12055,7 +12058,10 @@ public void testGroupingSets() throws Exception new DefaultDimensionSpec("v1", "d1", ValueType.LONG) ) ) - .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Arrays.asList("v0", "v1")) + )) .setSubtotalsSpec( ImmutableList.of( ImmutableList.of("d0", "d1"), @@ -12068,21 +12074,143 @@ public void testGroupingSets() throws Exception .build() ), ImmutableList.of( - new Object[]{"", timestamp("2000-01-01"), 2L}, - new Object[]{"", timestamp("2001-01-01"), 1L}, - new Object[]{"a", timestamp("2000-01-01"), 1L}, - new Object[]{"a", timestamp("2001-01-01"), 1L}, - new Object[]{"abc", timestamp("2001-01-01"), 1L}, - new Object[]{"", null, 3L}, - new Object[]{"a", null, 2L}, - new Object[]{"abc", null, 1L}, - new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L}, - new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L}, - new Object[]{NULL_STRING, null, 6L} + new Object[]{"", timestamp("2000-01-01"), 2L, 0L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"", null, 3L, 1L}, + new Object[]{"a", null, 2L, 1L}, + new Object[]{"abc", null, 1L, 1L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, null, 6L, 3L} + ) + ); + } + + @Test + public void testGroupingAggregatorDifferentOrder() throws Exception + { + // Cannot vectorize due to virtual columns. + cannotVectorize(); + + testQuery( + "SELECT dim2, gran, SUM(cnt), GROUPING(gran, dim2)\n" + + "FROM (SELECT FLOOR(__time TO MONTH) AS gran, COALESCE(dim2, '') dim2, cnt FROM druid.foo) AS x\n" + + "GROUP BY GROUPING SETS ( (dim2, gran), (dim2), (gran), () )", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim2\"),\"dim2\",'')", + ValueType.STRING + ), + expressionVirtualColumn( + "v1", + "timestamp_floor(\"__time\",'P1M',null,'UTC')", + ValueType.LONG + ) + ) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "d0"), + new DefaultDimensionSpec("v1", "d1", ValueType.LONG) + ) + ) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Arrays.asList("v1", "v0")) + )) + .setSubtotalsSpec( + ImmutableList.of( + ImmutableList.of("d0", "d1"), + ImmutableList.of("d0"), + ImmutableList.of("d1"), + ImmutableList.of() + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", timestamp("2000-01-01"), 2L, 0L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"", null, 3L, 2L}, + new Object[]{"a", null, 2L, 2L}, + new Object[]{"abc", null, 1L, 2L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, null, 6L, 3L} ) ); } + @Test + public void testGroupingAggregatorWithPostAggregator() throws Exception + { + List resultList; + if (NullHandling.sqlCompatible()) { + resultList = ImmutableList.of( + new Object[]{NULL_STRING, 2L, 0L, "INDIVIDUAL"}, + new Object[]{"", 1L, 0L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 0L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 0L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 1L, "ALL"} + ); + } else { + resultList = ImmutableList.of( + new Object[]{"", 3L, 0L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 0L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 0L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 1L, "ALL"} + ); + } + testQuery( + "SELECT dim2, SUM(cnt), GROUPING(dim2), \n" + + "CASE WHEN GROUPING(dim2) = 1 THEN 'ALL' ELSE 'INDIVIDUAL' END\n" + + "FROM druid.foo\n" + + "GROUP BY GROUPING SETS ( (dim2), () )", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("dim2", "d0", ValueType.STRING) + ) + ) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Collections.singletonList("dim2")) + )) + .setSubtotalsSpec( + ImmutableList.of( + ImmutableList.of("d0"), + ImmutableList.of() + ) + ) + .setPostAggregatorSpecs(Collections.singletonList(new ExpressionPostAggregator( + "p0", + "case_searched((\"a1\" == 1),'ALL','INDIVIDUAL')", + null, + ExprMacroTable.nil() + ))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + resultList + ); + } + @Test public void testGroupingSetsWithNumericDimension() throws Exception {