Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predicate refactor 2 #639

Merged
merged 6 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/basic/InlineFunctions.pvl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//:: cases InlineFunctions
//:: tools silicon
//:: verdict Pass

class C {
ensures false;
inline int f() = 3 + 3;

void m() {
assert f() == 6;
}
}
12 changes: 12 additions & 0 deletions examples/predicates/MutuallyRecursiveInlinePredicates.pvl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//:: cases MutuallyRecursiveInlinePredicates
//:: tools silicon
//:: verdict Error

class C {
inline resource p() = q() ** true;
inline resource q() = true ** p();

void p() {
assert p();
}
}
11 changes: 11 additions & 0 deletions examples/predicates/RecursiveInlinePredicate.pvl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//:: cases RecursiveInlinePredicate
//:: tools silicon
//:: verdict Error

class C {
inline resource p() = true ** p();

void p() {
assert p();
}
}
37 changes: 37 additions & 0 deletions examples/predicates/ScaleInlinePredicate.pvl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//:: cases ScaleInlinePredicate
//:: tools silicon
//:: verdict Pass

class C {
int x;
int y;
inline resource r() = Perm(x, write);

inline resource s() = r() ** Perm(y, 1\2) ** x == y;

requires [1\2]r();
void m() {
unfold [1\2]r();
assert Perm(x, 1\2);
}

requires [1\4]r();
void m2() {
unfold [1\2][1\2]r();
assert Perm(x, 1\4);
}

requires s();
void m3() {
fold [1\2]s();
assert x == y;
assert Perm(y, 1\2) ** Perm(x, write);
}

requires [1\8]s();
void m4() {
fold [1\4][1\4]s();
assert x == y;
assert Perm(y, 1\16) ** Perm(x, 1\8);
}
}
12 changes: 10 additions & 2 deletions src/main/java/vct/col/features/FeatureRainbow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,13 @@ class RainbowVisitor(source: ProgramUnit) extends RecursiveVisitor(source, true)
}

if(isPure(m)) {
if(isInline(m))
addFeature(InlinePredicate, m)
if(isInline(m)) {
if (m.getReturnType.isPrimitive(PrimitiveSort.Resource)) {
addFeature(InlinePredicate, m)
} else if (!m.getReturnType.isPrimitive(PrimitiveSort.Process)) {
addFeature(InlineFunction, m)
}
}
if(m.getBody.isInstanceOf[BlockStatement])
addFeature(PureImperativeMethods, m)
}
Expand Down Expand Up @@ -618,6 +623,7 @@ object Feature {
GivenYields,
StaticFields,
InlinePredicate,
InlineFunction,
KernelClass,
AddrOf,
OpenMP,
Expand Down Expand Up @@ -746,6 +752,7 @@ object Feature {
GivenYields,
StaticFields,
InlinePredicate,
InlineFunction,
KernelClass,
AddrOf,
OpenMP,
Expand Down Expand Up @@ -846,6 +853,7 @@ case object ADTOperator extends ScannableFeature
case object GivenYields extends ScannableFeature
case object StaticFields extends ScannableFeature
case object InlinePredicate extends ScannableFeature
case object InlineFunction extends ScannableFeature
case object KernelClass extends ScannableFeature
case object AddrOf extends ScannableFeature
case object OpenMP extends ScannableFeature
Expand Down
188 changes: 188 additions & 0 deletions src/main/java/vct/col/rewrite/InlinePredicatesAndFunctions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package vct.col.rewrite;

import vct.col.ast.expr.NameExpression;
import vct.col.ast.stmt.decl.ASTFlags;
import vct.col.ast.generic.ASTNode;
import vct.col.ast.stmt.decl.ASTSpecial;
import vct.col.ast.stmt.decl.Method;
import vct.col.ast.expr.MethodInvokation;
import vct.col.ast.expr.OperatorExpression;
import vct.col.ast.stmt.decl.ProgramUnit;
import vct.col.ast.expr.StandardOperator;
import vct.col.ast.type.ASTReserved;
import vct.col.ast.util.AbstractRewriter;
import vct.test.ThreadPool;

import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Stack;

import static hre.lang.System.Output;

public class InlinePredicatesAndFunctions extends AbstractRewriter {

int count = 0;
Stack<String> inlinedScalars = new Stack<>();
IdentityHashMap<Method, Object> currentlyBeingInlined = new IdentityHashMap<>();

public InlinePredicatesAndFunctions(ProgramUnit source) {
super(source);
}

public ASTNode inlineCall(MethodInvokation e, Method def) {
if (currentlyBeingInlined.containsKey(def)) {
def.getOrigin().report("error", "Inline predicate or function cannot contain itself.");
hre.lang.System.Abort("Cyclical inline predicate or function detected");
} else {
currentlyBeingInlined.put(def, def);
}

int N=def.getArity();
HashMap<NameExpression,ASTNode> map=new HashMap<NameExpression, ASTNode>();
Substitution sigma=new Substitution(source(),map);
map.put(create.reserved_name(ASTReserved.This), rewrite(e.object()));
for(int i=0;i<N;i++){
map.put(create.unresolved_name(def.getArgument(i)),rewrite(e.getArg(i)));
}
ASTNode body=rewrite(def.getBody());
InlineMarking marker=new InlineMarking(source(),e.getOrigin());
body.accept(marker);

ASTNode result = sigma.rewrite(body);
currentlyBeingInlined.remove(def);
return result;
}

@Override
public void visit(MethodInvokation e){
if (inline(e)){
result = inlineCall(e, e.getDefinition());
} else if (!inlinedScalars.empty() && e.getDefinition().getKind().equals(Method.Kind.Predicate)) {
// Because the previous if branch was false, we know this is a predicate that is not inline
// Hence, we have to add a scale that scales the predicate according to any earlier encountered scales
super.visit(e);
result = create.expression(StandardOperator.Scale,
getCombinedScalar(null),
result
);
} else {
super.visit(e);
}
}

protected boolean inline(ASTNode node) {
if (node instanceof Method) {
Method def = (Method) node;
if (def.isValidFlag(ASTFlags.INLINE)){
return (def.kind==Method.Kind.Predicate || def.kind==Method.Kind.Pure) && def.getFlag(ASTFlags.INLINE);
} else {
return false;
}
} else if (node instanceof MethodInvokation) {
MethodInvokation invokation = (MethodInvokation) node;
return inline(invokation.getDefinition());
} else if (node instanceof OperatorExpression) {
OperatorExpression operatorExpression = (OperatorExpression) node;
if (operatorExpression.operator() == StandardOperator.Scale) {
return inline(operatorExpression.arg(1));
}
}

return false;
}

@Override
public void visit(Method m){
if (inline(m)){
result=null;
} else {
super.visit(m);
}
}

/**
* Returns all the aggregated scalars so far in one multiplication.
*
* Given a startNode, it will yield the following AST:
*
* (((startNode * scalar1) * scalar2) * ...
*
* If no startNode is given, only the scalars appear in the tree.
*
* The reason startNode is a possible argument is because of a bug in Viper where if startNode
* is not at the bottom of a multiplication tree, some type error is triggered in Viper. Ideally,
* this method would have no arguments, and would just return the multiplication tree of the scalars.
*
* @param startNode Node to start the multiplication with. If not available, can be null.
* @return AST corresponding to: (((startNode * scalar1) * scalar2) * ...
*/
private ASTNode getCombinedScalar(ASTNode startNode) {
if (inlinedScalars.empty()) {
Abort("Cannot combine scalars when stack is empty");
}

ASTNode result;
int startIndex;

if (startNode == null) {
result = create.local_name(inlinedScalars.get(0));
startIndex = 1;
} else {
result = startNode;
startIndex = 0;
}

for (int i = startIndex; i < inlinedScalars.size(); i++) {
result = create.expression(
StandardOperator.Mult,
result,
create.local_name(inlinedScalars.get(i)));
}

return result;
}

@Override
public void visit(OperatorExpression e){
switch(e.operator()){
case Scale:
if (inline(e)) {
String scaleAmountName = "inlineScalar" + count++;

inlinedScalars.push(scaleAmountName);
super.visit(e);
OperatorExpression scaleExpr = (OperatorExpression) result;
inlinedScalars.pop();

result = create.let_expr(
create.field_decl(scaleAmountName, rewrite(e.arg(0).getType()), scaleExpr.arg(0)),
scaleExpr.arg(1)
);
} else {
super.visit(e);
}
break;
case Perm:
super.visit(e);
OperatorExpression newPerm = (OperatorExpression) result;
if (!inlinedScalars.empty()) {
ASTNode permissionAmount = getCombinedScalar(e.arg(1));
result = create.expression(StandardOperator.Perm, newPerm.arg(0), permissionAmount);
}
break;
default:
super.visit(e);
break;
}
}

@Override
public void visit(ASTSpecial e) {
if ((e.kind == ASTSpecial.Kind.Fold || e.kind == ASTSpecial.Kind.Unfold) && inline(e.args[0])) {
Warning("Folding/unfolding an inline predicate is allowed but not encouraged. See https://github.com/utwente-fmt/vercors/wiki/Predicates#inline-predicates for more info.");
result = create.special(ASTSpecial.Kind.Assert, rewrite(e.getArg(0)));
} else {
super.visit(e);
}
}
}
Loading