Skip to content

Commit

Permalink
Merge pull request #41733 from LakshanWeerasinghe/fix-#41525
Browse files Browse the repository at this point in the history
Add `on conflict` clause to query pipeline
  • Loading branch information
LakshanWeerasinghe authored Apr 4, 2024
2 parents fe0590a + 7727e5f commit 35c2599
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ private void createFunctionMap(BLangFunction funcNode, SymbolEnv funcEnv) {
if (funcNode.mapSymbolUpdated) {
return;
}
BLangRecordLiteral emptyRecord = ASTBuilderUtil.createEmptyRecordLiteral(funcNode.pos, symTable.mapType);
BLangRecordLiteral emptyRecord = ASTBuilderUtil.createEmptyRecordLiteral(funcNode.pos, symTable.mapAllType);
BLangSimpleVariable mapVar = ASTBuilderUtil.createVariable(funcNode.pos, funcNode.mapSymbol.name.value,
funcNode.mapSymbol.type, emptyRecord,
funcNode.mapSymbol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public class QueryDesugar extends BLangNodeVisitor {
private static final Name QUERY_CREATE_GROUP_BY_FUNCTION = new Name("createGroupByFunction");
private static final Name QUERY_CREATE_COLLECT_FUNCTION = new Name("createCollectFunction");
private static final Name QUERY_CREATE_SELECT_FUNCTION = new Name("createSelectFunction");
private static final Name QUERY_CREATE_ON_CONFLICT_FUNCTION = new Name("createOnConflictFunction");
private static final Name QUERY_CREATE_DO_FUNCTION = new Name("createDoFunction");
private static final Name QUERY_CREATE_LIMIT_FUNCTION = new Name("createLimitFunction");
private static final Name QUERY_ADD_STREAM_FUNCTION = new Name("addStreamFunction");
Expand All @@ -229,8 +230,12 @@ public class QueryDesugar extends BLangNodeVisitor {
private static final Name QUERY_TO_STRING_FUNCTION = new Name("toString");
private static final Name QUERY_TO_XML_FUNCTION = new Name("toXML");
private static final Name QUERY_ADD_TO_TABLE_FUNCTION = new Name("addToTable");
private static final Name QUERY_ADD_TO_TABLE_FOR_ON_CONFLICT_FUNCTION = new Name("addToTableForOnConflict");
private static final Name QUERY_ADD_TO_MAP_FUNCTION = new Name("addToMap");
private static final Name QUERY_ADD_TO_MAP_FOR_ON_CONFLICT_FUNCTION = new Name("addToMapForOnConflict");
private static final Name QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION = new Name("getStreamFromPipeline");
private static final Name QUERY_GET_STREAM_FOR_ON_CONFLICT_FROM_PIPELINE_FUNCTION =
new Name("getStreamForOnConflictFromPipeline");
private static final Name QUERY_GET_QUERY_ERROR_ROOT_CAUSE_FUNCTION = new Name("getQueryErrorRootCause");
private static final String FRAME_PARAMETER_NAME = "$frame$";
private static final Name QUERY_BODY_DISTINCT_ERROR_NAME = new Name("Error");
Expand Down Expand Up @@ -298,23 +303,21 @@ BLangStatementExpression desugar(BLangQueryExpr queryExpr, SymbolEnv env,
if (queryExpr.isStream) {
resultType = streamRef.getBType();
} else if (queryExpr.isTable) {
onConflictExpr = (onConflictExpr == null)
? ASTBuilderUtil.createLiteral(pos, symTable.nilType, Names.NIL_VALUE)
: onConflictExpr;
BLangVariableReference tableRef = addTableConstructor(queryExpr, queryBlock);
Name internalFuncName = onConflictExpr == null ? QUERY_ADD_TO_TABLE_FUNCTION
: QUERY_ADD_TO_TABLE_FOR_ON_CONFLICT_FUNCTION;
result = getStreamFunctionVariableRef(queryBlock,
QUERY_ADD_TO_TABLE_FUNCTION, Lists.of(streamRef, tableRef, onConflictExpr, isReadonly), pos);
internalFuncName, Lists.of(streamRef, tableRef, isReadonly), pos);
resultType = tableRef.getBType();
onConflictExpr = null;
} else if (queryExpr.isMap) {
onConflictExpr = (onConflictExpr == null)
? ASTBuilderUtil.createLiteral(pos, symTable.nilType, Names.NIL_VALUE)
: onConflictExpr;
BMapType mapType = getMapType(queryExpr.getBType());
BLangRecordLiteral.BLangMapLiteral mapLiteral = new BLangRecordLiteral.BLangMapLiteral(queryExpr.pos,
mapType, new ArrayList<>());
Name internalFuncName = onConflictExpr == null ? QUERY_ADD_TO_MAP_FUNCTION
: QUERY_ADD_TO_MAP_FOR_ON_CONFLICT_FUNCTION;
result = getStreamFunctionVariableRef(queryBlock,
QUERY_ADD_TO_MAP_FUNCTION, Lists.of(streamRef, mapLiteral, onConflictExpr, isReadonly), pos);
internalFuncName, Lists.of(streamRef, mapLiteral, isReadonly), pos);
onConflictExpr = null;
} else if (queryExpr.getFinalClause().getKind() == NodeKind.COLLECT) {
result = getStreamFunctionVariableRef(queryBlock, COLLECT_QUERY_FUNCTION, Lists.of(streamRef), pos);
Expand Down Expand Up @@ -562,6 +565,9 @@ BLangVariableReference buildStream(List<BLangNode> clauses, BType resultType, Sy
case ON_CONFLICT:
final BLangOnConflictClause onConflict = (BLangOnConflictClause) clause;
onConflictExpr = onConflict.expression;
BLangVariableReference onConflictRef = addOnConflictFunction(block, onConflict,
stmtsToBePropagated);
addStreamFunction(block, initPipeline, onConflictRef);
break;
}
}
Expand Down Expand Up @@ -927,6 +933,28 @@ BLangVariableReference addSelectFunction(BLangBlockStmt blockStmt, BLangSelectCl
return getStreamFunctionVariableRef(blockStmt, QUERY_CREATE_SELECT_FUNCTION, Lists.of(lambda), pos);
}

/**
* Desugar onConflictClause to below and return a reference to created onConflict _StreamFunction.
* _StreamFunction onConflictFunc = createOnConflictFunction
* @param blockStmt parent block to write to.
* @param onConflictClause to be desugared.
* @param stmtsToBePropagated list of statements to be propagated.
* @return variableReference to created onConflict _StreamFunction.
*/
BLangVariableReference addOnConflictFunction(BLangBlockStmt blockStmt, BLangOnConflictClause onConflictClause,
List<BLangStatement> stmtsToBePropagated) {
Location pos = onConflictClause.pos;
BLangLambdaFunction lambda = createPassthroughLambda(pos);
BLangBlockFunctionBody body = (BLangBlockFunctionBody) lambda.function.body;
body.stmts.addAll(0, stmtsToBePropagated);
BVarSymbol oldFrameSymbol = lambda.function.requiredParams.get(0).symbol;
BLangSimpleVarRef frame = ASTBuilderUtil.createVariableRef(pos, oldFrameSymbol);
// $frame#[$error$] = on-conflict-expr;
BLangStatement assignment = getAddToFrameStmt(pos, frame, "$error$", onConflictClause.expression);
body.stmts.add(body.stmts.size() - 1, assignment);
lambda = rewrite(lambda);
return getStreamFunctionVariableRef(blockStmt, QUERY_CREATE_ON_CONFLICT_FUNCTION, Lists.of(lambda), pos);
}
/**
* Desugar doClause to below and return a reference to created do _StreamFunction.
* _StreamFunction doFunc = createDoFunction(function(_Frame frame) {
Expand Down Expand Up @@ -996,8 +1024,12 @@ void addStreamFunction(BLangBlockStmt blockStmt, BLangVariableReference pipeline
*/
BLangVariableReference addGetStreamFromPipeline(BLangBlockStmt blockStmt, BLangVariableReference pipelineRef) {
Location pos = pipelineRef.pos;
if (onConflictExpr == null) {
return getStreamFunctionVariableRef(blockStmt,
QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos);
}
return getStreamFunctionVariableRef(blockStmt,
QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos);
QUERY_GET_STREAM_FOR_ON_CONFLICT_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos);
}

/**
Expand Down Expand Up @@ -2066,6 +2098,9 @@ public void visit(BLangErrorConstructorExpr errorConstructorExpr) {
if (errorConstructorExpr.namedArgs != null) {
rewrite(errorConstructorExpr.namedArgs);
}
if (errorConstructorExpr.positionalArgs != null) {
rewrite(errorConstructorExpr.positionalArgs);
}
errorConstructorExpr.errorDetail = rewrite(errorConstructorExpr.errorDetail);
result = errorConstructorExpr;
}
Expand Down
107 changes: 90 additions & 17 deletions langlib/lang.query/src/main/ballerina/helpers.bal
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ function createSelectFunction(function(_Frame _frame) returns _Frame|error? sele
return new _SelectFunction(selectFunc);
}

function createOnConflictFunction(function(_Frame _frame) returns _Frame|error? onConflictFunc)
returns _StreamFunction => new _OnConflictFunction(onConflictFunc);

function createCollectFunction(string[] nonGroupingKeys, function(_Frame _frame) returns _Frame|error? collectFunc) returns _StreamFunction {
return new _CollectFunction(nonGroupingKeys, collectFunc);
}
Expand All @@ -93,6 +96,9 @@ function getStreamFromPipeline(_StreamPipeline pipeline) returns stream<Type, Co
return pipeline.getStream();
}

function getStreamForOnConflictFromPipeline(_StreamPipeline pipeline) returns stream<Type, CompletionType>
=> pipeline.getStreamForOnConflict();

function toArray(stream<Type, CompletionType> strm, Type[] arr, boolean isReadOnly) returns Type[]|error {
if isReadOnly {
// In this case arr will be an immutable array. Therefore, we will create a new mutable array and pass it to the
Expand Down Expand Up @@ -159,7 +165,8 @@ function toString(stream<Type, CompletionType> strm) returns string|error {
return result;
}

function addToTable(stream<Type, CompletionType> strm, table<map<Type>> tbl, error? err, boolean isReadOnly) returns table<map<Type>>|error {
function addToTable(stream<Type, CompletionType> strm, table<map<Type>> tbl, boolean isReadOnly)
returns table<map<Type>>|error {
if isReadOnly {
// TODO: Properly fix readonly scenario - Issue lang/#36721
// In this case tbl will be an immutable table. Therefore, we will create a new mutable table. Next, we will
Expand All @@ -168,52 +175,83 @@ function addToTable(stream<Type, CompletionType> strm, table<map<Type>> tbl, err
// and make it immutable with createImmutableTable().
table<map<Type>> tempTbl = table [];
table<map<Type>> tbl2 = createTableWithKeySpecifier(tbl, typeof(tempTbl));
table<map<Type>> tempTable = check createTable(strm, tbl2, err);
table<map<Type>> tempTable = check createTable(strm, tbl2);
return createImmutableTable(tbl, tempTable.toArray());
}
return createTable(strm, tbl, err);
return createTable(strm, tbl);
}

function createTable(stream<Type, CompletionType> strm, table<map<Type>> tbl, error? err) returns table<map<Type>>|error {
function createTable(stream<Type, CompletionType> strm, table<map<Type>> tbl) returns table<map<Type>>|error {
record {| Type value; |}|CompletionType v = strm.next();
while (v is record {| Type value; |}) {
error? e = trap tbl.add(<map<Type>> checkpanic v.value);
if (e is error) {
if (err is error) {
return err;
}
tbl.put(<map<Type>> checkpanic v.value);
}
v = strm.next();
}
if (v is error) {
if v is error {
return v;
}
return tbl;
}

function addToMap(stream<Type, CompletionType> strm, map<Type> mp, error? err, boolean isReadOnly) returns map<Type>|error {
// Here, `err` is used to get the expression of on-conflict clause
function addToTableForOnConflict(stream<Type, CompletionType> strm, table<map<Type>> tbl, boolean isReadOnly)
returns table<map<Type>>|error {
if isReadOnly {
// TODO: Properly fix readonly scenario - Issue lang/#36721
// In this case tbl will be an immutable table. Therefore, we will create a new mutable table. Next, we will
// pass the newly created table into createTableWithKeySpecifier() to add the key specifier details from the
// original table variable (tbl). Then the newly created table variable will be populated using createTable()
// and make it immutable with createImmutableTable().
table<map<Type>> tempTbl = table [];
table<map<Type>> mutableTableRef = createTableWithKeySpecifier(tbl, typeof(tempTbl));
_ = check createTableForOnConflict(strm, mutableTableRef);
return createImmutableTable(tbl, mutableTableRef.toArray());
}
return createTableForOnConflict(strm, tbl);
}

function createTableForOnConflict(stream<Type, CompletionType> strm, table<map<Type>> tbl)
returns table<map<Type>>|error {
record {| Type value; |}|CompletionType v = strm.next();
while v is record {| Type value; |} {
record {|Type v; error? err;|}|error value = v.value.ensureType();
if value is error {
return value;
}
map<Type> tblValue = check value.v.ensureType();
error? e = trap tbl.add(tblValue);
error? err = value.err;
if e is error && err is error {
return err;
}
if e is error && err is () {
tbl.put(<map<Type>> checkpanic value.v);
}
v = strm.next();
}
return v is error ? v : tbl;
}

function addToMap(stream<Type, CompletionType> strm, map<Type> mp, boolean isReadOnly) returns map<Type>|error {
if isReadOnly {
// In this case mp will be an immutable map. Therefore, we will create a new mutable map and pass it to the
// createMap() (because we can't update immutable map). Then it will populate the members into it and the
// resultant map will be passed into createImmutableValue() to make it immutable.
map<Type> mp2 = {};
createImmutableValue(check createMap(strm, mp2, err));
createImmutableValue(check createMap(strm, mp2));
return mp2;
}
return createMap(strm, mp, err);
return createMap(strm, mp);
}

function createMap(stream<Type, CompletionType> strm, map<Type> mp, error? err) returns map<Type>|error {
function createMap(stream<Type, CompletionType> strm, map<Type> mp) returns map<Type>|error {
record {| Type value; |}|CompletionType v = strm.next();
while (v is record {| Type value; |}) {
while v is record {| Type value; |} {
[string, Type]|error value = trap (<[string, Type]> checkpanic v.value);
if value !is error {
string key = value[0];
if mp.hasKey(key) && err is error {
return err;
}
mp[key] = value[1];
} else {
return value;
Expand All @@ -227,6 +265,41 @@ function createMap(stream<Type, CompletionType> strm, map<Type> mp, error? err)
return mp;
}

function addToMapForOnConflict(stream<Type, CompletionType> strm, map<Type> mp, boolean isReadOnly)
returns map<Type>|error {
if isReadOnly {
// In this case mp will be an immutable map. Therefore, we will create a new mutable map and pass it to the
// createMap() (because we can't update immutable map). Then it will populate the members into it and the
// resultant map will be passed into createImmutableValue() to make it immutable.
map<Type> mp2 = {};
createImmutableValue(check createMapForOnConflict(strm, mp2));
return mp2;
}
return createMapForOnConflict(strm, mp);
}

function createMapForOnConflict(stream<Type, CompletionType> strm, map<Type> mp) returns map<Type>|error {
record {| Type value; |}|CompletionType v = strm.next();
while v is record {| Type value; |} {
record {|Type v; error? err;|}|error value = v.value.ensureType();
if value is error {
return value;
}
[string, Type]|error keyValue = value.v.ensureType();
if keyValue is error {
return keyValue;
}
string key = keyValue[0];
error? err = value.err;
if mp.hasKey(key) && err is error {
return err;
}
mp[key] = keyValue[1];
v = strm.next();
}
return v is error ? v : mp;
}

function consumeStream(stream<Type, CompletionType> strm) returns any|error {
any|error? v = strm.next();
while (!(v is () || v is error)) {
Expand Down
61 changes: 61 additions & 0 deletions langlib/lang.query/src/main/ballerina/types.bal
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class _StreamPipeline {
var strm = internal:construct(self.constraintTd, self.completionTd, itrObj);
return strm;
}

public function getStreamForOnConflict() returns stream<Type, CompletionType> {
OnConflictIterHelper itrObj = new (self, self.constraintTd);
return internal:construct(self.constraintTd, self.completionTd, itrObj);
}
}

class _InitFunction {
Expand Down Expand Up @@ -794,6 +799,40 @@ class _SelectFunction {
}
}

class _OnConflictFunction {
*_StreamFunction;

# Desugared function to do;
# on conflict error("Duplicate key")
public function (_Frame _frame) returns _Frame|error? onConflictFunc;

function init(function (_Frame _frame) returns _Frame|error? onConflictFunc) {
self.onConflictFunc = onConflictFunc;
self.prevFunc = ();
}

public function process() returns _Frame|error? {
_StreamFunction pf = <_StreamFunction>self.prevFunc;
function (_Frame _frame) returns _Frame|error? f = self.onConflictFunc;
_Frame|error? pFrame = pf.process();
if (pFrame is _Frame) {
_Frame|error? cFrame = f(pFrame);
if (cFrame is error) {
return prepareQueryBodyError(cFrame);
}
return cFrame;
}
return pFrame;
}

public function reset() {
_StreamFunction? pf = self.prevFunc;
if (pf is _StreamFunction) {
pf.reset();
}
}
}

class _DoFunction {
*_StreamFunction;

Expand Down Expand Up @@ -931,6 +970,28 @@ class IterHelper {
}
}

class OnConflictIterHelper {
public _StreamPipeline pipeline;
public typedesc<Type> outputType;

function init(_StreamPipeline pipeline, typedesc<Type> outputType) {
self.pipeline = pipeline;
self.outputType = outputType;
}

public isolated function next() returns record {|Type value;|}|error? {
_StreamPipeline p = self.pipeline;
_Frame|error? f = p.next();
if (f is _Frame) {
Type v = <Type>f["$value$"];
error? err = <error?>f["$error$"];
record {|Type v; error? err;|} value = {v, err};
return internal:setNarrowType(self.outputType, {value: value});
}
return f;
}
}

class _OrderTreeNode {
any? key = ();
_Frame[]? frames = ();
Expand Down
Loading

0 comments on commit 35c2599

Please sign in to comment.