Skip to content

Commit

Permalink
[BABEL-284] Add analyzer code for PIVOT
Browse files Browse the repository at this point in the history
Added transformPivotClause function in analyer to help identify output
table columns' data type.

Task: BABEL-284
Signed-off-by: Yanjie Xu <[email protected]>
  • Loading branch information
RIC06X committed Oct 4, 2023
1 parent 399374a commit 202e707
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
156 changes: 156 additions & 0 deletions src/backend/parser/analyze.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ static void transformLockingClause(ParseState *pstate, Query *qry,
static bool test_raw_expression_coverage(Node *node, void *context);
#endif

RawStmt *bbf_pivot_sql1;
RawStmt *bbf_pivot_sql2;

/*
* parse_analyze_fixedparams
Expand Down Expand Up @@ -1368,6 +1370,155 @@ count_rowexpr_columns(ParseState *pstate, Node *expr)
return -1;
}

static ResTarget *
makeResTargetFromColName(char * colName)
{
ResTarget *tempResTarget;
ColumnRef *tempColRef;

tempResTarget = makeNode(ResTarget);
tempColRef = makeNode(ColumnRef);
tempColRef->location = -1;
tempColRef->fields = list_make1(makeString(colName));
tempResTarget->name = NULL;
tempResTarget->name_location = -1;
tempResTarget->indirection = NIL;
tempResTarget->val = (Node *) tempColRef;
tempResTarget->location = -1;
return tempResTarget;
}

static void
transformPivotClause(ParseState *pstate, SelectStmt *stmt)
{
Query *temp_src_query;
ParseState *sub_pstate;
List *temp_src_targetlist;
List *new_src_sql_targetist;
List *new_pivot_aliaslist;
List *src_sql_groupbylist;
char *pivot_colstr;
char *value_colstr;
ColumnRef *value_col;
TargetEntry *aggfunc_te;
RangeFunction *pivot_from_function;
RawStmt *s_sql;
RawStmt *c_sql;
MemoryContext oldContext;

new_src_sql_targetist = NULL;
new_pivot_aliaslist = NULL;
src_sql_groupbylist = NULL;

/* transform temporary src_sql */
sub_pstate = make_parsestate(pstate);
temp_src_query = transformSelectStmt(sub_pstate, stmt->srcSql);
temp_src_targetlist = temp_src_query->targetList;

/* Get pivot column str & value column str from parser result */
pivot_colstr = stmt->pivotCol;
value_col = list_nth_node(ColumnRef, ((FuncCall *)((ResTarget *)stmt->aggFunc)->val)->args, 0);
value_colstr = list_nth_node(String, value_col->fields, 0)->sval;

/* Get the targetList of the src table */
for (int i = 0; i < temp_src_targetlist->length; i++)
{
ResTarget *tempResTarget;
ColumnDef *tempColDef;
TargetEntry *tempEntry = list_nth_node(TargetEntry, temp_src_targetlist, i);

char *colName = tempEntry->resname;

if (strcasecmp(colName, pivot_colstr) == 0 || strcasecmp(colName, value_colstr) == 0)
continue;
/* prepare src_sql's targetList */
tempResTarget = makeResTargetFromColName(colName);

if (new_src_sql_targetist == NULL)
new_src_sql_targetist = list_make1(tempResTarget);
else
new_src_sql_targetist = lappend(new_src_sql_targetist, tempResTarget);

/* prepare pivot sql's alias_clause */
tempColDef = makeColumnDef(colName,
((Var *)tempEntry->expr)->vartype,
((Var *)tempEntry->expr)->vartypmod,
((Var *)tempEntry->expr)->varcollid
);

if (new_pivot_aliaslist == NULL)
new_pivot_aliaslist = list_make1(tempColDef);
else
new_pivot_aliaslist = lappend(new_pivot_aliaslist, tempColDef);
}
/* source_sql: non-pivot column + pivot colunm+ agg(value_col) */
/* complete src_sql's targetList*/
new_src_sql_targetist = lappend(new_src_sql_targetist, makeResTargetFromColName(pivot_colstr));
new_src_sql_targetist = lappend(new_src_sql_targetist, (ResTarget *)stmt->aggFunc);
((SelectStmt *)stmt->srcSql)->targetList = new_src_sql_targetist;

/* complete src_sql's groupby*/
for (int i = 0; i < new_src_sql_targetist->length - 1; i++)
{
A_Const *tempAConst = makeNode(A_Const);
tempAConst->val.ival.type = T_Integer;
tempAConst->val.ival.ival = i+1;
tempAConst->location = -1;

if (src_sql_groupbylist == NULL)
src_sql_groupbylist = list_make1(tempAConst);
else
src_sql_groupbylist = lappend(src_sql_groupbylist, tempAConst);
}
((SelectStmt *)stmt->srcSql)->groupClause = src_sql_groupbylist;

/* Transform the new src_sql & get the output type of that agg function*/
/* ?do we need to clean the memory used by previous sub_pstate? */
sub_pstate = make_parsestate(pstate);
temp_src_query = transformSelectStmt(sub_pstate, stmt->srcSql);
temp_src_targetlist = temp_src_query->targetList;

/* asClause: non-pivot columns + value columns) */
aggfunc_te = list_nth_node(TargetEntry, temp_src_targetlist, temp_src_targetlist->length - 1);

/* complete pivo sql's alias_clause */
/* Rewrite the fromClause in the outer select to have correct alias column name and datatype */
pivot_from_function = list_nth_node(RangeFunction, stmt->fromClause, 0);
for(int i = 0; i < stmt->value_col_strlist->length; i++)
{
ColumnDef *tempColDef;
tempColDef = makeColumnDef((char *) list_nth(stmt->value_col_strlist, i),
((Aggref *)aggfunc_te->expr)->aggtype,
-1,
((Aggref *)aggfunc_te->expr)->aggcollid
);

if (new_pivot_aliaslist == NULL)
new_pivot_aliaslist = list_make1(tempColDef);
else
new_pivot_aliaslist = lappend(new_pivot_aliaslist, tempColDef);
}

pivot_from_function->coldeflist = new_pivot_aliaslist;
/* put the correct src_sql raw parse tree into the memory context for later use */
oldContext = CurrentMemoryContext;
MemoryContextSwitchTo(TopMemoryContext);
/* save rewrited sqls to global variable for later retrive */
s_sql = makeNode(RawStmt);
c_sql = makeNode(RawStmt);
s_sql->stmt = (Node *) stmt->srcSql;
s_sql->stmt_location = 0;
s_sql->stmt_len = 0;

c_sql->stmt = (Node *) stmt->catSql;
c_sql->stmt_location = 0;
c_sql->stmt_len = 0;

bbf_pivot_sql1 = copyObject(s_sql);
bbf_pivot_sql2 = copyObject(c_sql);

MemoryContextSwitchTo(oldContext);
}

/*
* transformSelectStmt -
Expand Down Expand Up @@ -1407,6 +1558,11 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
/* make WINDOW info available for window functions, too */
pstate->p_windowdefs = stmt->windowClause;

if (stmt->isPivot)
{
transformPivotClause(pstate, stmt);
}

/* process the FROM clause */
transformFromClause(pstate, stmt->fromClause);

Expand Down
5 changes: 5 additions & 0 deletions src/include/nodes/parsenodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,11 @@ typedef struct SelectStmt

/* These fields are used only in SelectStmt with PIVOT keyword */
bool isPivot;
struct SelectStmt *srcSql;
struct SelectStmt *catSql;
List *value_col_strlist;
char *pivotCol;
Node *aggFunc;
} SelectStmt;


Expand Down
3 changes: 3 additions & 0 deletions src/include/parser/analyze.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,7 @@ extern List *BuildOnConflictExcludedTargetlist(Relation targetrel,

extern SortGroupClause *makeSortGroupClauseForSetOp(Oid rescoltype, bool require_hash);

extern RawStmt *bbf_pivot_sql1;
extern RawStmt *bbf_pivot_sql2;

#endif /* ANALYZE_H */

0 comments on commit 202e707

Please sign in to comment.