Skip to content

Commit

Permalink
[GR-49949] Backport to 20.3: Loop unroll: handle stride overflow in i…
Browse files Browse the repository at this point in the history
…nt range.

PullRequest: graal/16346
  • Loading branch information
elkorchi committed Dec 16, 2023
2 parents d54cd1f + 0a33a84 commit a10ec61
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2011, 2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2011, 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -24,10 +24,10 @@
*/
package org.graalvm.compiler.core.common;

// JaCoCo Exclude

import jdk.vm.ci.code.CodeUtil;

import org.graalvm.compiler.debug.GraalError;

/**
* A collection of static utility functions that check ranges of numbers.
*/
Expand Down Expand Up @@ -244,4 +244,14 @@ public static long minUnsigned(long a, long b) {
public static boolean sameSign(long a, long b) {
return a < 0 == b < 0;
}

public static long addExact(long a, long b, int bits) {
if (bits == 32) {
return Math.addExact((int) a, (int) b);
} else if (bits == 64) {
return Math.addExact(a, b);
} else {
throw GraalError.shouldNotReachHere("Must be one of java's core datatypes int/long but is " + bits);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.graalvm.collections.EconomicMap;
import org.graalvm.compiler.core.common.RetryableBailoutException;
import org.graalvm.compiler.core.common.calc.CanonicalCondition;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.graph.Graph.Mark;
import org.graalvm.compiler.graph.Graph.NodeEventScope;
Expand Down Expand Up @@ -602,6 +604,17 @@ public static List<ControlSplitNode> findUnswitchable(LoopEx loop) {
return controls;
}

public static boolean strideAdditionOverflows(LoopEx loop) {
final int bits = ((IntegerStamp) loop.counted().getCounter().valueNode().stamp(NodeView.DEFAULT)).getBits();
long stride = loop.counted().getCounter().constantStride();
try {
NumUtil.addExact(stride, stride, bits);
return false;
} catch (ArithmeticException ae) {
return true;
}
}

public static boolean isUnrollableLoop(LoopEx loop) {
if (!loop.isCounted() || !loop.counted().getCounter().isConstantStride() || !loop.loop().getChildren().isEmpty() || loop.loopBegin().loopEnds().count() != 1 ||
loop.loopBegin().loopExits().count() > 1) {
Expand All @@ -618,11 +631,8 @@ public static boolean isUnrollableLoop(LoopEx loop) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s condition unsupported %s ", loopBegin, ((CompareNode) condition).condition());
return false;
}
long stride = loop.counted().getCounter().constantStride();
try {
Math.addExact(stride, stride);
} catch (ArithmeticException ae) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s doubling the stride overflows %d", loopBegin, stride);
if (strideAdditionOverflows(loop)) {
condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s doubling the stride overflows %d", loopBegin, loop.counted().getCounter().constantStride());
return false;
}
if (!loop.canDuplicateLoop()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@
import org.junit.Test;

import jdk.vm.ci.meta.ResolvedJavaMethod;
import jdk.vm.ci.meta.SpeculationLog;

public class LoopPartialUnrollTest extends GraalCompilerTest {

boolean check = true;

@Override
protected void checkMidTierGraph(StructuredGraph graph) {
if (!check) {
return;
}
NodeIterable<LoopBeginNode> loops = graph.getNodes().filter(LoopBeginNode.class);
for (LoopBeginNode loop : loops) {
if (loop.isMainLoop()) {
Expand Down Expand Up @@ -347,4 +353,79 @@ public void testDuplicateBody(String reference, String test) {
canonicalizer.apply(referenceGraph, getDefaultMidTierContext());
assertEquals(referenceGraph, testGraph);
}

public static void twoUsages(int n) {
for (int i = 0; injectIterationCount(100, i < n); i++) {
GraalDirectives.blackhole(i < n ? 1 : 2);
}
}

@Test
public void testUsages() {
check = false;
OptionValues options = new OptionValues(getInitialOptions(), GraalOptions.LoopPeeling, false);
test(options, "twoUsages", 100);
check = true;
}

@Test
public void testIDiv() {
check = false;
for (int i = -1; i < 64; i++) {
test("idivSnippet", i);
}
check = true;
}

static int S = 100;

public static int idivSnippet(int iterations) {
int res = 0;
for (int i = 1; injectBranchProbability(0.99, i < iterations); i++) {
res += 100 / i;
}

return res;
}

static int rr = 0;

static int countedAfterSnippet(int i, int limit) {
int res = 0;
for (int j = i; GraalDirectives.injectIterationCount(1000, j <= limit); j += Integer.MAX_VALUE) {
rr += 42;
res += j;
}
return res;
}

SpeculationLog speculationLog;
boolean useSpeculationLog;

@Override
protected SpeculationLog getSpeculationLog() {
if (!useSpeculationLog) {
speculationLog = null;
return null;
}
if (speculationLog == null) {
speculationLog = getCodeCache().createSpeculationLog();
}
speculationLog.collectFailedSpeculations();
return speculationLog;
}

@Test
public void strideOverflow() {
check = false;
useSpeculationLog = true;
OptionValues opt = new OptionValues(getInitialOptions(), GraalOptions.LoopPeeling, false);
for (int i = -1000; i < 1000; i++) {
for (int j = 0; j < 100; j++) {
test(opt, "countedAfterSnippet", i, j);
}
}
check = true;
useSpeculationLog = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import org.graalvm.collections.EconomicMap;
import org.graalvm.collections.Equivalence;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.debug.DebugCloseable;
import org.graalvm.compiler.debug.DebugContext;
Expand Down Expand Up @@ -236,7 +237,7 @@ public void insertWithinAfter(LoopEx loop, EconomicMap<LoopBeginNode, OpaqueNode
opaqueUnrolledStrides.put(loop.loopBegin(), opaque);
} else {
assert counted.getCounter().isConstantStride();
assert Math.addExact(counted.getCounter().constantStride(), counted.getCounter().constantStride()) == counted.getCounter().constantStride() * 2;
assert !strideAdditionOverflows(loop) : "Stride addition must not overflow";
ValueNode previousValue = opaque.getValue();
opaque.setValue(graph.addOrUniqueWithInputs(AddNode.add(counterStride, previousValue, NodeView.DEFAULT)));
GraphUtil.tryKillUnused(previousValue);
Expand Down Expand Up @@ -747,4 +748,15 @@ public void apply(Node from, Position p) {
}
return newExit;
}

public static boolean strideAdditionOverflows(LoopEx loop) {
final int bits = ((IntegerStamp) loop.counted().getCounter().valueNode().stamp(NodeView.DEFAULT)).getBits();
long stride = loop.counted().getCounter().constantStride();
try {
NumUtil.addExact(stride, stride, bits);
return false;
} catch (ArithmeticException ae) {
return true;
}
}
}

0 comments on commit a10ec61

Please sign in to comment.