Skip to content

Commit

Permalink
[feature](Nereids) support agg state type in create table (#32171)
Browse files Browse the repository at this point in the history
this PR introduce a behavior change, syntax of create table with agg_state type is changed.
  • Loading branch information
morrySnow authored and Doris-Extras committed Mar 15, 2024
1 parent 62023d7 commit ea2fbfa
Show file tree
Hide file tree
Showing 24 changed files with 121 additions and 111 deletions.
4 changes: 2 additions & 2 deletions docs/en/docs/sql-manual/sql-reference/Data-Types/AGG_STATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ Create table example:
```sql
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int),
k3 agg_state group_concat(string)
k2 agg_state<max_by(int not null,int)> generic,
k3 agg_state<group_concat(string)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ under the License.
```sql
create table a_table(
k1 int null,
k2 agg_state max_by(int not null,int),
k3 agg_state group_concat(string)
k2 agg_state<max_by(int not null,int)> generic,
k3 agg_state<group_concat(string)> generic
)
aggregate key (k1)
distributed BY hash(k1) buckets 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public boolean getResultIsNullable() {
@Override
public String toSql(int depth) {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("AGG_STATE(");
stringBuilder.append("AGG_STATE<").append(functionName).append("(");
for (int i = 0; i < subTypes.size(); i++) {
if (i > 0) {
stringBuilder.append(", ");
Expand All @@ -82,7 +82,7 @@ public String toSql(int depth) {
stringBuilder.append(" NULL");
}
}
stringBuilder.append(")");
stringBuilder.append(")>");
return stringBuilder.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ FRONTENDS: 'FRONTENDS';
FULL: 'FULL';
FUNCTION: 'FUNCTION';
FUNCTIONS: 'FUNCTIONS';
GENERIC: 'GENERIC';
GLOBAL: 'GLOBAL';
GRANT: 'GRANT';
GRANTS: 'GRANTS';
Expand Down
15 changes: 13 additions & 2 deletions fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,10 @@ columnDefs

columnDef
: colName=identifier type=dataType
KEY? (aggType=aggTypeDef)? ((NOT NULL) | NULL)? (AUTO_INCREMENT (LEFT_PAREN autoIncInitValue=number RIGHT_PAREN)?)?
KEY?
(aggType=aggTypeDef)?
((NOT)? NULL)?
(AUTO_INCREMENT (LEFT_PAREN autoIncInitValue=number RIGHT_PAREN)?)?
(DEFAULT (nullValue=NULL | INTEGER_VALUE | stringValue=STRING_LITERAL
| CURRENT_TIMESTAMP (LEFT_PAREN defaultValuePrecision=number RIGHT_PAREN)?))?
(ON UPDATE CURRENT_TIMESTAMP (LEFT_PAREN onUpdateValuePrecision=number RIGHT_PAREN)?)?
Expand Down Expand Up @@ -587,7 +590,7 @@ rollupDef
;

aggTypeDef
: MAX | MIN | SUM | REPLACE | REPLACE_IF_NOT_NULL | HLL_UNION | BITMAP_UNION | QUANTILE_UNION
: MAX | MIN | SUM | REPLACE | REPLACE_IF_NOT_NULL | HLL_UNION | BITMAP_UNION | QUANTILE_UNION | GENERIC
;

tabletList
Expand Down Expand Up @@ -846,10 +849,17 @@ unitIdentifier
: YEAR | MONTH | WEEK | DAY | HOUR | MINUTE | SECOND
;

dataTypeWithNullable
: dataType ((NOT)? NULL)?
;

dataType
: complex=ARRAY LT dataType GT #complexDataType
| complex=MAP LT dataType COMMA dataType GT #complexDataType
| complex=STRUCT LT complexColTypeList GT #complexDataType
| AGG_STATE LT functionNameIdentifier
LEFT_PAREN dataTypes+=dataTypeWithNullable
(COMMA dataTypes+=dataTypeWithNullable)* RIGHT_PAREN GT #aggStateDataType
| primitiveColType (LEFT_PAREN (INTEGER_VALUE | ASTERISK)
(COMMA INTEGER_VALUE)* RIGHT_PAREN)? #primitiveDataType
;
Expand Down Expand Up @@ -1061,6 +1071,7 @@ nonReserved
| FREE
| FRONTENDS
| FUNCTION
| GENERIC
| GLOBAL
| GRAPH
| GROUPING
Expand Down
33 changes: 13 additions & 20 deletions fe/fe-core/src/main/cup/sql_parser.cup
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ terminal String
KW_FULL,
KW_FUNCTION,
KW_FUNCTIONS,
KW_GENERIC,
KW_GLOBAL,
KW_GRANT,
KW_GRANTS,
Expand Down Expand Up @@ -3132,6 +3133,10 @@ opt_agg_type ::=
{:
RESULT = AggregateType.QUANTILE_UNION;
:}
| KW_GENERIC
{:
RESULT = AggregateType.GENERIC;
:}
;

opt_partition ::=
Expand Down Expand Up @@ -3731,31 +3736,11 @@ column_definition ::=
ColumnDef columnDef = new ColumnDef(columnName, typeDef, isKey, null, isAllowNull, autoIncInitValue, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName KW_AGG_STATE IDENT:fnName LPAREN type_def_nullable_list:list RPAREN opt_auto_inc_init_value:autoIncInitValue opt_default_value:defaultValue opt_comment:comment
{:
for (TypeDef def : list) {
def.analyze(null);
}
ColumnDef columnDef = new ColumnDef(columnName, new TypeDef(Expr.createAggStateType(fnName,
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()))), false, AggregateType.GENERIC_AGGREGATION, false, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName type_def:typeDef opt_is_key:isKey opt_agg_type:aggType opt_is_allow_null:isAllowNull opt_auto_inc_init_value:autoIncInitValue opt_default_value:defaultValue opt_comment:comment
{:
ColumnDef columnDef = new ColumnDef(columnName, typeDef, isKey, aggType, isAllowNull, autoIncInitValue, defaultValue, comment);
RESULT = columnDef;
:}
| ident:columnName KW_AGG_STATE opt_is_key:isKey opt_agg_type:aggType LPAREN type_def_nullable_list:list RPAREN opt_default_value:defaultValue opt_comment:comment
{:
for (TypeDef def : list) {
def.analyze(null);
}
ColumnDef columnDef = new ColumnDef(columnName, new TypeDef(Expr.createAggStateType(aggType.name().toLowerCase(),
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()))), isKey, AggregateType.GENERIC_AGGREGATION, false, defaultValue, comment);
RESULT = columnDef;
:}
;

index_definition ::=
Expand Down Expand Up @@ -6553,6 +6538,12 @@ type ::=
{: ScalarType type = ScalarType.createHllType();
RESULT = type;
:}
| KW_AGG_STATE LESSTHAN IDENT:fnName LPAREN type_def_nullable_list:list RPAREN GREATERTHAN
{:
RESULT = Expr.createAggStateType(fnName,
list.stream().map(TypeDef::getType).collect(Collectors.toList()),
list.stream().map(TypeDef::getNullable).collect(Collectors.toList()));
:}
| KW_ALL
{: RESULT = Type.ALL; :}
;
Expand Down Expand Up @@ -7782,6 +7773,8 @@ keyword ::=
{: RESULT = id; :}
| KW_GLOBAL:id
{: RESULT = id; :}
| KW_GENERIC:id
{: RESULT = id; :}
| KW_GRAPH:id
{: RESULT = id; :}
| KW_HASH:id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,12 @@ public void analyze(boolean isOlap) throws AnalysisException {
}

// check if aggregate type is valid
if (aggregateType != AggregateType.GENERIC_AGGREGATION
if (aggregateType != AggregateType.GENERIC
&& !aggregateType.checkCompatibility(type.getPrimitiveType())) {
throw new AnalysisException(String.format("Aggregate type %s is not compatible with primitive type %s",
toString(), type.toSql()));
}
if (aggregateType == AggregateType.GENERIC_AGGREGATION) {
if (aggregateType == AggregateType.GENERIC) {
if (!SessionVariable.enableAggState()) {
throw new AnalysisException("agg state not enable, need set enable_agg_state=true");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ private MVColumnItem buildMVColumnItem(Analyzer analyzer, FunctionCallExpr funct
type = Type.BIGINT;
break;
default:
mvAggregateType = AggregateType.GENERIC_AGGREGATION;
mvAggregateType = AggregateType.GENERIC;
if (functionCallExpr.getParams().isDistinct() || functionCallExpr.getParams().isStar()) {
throw new AnalysisException(
"The Materialized-View's generic aggregation not support star or distinct");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public enum AggregateType {
NONE("NONE"),
BITMAP_UNION("BITMAP_UNION"),
QUANTILE_UNION("QUANTILE_UNION"),
GENERIC_AGGREGATION("GENERIC_AGGREGATION");
GENERIC("GENERIC");

private static EnumMap<AggregateType, EnumSet<PrimitiveType>> compatibilityMap;

Expand Down
35 changes: 5 additions & 30 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ public class Column implements Writable, GsonPostProcessable {
@SerializedName(value = "uniqueId")
private int uniqueId;

@SerializedName(value = "genericAggregationName")
private String genericAggregationName;

@SerializedName(value = "clusterKeyId")
private int clusterKeyId = -1;

Expand Down Expand Up @@ -244,8 +241,8 @@ public Column(String name, Type type, boolean isKey, AggregateType aggregateType
c.setIsAllowNull(aggState.getSubTypeNullables().get(i));
addChildrenColumn(c);
}
this.genericAggregationName = aggState.getFunctionName();
this.aggregationType = AggregateType.GENERIC_AGGREGATION;
this.isAllowNull = false;
this.aggregationType = AggregateType.GENERIC;
}
}

Expand Down Expand Up @@ -449,11 +446,7 @@ public AggregateType getAggregationType() {
}

public String getAggregationString() {
if (getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
return getGenericAggregationString();
} else {
return getAggregationType().name();
}
return getAggregationType().name();
}

public boolean isAggregated() {
Expand Down Expand Up @@ -764,22 +757,6 @@ public String toSql(boolean isUniqueTable) {
return toSql(isUniqueTable, false);
}

public String getGenericAggregationString() {
StringBuilder sb = new StringBuilder();
sb.append(genericAggregationName).append("(");
for (int i = 0; i < children.size(); i++) {
if (i != 0) {
sb.append(", ");
}
sb.append(children.get(i).getType().toSql());
if (children.get(i).isAllowNull()) {
sb.append(" NULL");
}
}
sb.append(")");
return sb.toString();
}

public String toSql(boolean isUniqueTable, boolean isCompatible) {
StringBuilder sb = new StringBuilder();
sb.append("`").append(name).append("` ");
Expand All @@ -791,11 +768,9 @@ public String toSql(boolean isUniqueTable, boolean isCompatible) {
} else {
sb.append(typeStr);
}
if (aggregationType == AggregateType.GENERIC_AGGREGATION) {
sb.append(" ").append(getGenericAggregationString());
} else if (aggregationType != null && aggregationType != AggregateType.NONE && !isUniqueTable
if (aggregationType != null && aggregationType != AggregateType.NONE && !isUniqueTable
&& !isAggregationTypeImplicit) {
sb.append(" ").append(aggregationType.name());
sb.append(" ").append(aggregationType.toSql());
}
if (isAllowNull) {
sb.append(" NULL");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.analysis.TableName;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.BuiltinAggregateFunctions;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.ScalarType;
Expand All @@ -42,6 +43,7 @@
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.DorisParser.AddConstraintContext;
import org.apache.doris.nereids.DorisParser.AggClauseContext;
import org.apache.doris.nereids.DorisParser.AggStateDataTypeContext;
import org.apache.doris.nereids.DorisParser.AliasQueryContext;
import org.apache.doris.nereids.DorisParser.AliasedQueryContext;
import org.apache.doris.nereids.DorisParser.AlterMTMVContext;
Expand Down Expand Up @@ -75,6 +77,7 @@
import org.apache.doris.nereids.DorisParser.CreateRowPolicyContext;
import org.apache.doris.nereids.DorisParser.CreateTableContext;
import org.apache.doris.nereids.DorisParser.CteContext;
import org.apache.doris.nereids.DorisParser.DataTypeWithNullableContext;
import org.apache.doris.nereids.DorisParser.DateCeilContext;
import org.apache.doris.nereids.DorisParser.DateFloorContext;
import org.apache.doris.nereids.DorisParser.Date_addContext;
Expand Down Expand Up @@ -422,6 +425,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
Expand Down Expand Up @@ -2519,7 +2523,9 @@ public ColumnDefinition visitColumnDef(ColumnDefContext ctx) {
String colName = ctx.colName.getText();
DataType colType = ctx.type instanceof PrimitiveDataTypeContext
? visitPrimitiveDataType(((PrimitiveDataTypeContext) ctx.type))
: visitComplexDataType(((ComplexDataTypeContext) ctx.type));
: ctx.type instanceof ComplexDataTypeContext
? visitComplexDataType((ComplexDataTypeContext) ctx.type)
: visitAggStateDataType((AggStateDataTypeContext) ctx.type);
colType = colType.conversion();
boolean isKey = ctx.KEY() != null;
boolean isNotNull = ctx.NOT() != null;
Expand Down Expand Up @@ -3248,6 +3254,32 @@ private ExplainLevel parseExplainPlanType(PlanTypeContext planTypeContext) {
return ExplainLevel.ALL_PLAN;
}

@Override
public Pair<DataType, Boolean> visitDataTypeWithNullable(DataTypeWithNullableContext ctx) {
return ParserUtils.withOrigin(ctx, () -> Pair.of(typedVisit(ctx.dataType()), ctx.NOT() == null));
}

@Override
public DataType visitAggStateDataType(AggStateDataTypeContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
List<Pair<DataType, Boolean>> dataTypeWithNullables = ctx.dataTypes.stream()
.map(this::visitDataTypeWithNullable)
.collect(Collectors.toList());
List<DataType> dataTypes = dataTypeWithNullables.stream()
.map(dt -> dt.first)
.collect(ImmutableList.toImmutableList());
List<Boolean> nullables = dataTypeWithNullables.stream()
.map(dt -> dt.second)
.collect(ImmutableList.toImmutableList());
String functionName = ctx.functionNameIdentifier().getText();
if (!BuiltinAggregateFunctions.INSTANCE.aggFuncNames.contains(functionName)) {
// TODO use function binder to check function exists
throw new ParseException("Can not found function '" + functionName + "'", ctx);
}
return new AggStateType(functionName, dataTypes, nullables);
});
}

@Override
public DataType visitPrimitiveDataType(PrimitiveDataTypeContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ public Expression visitSum(Sum sum, RewriteContext context) {
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, RewriteContext context) {
String aggStateName = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(
AggregateType.GENERIC_AGGREGATION, StateCombinator.create(aggregateFunction).toSql()));
AggregateType.GENERIC, StateCombinator.create(aggregateFunction).toSql()));

Column mvColumn = context.checkContext.getColumn(aggStateName);
if (mvColumn != null && context.checkContext.valueNameToColumn.containsValue(mvColumn)) {
Expand Down
Loading

0 comments on commit ea2fbfa

Please sign in to comment.