Skip to content

Commit

Permalink
Merge pull request #42125 from heshanpadmasiri/optim/foreach-desugar
Browse files Browse the repository at this point in the history
Desugar common foreach loop to while loops
  • Loading branch information
heshanpadmasiri authored Apr 4, 2024
2 parents 68c91b5 + f50bccf commit 223d87f
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import org.wso2.ballerinalang.compiler.tree.statements.BLangSimpleVariableDef;
import org.wso2.ballerinalang.compiler.tree.statements.BLangStatement;
import org.wso2.ballerinalang.compiler.tree.statements.BLangTupleVariableDef;
import org.wso2.ballerinalang.compiler.tree.statements.BLangWhile;
import org.wso2.ballerinalang.compiler.tree.types.BLangType;
import org.wso2.ballerinalang.compiler.tree.types.BLangUserDefinedType;
import org.wso2.ballerinalang.compiler.tree.types.BLangValueType;
Expand Down Expand Up @@ -270,6 +271,16 @@ static BLangForeach createForeach(Location pos,
return foreach;
}

static BLangWhile createWhile(Location pos,
BLangExpression condition,
BLangBlockStmt body) {
final BLangWhile whileNode = (BLangWhile) TreeBuilder.createWhileNode();
whileNode.pos = pos;
whileNode.body = body;
whileNode.expr = condition;
return whileNode;
}

static BLangSimpleVariableDef createVariableDefStmt(Location pos, BlockNode target) {
final BLangSimpleVariableDef variableDef = createVariableDef(pos);
target.addStatement(variableDef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@
import java.util.Set;
import java.util.Stack;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.xml.XMLConstants;
Expand Down Expand Up @@ -5094,6 +5095,8 @@ public void visit(BLangForeach foreach) {
foreach.body.failureBreakMode = BLangBlockStmt.FailureBreakMode.NOT_BREAKABLE;
BLangDo doStmt = wrapStatementWithinDo(foreach.pos, foreach, onFailClause);
result = rewrite(doStmt, env);
} else if (canEliminateIterator(foreach)) {
result = rewrite(desugarForeachToWhileWithoutIterator(foreach), env);
} else {
// We need to create a new variable for the expression as well. This is needed because integer ranges can be
// added as the expression so we cannot get the symbol in such cases.
Expand All @@ -5114,6 +5117,123 @@ public void visit(BLangForeach foreach) {
}
}

private boolean canEliminateIterator(BLangForeach loop) {
BLangExpression collection = loop.collection;
if (collection.getKind() == NodeKind.BINARY_EXPR) {
// range expression
return true;
}
TypeKind kind = collection.getBType().getKind();
return kind == TypeKind.ARRAY || kind == TypeKind.TUPLE;
}

// We desguar certain foreach loops to hardcoded while loops to avoid having to use an iterator. This shouldn't
// create any observable difference in behavior and should be faster than using iterators. Currently, we support
// this optimization for two kinds foreach loops,
// 1) Range expressions
// foreach int i in m ..< n {
// // code;
// } will become,
// int $index$ = m;
// int $indexMax$ = n;
// while $index$ < $indexMax$ {
// int i = $index$;
// $index$ = $index$ + 1;
// // code;
// }
// 2) Foreach over lists
// foreach float f in [1.0, 2.0, 3.0] {
// // code;
// } will become,
// float[] $data$ = [1.0, 2.0, 3.0];
// int $index$ = 0;
// int $indexMax$ = $data$.length();
// while $index$ < $indexMax$ {
// float f = $data$[$index$];
// $index$ = $index$ + 1;
// // code;
// }
private BLangBlockStmt desugarForeachToWhileWithoutIterator(BLangForeach foreach) {
Location pos = foreach.pos;
BLangBlockStmt scopeBlock = ASTBuilderUtil.createBlockStmt(pos);
if (foreach.collection.getKind() == NodeKind.BINARY_EXPR) {
BLangBinaryExpr rangeExpr = (BLangBinaryExpr) foreach.collection;
OperatorKind comparisonOp =
rangeExpr.opKind == OperatorKind.HALF_OPEN_RANGE ? OperatorKind.LESS_THAN : OperatorKind.LESS_EQUAL;
Function<BLangVariableReference, BLangExpression> loopValueGenerator =
(BLangVariableReference indexVarRef) -> indexVarRef;
finishDesugarForeachToWhile(foreach, rangeExpr.lhsExpr, rangeExpr.rhsExpr, comparisonOp, loopValueGenerator,
scopeBlock);
return scopeBlock;
}
BSymbol listSymbol =
addTemporaryVariableToScope(pos, "$data$", foreach.collection, foreach.collection.expectedType,
scopeBlock);
BLangExpression listRef = ASTBuilderUtil.createVariableRef(pos, listSymbol);
BLangExpression indexInitVal = ASTBuilderUtil.createLiteral(pos, symTable.intType, 0L);
BLangExpression indexMaxVal =
createLangLibInvocationNode("length", listRef, new ArrayList<>(), symTable.intType, pos);
OperatorKind comparisonOp = OperatorKind.LESS_THAN;

Function<BLangVariableReference, BLangExpression> loopValueGenerator =
(BLangVariableReference indexVarRef) -> {
BLangVariable loopVal = (BLangVariable) foreach.variableDefinitionNode.getVariable();
return ASTBuilderUtil.createIndexBasesAccessExpr(pos, loopVal.getBType(), (BVarSymbol) listSymbol,
indexVarRef);
};
finishDesugarForeachToWhile(foreach, indexInitVal, indexMaxVal, comparisonOp, loopValueGenerator, scopeBlock);
return scopeBlock;
}

private void finishDesugarForeachToWhile(BLangForeach foreach, BLangExpression indexInitVal,
BLangExpression indexMaxVal, OperatorKind comparisonOp,
Function<BLangVariableReference, BLangExpression> loopValueGenerator,
BLangBlockStmt scopeBlock) {
Location pos = foreach.pos;
BSymbol indexSymbol = addTemporaryVariableToScope(pos, "$index$", indexInitVal, symTable.intType, scopeBlock);
// This needs to be a variable defined outside the loop in order to give the same observable behavior in case
// of adding elements to array. May need to change when ballerina-spec/899 is resolved.
BSymbol indexMaxSymbol =
addTemporaryVariableToScope(pos, "$indexMax$", indexMaxVal, symTable.intType, scopeBlock);

BLangBinaryExpr condition =
ASTBuilderUtil.createBinaryExpr(pos, ASTBuilderUtil.createVariableRef(pos, indexSymbol),
ASTBuilderUtil.createVariableRef(pos, indexMaxSymbol), symTable.booleanType, comparisonOp,
(BOperatorSymbol) symResolver
.resolveBinaryOperator(OperatorKind.EQUAL, symTable.booleanType, symTable.booleanType));

BLangBlockStmt whileBody = ASTBuilderUtil.createBlockStmt(pos);
whileBody.scope = foreach.body.scope;

VariableDefinitionNode loopValDef = foreach.variableDefinitionNode;
Location loopValPos = loopValDef.getPosition();
BLangSimpleVarRef indexRef = ASTBuilderUtil.createVariableRef(loopValPos, indexSymbol);

loopValDef.getVariable().setInitialExpression(loopValueGenerator.apply(indexRef));
whileBody.addStatement(loopValDef);

// Increment the index
whileBody.addStatement(ASTBuilderUtil.createAssignmentStmt(pos, indexRef,
ASTBuilderUtil.createBinaryExpr(pos, indexRef, ASTBuilderUtil.createLiteral(pos, symTable.intType, 1L),
symTable.intType, OperatorKind.ADD,
(BOperatorSymbol) symResolver.resolveBinaryOperator(OperatorKind.ADD, symTable.intType,
symTable.intType))));

whileBody.stmts.addAll(foreach.body.stmts);

BLangWhile whileLoop = ASTBuilderUtil.createWhile(pos, condition, whileBody);
scopeBlock.addStatement(whileLoop);
}

private BSymbol addTemporaryVariableToScope(Location pos, String name, BLangExpression initValue, BType type,
BLangBlockStmt scopeBlock) {
BLangSimpleVariable var = ASTBuilderUtil.createVariable(pos, name, type, initValue,
new BVarSymbol(0, Names.fromString(name), this.env.scope.owner.pkgID, type,
this.env.scope.owner, pos, VIRTUAL));
scopeBlock.addStatement(ASTBuilderUtil.createVariableDef(pos, var));
return var.symbol;
}

BLangBlockStmt desugarForeachStmt(BVarSymbol collectionSymbol, BType collectionType, BLangForeach foreach,
BLangSimpleVariableDef dataVarDef) {
// Get the symbol of the variable (collection).
Expand Down Expand Up @@ -5286,7 +5406,7 @@ private BLangBlockStmt desugarForeachWithIteratorDef(BLangForeach foreach,
boolean isIteratorFuncFromLangLib) {
BLangSimpleVariableDef iteratorVarDef = getIteratorVariableDefinition(foreach.pos, collectionSymbol,
iteratorInvokableSymbol, isIteratorFuncFromLangLib);
BLangBlockStmt blockNode = desugarForeachToWhile(foreach, iteratorVarDef);
BLangBlockStmt blockNode = desugarForeachToWhileWithIterator(foreach, iteratorVarDef);
blockNode.stmts.add(0, dataVariableDefinition);
return blockNode;
}
Expand All @@ -5310,7 +5430,7 @@ BInvokableSymbol getLangLibIteratorInvokableSymbol(BVarSymbol collectionSymbol)
names.fromString(BLangCompilerConstants.ITERABLE_COLLECTION_ITERATOR_FUNC), env);
}

private BLangBlockStmt desugarForeachToWhile(BLangForeach foreach, BLangSimpleVariableDef varDef) {
private BLangBlockStmt desugarForeachToWhileWithIterator(BLangForeach foreach, BLangSimpleVariableDef varDef) {

// We desugar the foreach statement to a while loop here.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import org.ballerinalang.test.BCompileUtil;
import org.ballerinalang.test.BRunUtil;
import org.ballerinalang.test.CompileResult;
import org.ballerinalang.test.exceptions.BLangTestException;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.Arrays;
Expand Down Expand Up @@ -218,9 +220,36 @@ public void testArrayWithNullElements() {
Assert.assertEquals(returns.toString(), "0:d0 1: 2:d2 3: ");
}

@Test(dataProvider = "listIterationTestFunctions")
public void testListIteration(String functionName) {
BRunUtil.invoke(program, functionName);
}

@DataProvider(name = "listIterationTestFunctions")
public Object[][] listIterationTestFunctions() {
return new Object[][]{
{"testQueryInsideLoop"},
{"testListConstructor"},
{"testTuple"},
{"testEmptyArray"},
{"testFunctionCall"},
{"testQueryExpressions"}
};
}

// These ensure observable behaviour of foreach statements are the same even when we optimize away the iterator
@Test
public void testEmptyArray() {
BRunUtil.invoke(program, "testEmptyArray");
public void testMutatingArray() {
BRunUtil.invoke(program, "testMutatingArray");
}

@Test
public void testRemoveElementsWhileIterating() {
try {
BRunUtil.invoke(program, "testRemoveElementsWhileIterating");
} catch (BLangTestException ex) {
Assert.assertTrue(ex.getMessage().contains("array index out of range: index: 2, size: 1"));
}
}

@AfterClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,72 @@ function testArrayWithNullElements() returns string {
return output;
}

function testQueryInsideLoop() {
string[] values = ["a", "b", "c"];
int sum = 0;
foreach string entity in values {
string[] relationFields = from string each in values
where entity + each != "foo"
select each;
sum += relationFields.length();
}
assertEquality(9, sum);
}

function testMutatingArray() {
int[] vals = [1, 2, 3];
int sum = 0;
foreach int val in vals {
vals = [1, ...vals];
sum += vals[0];
}
assertEquality(3, sum);
int[] prefixVals = [1, 2, 3];
sum = 0;
foreach int val in prefixVals {
vals.push(10);
sum += val;
}
assertEquality(6, sum);
sum = 0;
vals = [1, 2, 3];
foreach int val in vals {
sum += val;
vals = [4, 5, 6];
}
assertEquality(6, sum);
}

function testListConstructor() {
int sum = 0;
foreach int val in [1, 2, 3] {
sum += val;
}
assertEquality(6, sum);
}

function testTuple() {
int sum = 0;
[int, int, int] vals = [1, 2, 3];
foreach int val in vals {
sum += val;
}
assertEquality(6, sum);
sum = 0;
[int, int, int...] moreVals = [1, 2, 3, 4, 5];
foreach int val in moreVals {
sum += val;
}
assertEquality(15, sum);
}

function testRemoveElementsWhileIterating() {
int[] vals = [1, 2, 3];
foreach int val in vals {
var _ = vals.pop();
}
}

function testEmptyArray() {
output = "hello";
foreach var item in [] {
Expand All @@ -247,6 +313,26 @@ function testEmptyArray() {
assertEquality(output, "hello");
}

function testFunctionCall() {
int sum = 0;
foreach int val in foo() {
sum += val;
}
assertEquality(6, sum);
}

function foo() returns int[] {
return [1, 2, 3];
}

function testQueryExpressions() {
int sum = 0;
foreach int val in from int i in [1, 2, 3] select i {
sum += val;
}
assertEquality(6, sum);
}

const ASSERTION_ERROR_REASON = "AssertionError";

function assertEquality(any|error expected, any|error actual) {
Expand Down

0 comments on commit 223d87f

Please sign in to comment.