From 8a09ff5b7b52eb55d0c4ea7584c5ed8116084932 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 20 Jun 2022 15:48:19 +0200 Subject: [PATCH 01/25] Added an expression equality checker class --- .../ast/util/ExpressionEqualityCheck.scala | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala diff --git a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala new file mode 100644 index 0000000000..410ff9ac0b --- /dev/null +++ b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala @@ -0,0 +1,286 @@ +package vct.col.ast.util + +import hre.lang.System.Warning +import vct.col.ast.util.ExpressionEqualityCheck.is_constant_int +import vct.col.ast.{And, BitAnd, BitNot, BitOr, BitShl, BitShr, BitUShr, BitXor, Div, Eq, Exp, Expr, FloorDiv, Greater, GreaterEq, Implies, IntegerValue, Less, LessEq, Local, Minus, Mod, Mult, Neq, Not, Or, Plus, Star, UMinus, Wand} + +import scala.collection.mutable + +object ExpressionEqualityCheck { + def apply[G](info: Option[AnnotationVariableInfo[G]] = None): ExpressionEqualityCheck[G] = new ExpressionEqualityCheck[G](info) + + def is_constant_int[G](e: Expr[G]): Option[BigInt] = { + ExpressionEqualityCheck().is_constant_int(e) + } + + def equal_expressions[G](lhs: Expr[G], rhs: Expr[G]): Boolean = { + ExpressionEqualityCheck().equal_expressions(lhs, rhs) + } +} + +class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { + var replacer_depth = 0 + var replacer_depth_int = 0 + val max_depth = 100 + + def is_constant_int(e: Expr[G]): Option[BigInt] = { + replacer_depth_int = 0 + is_constant_int_(e) + } + + def is_constant_int_(e: Expr[G]): Option[BigInt] = e match { + case e: Local[G] => + // Does it have a direct int value? + info.flatMap(_.variable_values.get(e)) match { + case Some(x) => Some(x) + case None => + info.flatMap(_.variable_equalities.get(e)) match { + case None => None + case Some(equals) => + for (eq <- equals) { + // Make sure we do not loop indefinitely by keep replacing the same expressions somehow + if (replacer_depth_int > max_depth) return None + replacer_depth_int += 1 + val res = is_constant_int_(eq) + if (res.isDefined) return res + } + None + } + } + + case IntegerValue(value) => Some(value) + case Exp(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1.pow(i2.toInt) + case Plus(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 + i2 + case Minus(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 - i2 + case Mult(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 * i2 + case FloorDiv(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 / i2 + case Mod(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 % i2 + + case BitAnd(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 & i2 + case BitOr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 | i2 + case BitXor(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 ^ i2 + case BitShl(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 << i2.toInt + case BitShr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 >> i2.toInt + case BitUShr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1.toInt >>> i2.toInt + + case _ => None + } + + def equal_expressions(lhs: Expr[G], rhs: Expr[G]): Boolean = { + replacer_depth = 0 + equal_expressions_(lhs, rhs) + } + + // + def equal_expressions_(lhs: Expr[G], rhs: Expr[G]): Boolean = { + (is_constant_int(lhs), is_constant_int(rhs)) match { + case (Some(i1), Some(i2)) => return i1 == i2 + case (None, None) => () + //If one is a constant expression, and the other is not, this cannot be the same + case _ => return false + } + + (lhs, rhs) match { + // Unsure if we could check/pattern match on this easier + + // Commutative operators + case (Plus(lhs1, lhs2), Plus(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (Mult(lhs1, lhs2), Mult(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (BitAnd(lhs1, lhs2), BitAnd(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (BitOr(lhs1, lhs2), BitOr(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (BitXor(lhs1, lhs2), BitXor(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (And(lhs1, lhs2), And(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (Or(lhs1, lhs2), Or(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (Eq(lhs1, lhs2), Eq(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + case (Neq(lhs1, lhs2), Neq(rhs1, rhs2)) => + (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || + (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + + //Non commutative operators + case (Exp(lhs1, lhs2), Exp(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Minus(lhs1, lhs2), Minus(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Div(lhs1, lhs2), Div(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (FloorDiv(lhs1, lhs2), FloorDiv(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Mod(lhs1, lhs2), Mod(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + + case (BitShl(lhs1, lhs2), BitShl(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (BitShr(lhs1, lhs2), BitShr(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (BitUShr(lhs1, lhs2), BitUShr(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + + case (Implies(lhs1, lhs2), Implies(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Star(lhs1, lhs2), Star(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Wand(lhs1, lhs2), Wand(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + + case (Greater(lhs1, lhs2), Greater(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (Less(lhs1, lhs2), Less(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (GreaterEq(lhs1, lhs2), GreaterEq(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + case (LessEq(lhs1, lhs2), LessEq(rhs1, rhs2)) => + equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + + // Unary expressions + case (UMinus(lhs), UMinus(rhs)) => equal_expressions_(lhs, rhs) + case (BitNot(lhs), BitNot(rhs)) => equal_expressions_(lhs, rhs) + case (Not(lhs), Not(rhs)) => equal_expressions_(lhs, rhs) + + // Variables + case (name1: Local[G], name2: Local[G]) => + if (name1 == name2) true + else if (info.isDefined) { + // Check if the variables are synonyms + (info.get.variable_synonyms.get(name1), info.get.variable_synonyms.get(name2)) match { + case (Some(x), Some(y)) => x == y + case _ => false + } + } else false + case (name1: Local[G], e2) => + replace_variable(name1, e2) + case (e1, name2: Local[G]) => + replace_variable(name2, e1) + + // In the general case, we are just interested in syntactic equality + case (e1, e2) => e1 == e2 + } + } + + // + def replace_variable(name: Local[G], other_e: Expr[G]): Boolean = { + if (info.isDefined) { + info.get.variable_equalities.get(name) match { + case None => false + case Some(equals) => + for (eq <- equals) { + // Make sure we do not loop indefinitely by keep replacing the same expressions somehow + if (replacer_depth > max_depth) return false + replacer_depth += 1 + if (equal_expressions_(eq, other_e)) return true + } + false + } + } else { + false + } + } +} + +case class AnnotationVariableInfo[G](variable_equalities: Map[Local[G], List[Expr[G]]], variable_values: Map[Local[G], BigInt], + variable_synonyms: Map[Local[G], Int]) + +/** This class gathers information about variables, such as: + * `requires x == 0` and stores that x is equal to the value 0. + * Which we can use in simplify steps + * This information is returned with ```get_info(annotations: Iterable[Expr[G]])``` + */ +class AnnotationVariableInfoGetter[G]() { + + val variable_equalities: mutable.Map[Local[G], mutable.ListBuffer[Expr[G]]] = + mutable.Map() + val variable_values: mutable.Map[Local[G], BigInt] = mutable.Map() + // We put synonyms in the same group and give them a group number, to identify the same synonym groups + val variable_synonyms: mutable.Map[Local[G], Int] = mutable.Map() + var current_synonym_group = 0 + + def extract_equalities(e: Expr[G]): Unit = { + e match{ + case Eq(e1, e2) => + (e1, e2) match{ + case (v1: Local[G], v2: Local[G]) => add_synonym(v1, v2) + case (v1: Local[G], _) => add_name(v1, e2) + case (_, v2: Local[G]) => add_name(v2, e1) + case _ => + } + case _ => + } + } + + def add_synonym(v1: Local[G], v2: Local[G]): Unit = { + (variable_synonyms.get(v1), variable_synonyms.get(v2)) match { + // We make a new group + case (None, None) => + variable_synonyms(v1) = current_synonym_group + variable_synonyms(v2) = current_synonym_group + current_synonym_group += 1 + // Add to the found group + case (Some(id1), None) => variable_synonyms(v2) = id1 + case (None, Some(id2)) => variable_synonyms(v1) = id2 + // Merge the groups, give every synonym group member of id2 value id1 + case (Some(id1), Some(id2)) => + variable_synonyms.mapValuesInPlace((_, group) => if (group == id2) id1 else group) + } + } + + def add_name(v: Local[G], expr: Expr[G]): Unit ={ + // Add to constant list + is_constant_int[G](expr) match { + case Some(x) => variable_values.get(v) match { + case Some(x_) => if (x!=x_) Warning("Value of %s is required to be both %d and %d", v, x, x_); + case None => variable_values(v) = x + } + case None => + val list = variable_equalities.getOrElseUpdate(v, mutable.ListBuffer()) + list.addOne(expr) + } + } + + def get_info(annotations: Iterable[Expr[G]]): AnnotationVariableInfo[G] = { + variable_equalities.clear() + variable_values.clear() + + for(clause <- annotations){ + extract_equalities(clause) + } + + distribute_info() + AnnotationVariableInfo(variable_equalities.view.mapValues(_.toList).toMap, variable_values.toMap, + variable_synonyms.toMap) + } + + def distribute_info(): Unit = { + // First distribute value knowledge over the rest of the map + val begin_size = variable_values.size + + for((name, equals) <- variable_equalities){ + if(!variable_values.contains(name)) + for(equal <- equals){ + equal match { + case n : Local[G] => + variable_values.get(n).foreach(variable_values(name) = _) + case _ => + } + } + } + + // If sizes are not the same, we know more, so distribute again! + if(variable_values.size != begin_size) distribute_info() + } + +} \ No newline at end of file From 1f3fe599ac99ca9a238f1a9cff9f53486184480b Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 24 Jun 2022 13:20:03 +0200 Subject: [PATCH 02/25] First version of translating SimplifiedNestedQuantifiers to new ast --- .../SimplifyNestedQuantifiers.scala | 836 ++++++++++++++++++ 1 file changed, 836 insertions(+) create mode 100644 src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala new file mode 100644 index 0000000000..7795fcb06e --- /dev/null +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -0,0 +1,836 @@ +package vct.col.newrewrite + +import vct.col.ast.{ArraySubscript, _} +import vct.col.ast.util.ExpressionEqualityCheck +import vct.col.newrewrite.util.Comparison +import vct.col.origin.{Origin, PanicBlame} +import vct.col.ref.Ref +import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.col.util.AstBuildHelpers._ +import vct.col.util.{AstBuildHelpers, Substitute} +import vct.result.VerificationError.Unreachable + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.math.BigInt +import scala.annotation.nowarn + +/** + * This rewrite pass simplifies expressions of roughly this form: + * forall(i,j: Int . 0 <= i < i_max && 0 <= j < j_max; xs[a*(i_max*j + i) + b]) + * and collapses it into a single forall: + * forall(k: Int. b <= k <= i_max*j_max*a + b && k % a == 0; xs[k]) + * + * We also check on if a quantifier takes just a single value. E.g. + * forall(i,j: Int; i == 5 && i < n && i <= j && j < 5; xs[j+i]) ====> 5 < n ==> forall(int j; 0 <= j < 5; xs[j]) + * + * and if a quantifier isn't in the "body" of the forall. E.g. + * forall(i,j: Int. 1 <= i && i< n && 0 < j; xs[j]>0) ====> n > 1 ==> forall(j: Int; 0 < j; xs[j] >0) + * + */ +case object SimplifyNestedQuantifiers extends RewriterBuilder { + override def key: String = "simplifyNestedQuantifiers" + override def desc: String = "Simplify nested quantifiers." +} + +//case class SimplifyNestedQuantifiers[Pre <: Generation]() extends NonLatchingRewriter[Pre, Rewritten[Pre]] { +case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] { + + case object SimplifyNestedQuantifiersOrigin extends Origin { + override def preferredName: String = "unknown" + + override def shortPosition: String = "generated" + + override def context: String = "[At generated expression for the simplification of nested quantifiers]" + + override def inlineContext: String = "[Simplified expression]" + } + + private implicit val o: Origin = SimplifyNestedQuantifiersOrigin + + private def one: IntegerValue[Pre] = IntegerValue(1) + + // TODO: Supply information towards the expression equality checker when encountering a contract + var equality_checker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck() + + override def dispatch(e: Expr[Pre]): Expr[Post] = { + e match { + case e: Binder[Pre] => + rewriteLinearArray(e) match { + case None => rewriteDefault(e) + case Some(e) => e + } + case other => rewriteDefault(other) + } + } + + def rewriteLinearArray(e: Binder[Pre]): Option[Expr[Post]] = { + val originalBody = e match { + case Forall(_, _, body) => body + case Starall(_, _, body) => body + case _ => return None + } + + if (e.bindings.exists(_.t != TInt())) return None + + val quantifierData = new RewriteQuantifierData(originalBody, e, this) + quantifierData.setData() + quantifierData.check_single_value_variables() + quantifierData.check_independent_variables() + + // Check if we have valid bounds to rewrite, otherwise we stop + if(!quantifierData.check_bounds()) return quantifierData.result() + + quantifierData.lookForLinearAccesses() + } + + class RewriteQuantifierData(val bindings: mutable.Set[Variable[Pre]], + var lowerBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var upperBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var upperExclusiveBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var independentConditions: ArrayBuffer[Expr[Pre]], + val dependentConditions: ArrayBuffer[Expr[Pre]], + var body: Expr[Pre], + val originalBinder: Binder[Pre], + val rewriter: SimplifyNestedQuantifiers[Pre] + ) { + def this(originalBody: Expr[Pre], originalBinder: Binder[Pre], rewriter: SimplifyNestedQuantifiers[Pre]) = { + this(originalBinder.bindings.to(mutable.Set), + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, + ArrayBuffer[Expr[Pre]](), + ArrayBuffer[Expr[Pre]](), + originalBody, + originalBinder, + rewriter + ) + } + + /** Keeps track if it is already feasible to make a new quantifier */ + var new_binder = false + + def setData(): Unit = { + val (allConditions, main_body) = unfoldImplies[Pre](body) + body = main_body + // Split bounds that are independent of any binding variables + val (new_independentConditions, potentialBounds) = allConditions.partition(indepOf(bindings, _)) + independentConditions.addAll(new_independentConditions) + getBounds(potentialBounds) + } + + /** + * Process the potential bounds to be either a bound or just a dependent condition. + * @param potentialBounds Bounds to be processed. + */ + def getBounds(potentialBounds: Iterable[Expr[Pre]]): Unit = { + for (bound <- potentialBounds) { + // First try to match a simple comparison + Comparison.of(bound) match { + case Some((_, Comparison.NEQ, _)) => dependentConditions.addOne(bound) + case Some((left, comp, right)) => + if (indepOf(bindings, right)) { + // x >|>=|==|<=|< 5 + left match { + case Local(Ref(v)) if bindings.contains(v) => addSingleBound(v, right, comp) + case _ => dependentConditions.addOne(bound) + } + } else if (indepOf(bindings, left)) { + right match { + case Local(Ref(v)) if bindings.contains(v) => + // If the quantified variable is the second argument: flip the relation + addSingleBound(v, right, comp.flip) + case _ => dependentConditions.addOne(bound) + } + } + case None => bound match { + // If we do not have a simple comparison, we support one special case: i \in {a..b} + case SeqMember(Local(Ref(v)), Range(from, to)) + if bindings.contains(v) && indepOf(bindings, from) && indepOf(bindings, to) => + addSingleBound(v, from, Comparison.GREATER_EQ) + addSingleBound(v, to, Comparison.LESS) + case _ => dependentConditions.addOne(bound) + } + } + } + } + + /** + * Add a bound like v >= right. + */ + @nowarn("msg=xhaust") + def addSingleBound(v: Variable[Pre], right: Expr[Pre], comp: Comparison): Unit = { + comp match { + // v < right + case Comparison.LESS => + upperExclusiveBounds(v).addOne(right) + upperBounds(v).addOne(right + one) + // v <= right + case Comparison.LESS_EQ => + upperExclusiveBounds(v).addOne(right - one) + upperBounds(v).addOne(right) + // v == right + case Comparison.EQ => + lowerBounds(v).addOne(right) + upperExclusiveBounds(v).addOne(right - one) + upperBounds(v).addOne(right) + // v >= right + case Comparison.GREATER_EQ => lowerBounds(v).addOne(right) + // v > right + case Comparison.GREATER => lowerBounds(v).addOne(right + one) + } + } + + /** We check if there now any binding variables which resolve to just a single value, which happens if it + * has equal lower and upper bounds. + * E.g. forall(int i,j; i == 0 && i <= j && j < 5; xs[j+i]) ==> forall(int j; 0 <= j < 5; xs[j]) + * We just replace each reference to that value, and check our bounds again. + * We don't worry if a we have something like x == 5 && x < 0, since that will resolve to 5 < 0, which equally + * does not work. + * */ + def check_single_value_variables(): Unit = { + for (name <- bindings) { + val equal_bounds = lowerBounds(name).intersect(upperBounds(name)) + if (equal_bounds.nonEmpty) { + // We will put out a new quantifier + new_binder = true + val new_value = equal_bounds.head + val name_var: Expr[Pre] = Local(name.ref) + val sub = Substitute[Pre](Map(name_var -> new_value)) + val replacer = sub.dispatch(_: Expr[Pre]) + body = replacer(body) + + // Do not quantify over name anymore + bindings.remove(name) + + // Some dependent selects, might now have become independent or even bounds + val oldDependentBounds = dependentConditions.map(replacer) + dependentConditions.clear() + + val (new_independentConditions, potentialBounds) = oldDependentBounds.partition(indepOf(bindings, _)) + independentConditions.addAll(new_independentConditions) + getBounds(potentialBounds) + + // Bounds for the name, have now become independent conditions + lowerBounds(name).foreach(lb => + if (lb != new_value) independentConditions.addOne(LessEq(lb, new_value))) + upperBounds(name).foreach(ub => + if (ub != new_value) independentConditions.addOne(LessEq(new_value, ub))) + + lowerBounds = lowerBounds.removed(name) + upperBounds = upperBounds.removed(name) + upperExclusiveBounds = upperExclusiveBounds.removed(name) + + // Strictly speaking, a binding variable could be newly removed, if a previous one has been found constant + // and then the bounds deem another binding variable also constant. We check that by doing recursion. + check_single_value_variables() + return + } + } + } + + def check_independent_variables(): Unit = { + for (name <- bindings) { + if (indepOf(mutable.Set(name), body)) { + var independent = true + dependentConditions.foreach(s => if (!indepOf(mutable.Set(name), s)) independent = false) + if (independent) { + // We can freely remove this named variable + val max_bound = extremeValue(name, maximizing = true) + val min_bound = extremeValue(name, maximizing = false) + (max_bound, min_bound) match { + case (Some(max_bound), Some(min_bound)) => + new_binder = true + // Do not quantify over name anymore + bindings.remove(name) + lowerBounds = lowerBounds.removed(name) + upperBounds = upperBounds.removed(name) + upperExclusiveBounds = upperExclusiveBounds.removed(name) + + // We remove the forall variable i, but need to rewrite some expressions + // (forall i; a <= i <= b; ...Perm(ar, x)...) ====> b>=a ==> ...Perm(ar, x*(b-a+1))... + independentConditions.addOne(GreaterEq(max_bound, min_bound)) + + body = Scale(Plus(one, Minus(max_bound, min_bound)), body)( + PanicBlame("Error in SimplifyNestedQuantifiers class, implication should make sure scale is" + + " never negative when accessed.")) // rp.dispatch(body) + case _ => + } + } + } + } + } + + def extremeValue(name: Variable[Pre], maximizing: Boolean): Option[Expr[Pre]] = { + if (maximizing && upperBounds(name).nonEmpty) + Some(extremes(upperBounds(name).toSeq, maximizing)) + else if (!maximizing && lowerBounds(name).nonEmpty) + Some(extremes(lowerBounds(name).toSeq, maximizing)) + else + None + } + + def extremes(xs: Seq[Expr[Pre]], maximizing: Boolean): Expr[Pre] = { + xs match { + case expr :: Nil => expr + case left :: right :: tail => + Select( + condition = if(maximizing) left > right else left < right, + whenTrue = extremes(left :: tail, maximizing), + whenFalse = extremes(right :: tail, maximizing), + ) + } + } + + // This allows only forall's to be rewritten, if they have at least one lower bound of zero + // TODO: Generalize this, so we don't have this restriction + def check_bounds(): Boolean = { + for (name <- bindings) { + var one_zero = false + val zero = BigInt(0) + lowerBounds.getOrElse(name, ArrayBuffer()) + .foreach(lower => equality_checker.is_constant_int(lower) match { + case Some(`zero`) => one_zero = true + case _ => + }) + //Exit when notAt least one zero, or no upper bounds + if (!one_zero || upperBounds.getOrElse(name, ArrayBuffer()).isEmpty) { + return false + } + } + true + } + + case class ForallSubstitute(subs: Map[Expr[Pre], Expr[Post]]) + extends Rewriter[Pre] { + + override def dispatch(e: Expr[Pre]): Expr[Post] = e match { + case expr if subs.contains(expr) => subs(expr) + case other => rewriteDefault(other) + } + } + + def lookForLinearAccesses(): Option[Expr[Post]] = { + val linearAccesses = new FindLinearArrayAccesses(this) + withCollectInScope(rewriter.variableScopes) {linearAccesses.search(body)} match { + case (bindings, Some(substituteForall)) => + if(bindings.size != 1){ + ??? + } + val sub = ForallSubstitute(substituteForall.substituteOldVars) + val newBody = sub.dispatch(body) + val select = Seq(substituteForall.newBounds) ++ independentConditions.map(sub.dispatch) ++ + dependentConditions.map(sub.dispatch) + val main = if (select.nonEmpty) Implies(AstBuildHelpers.foldAnd(select), newBody) else newBody + val forall: Binder[Post] = originalBinder match { + case _: Forall[Pre] => Forall(bindings, substituteForall.newTrigger, main) + case originalBinder: Starall[Pre] => Starall(bindings, substituteForall.newTrigger, main)(originalBinder.blame) + case _ => ??? + } + Some(forall) + case (_, None) => result() + } + } + + def result(): Option[Expr[Post]] = { + // If we changed something we always return a result, even if we could not rewrite further + val res = if(new_binder) { + val select = independentConditions ++ dependentConditions + if (bindings.isEmpty) { + if (select.isEmpty) Some(body) else Some(Implies(AstBuildHelpers.foldAnd(select.toSeq), body)) + } else { + upperExclusiveBounds.foreach { + case (n: Variable[Pre], upperBounds: ArrayBuffer[Expr[Pre]]) => + val i: Expr[Pre] = Local(n.ref) + upperBounds.foreach(upperBound => + select.addOne(i < upperBound) + ) + } + lowerBounds.foreach { + case (n: Variable[Pre], lowerBounds: ArrayBuffer[Expr[Pre]]) => + val i: Expr[Pre] = Local(n.ref) + lowerBounds.foreach(lowerBound => + select.addOne(lowerBound <= i ) + ) + } + val new_body = if (select.nonEmpty) Implies(AstBuildHelpers.foldAnd(select.toSeq), body) + else body + + // TODO: Should we get the old triggers? And then filter if the triggers contain variables which + // are not there anymore? + @nowarn("msg=xhaust") + val forall: Expr[Pre] = originalBinder match{ + case _: Forall[Pre] => Forall(bindings.toSeq, Seq(), new_body) + case e: Starall[Pre] => Starall(bindings.toSeq, Seq(), new_body)(e.blame) + } + Some(forall) + } + } else { + None + } + + res.map(rewriter.rewriteDefault) + } + } + + def indepOf[G](bindings: mutable.Set[Variable[G]], e: Expr[G]): Boolean = + e.transSubnodes.collectFirst { case Local(ref) if bindings.contains(ref.decl) => () }.isEmpty + + class FindLinearArrayAccesses(quantifierData: RewriteQuantifierData){ + + // Search for linear array expressions + def search(e: Expr[Pre]): Option[SubstituteForall] = { + e match { + case e @ ArraySubscript(_, index) => + if (indepOf(quantifierData.bindings, index)) { + return None + } + linear_expression(e) match { + case Some(substitute_forall) => Some(substitute_forall) + case None => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} + } + case _ => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} + } + } + + def search(n: Node[Pre]): Option[SubstituteForall] = None + + def linear_expression(e: ArraySubscript[Pre]): Option[SubstituteForall] = { + val ArraySubscript(_, index) = e + val pot = new PotentialLinearExpressions(e) + pot.visit(index) + pot.can_rewrite() + } + + class PotentialLinearExpressions(val arrayIndex: ArraySubscript[Pre]){ + val linear_expressions: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() + var constant_expression: Option[Expr[Pre]] = None + var is_linear: Boolean = true + var current_multiplier: Option[Expr[Pre]] = None + + def visit(e: Expr[Pre]): Unit = { + e match{ + case Plus(left, right) => + // if the first is constant, the second argument cannot be + if (isConstant(left)) { + addToConstant(left) + visit(right) + } else if (isConstant(right)) { + addToConstant(right) + visit(left) + } else { // Both arguments contain linear information + visit(left) + visit(right) + } + case Minus(left, right) => + // if the first is constant, the second argument cannot be + if (isConstant(left)) { + addToConstant(left) + val old_multiplier = current_multiplier + multiplyMultiplier(IntegerValue(-1)) + visit(right) + current_multiplier = old_multiplier + } else if (isConstant(right)) { + addToConstant(right, is_plus=false) + visit(left) + } else { // Both arguments contain linear information + visit(left) + val old_multiplier = current_multiplier + multiplyMultiplier(IntegerValue(-1)) + visit(right) + current_multiplier = old_multiplier + } + case Mult(left, right) => + if (isConstant(left)) { + val old_multiplier = current_multiplier + multiplyMultiplier(left) + visit(right) + current_multiplier = old_multiplier + } else if (isConstant(right)) { + val old_multiplier = current_multiplier + multiplyMultiplier(right) + visit(left) + current_multiplier = old_multiplier + } else { + is_linear = false + } + // TODO: Check if division is right conceptually with an example. Take special care to think about + // the order of division + case e@Div(left, right) => + if (isConstant(right)){ + val old_multiplier = current_multiplier + multiplyMultiplier(Div(IntegerValue(1), right)(e.blame)) + visit(left) + current_multiplier = old_multiplier + } else { + is_linear = false + } + case Local(ref) => + if(quantifierData.bindings.contains(ref.decl)) { + linear_expressions get ref.decl match { + case None => linear_expressions(ref.decl) = current_multiplier.getOrElse(IntegerValue(1)) + case Some(old) => linear_expressions(ref.decl) = + Plus(old, current_multiplier.getOrElse(IntegerValue(1))) + } + } else { + Unreachable("We should not end up here, the precondition of \'FindLinearArrayAccesses\' was not uphold.") + } + case _ => + is_linear = false + } + } + + def can_rewrite(): Option[SubstituteForall] = { + if(!is_linear) { + return None + } + + // Checking the preconditions of the check_vars_list function + if(quantifierData.bindings.isEmpty) return None + for(v <- quantifierData.bindings){ + if(!(linear_expressions.contains(v) && + quantifierData.upperExclusiveBounds.contains(v) && + quantifierData.upperExclusiveBounds(v).nonEmpty) + ) { + return None + } + } + + quantifierData.bindings.toList.reverse.permutations.map(check_vars_list) + .collectFirst({case Some(subst) => subst}) + } + + /** + * This function determines if the vars in this specific order allow the forall to be rewritten to one + * forall. + * + * Precondition: + * * At least one var in `quantifierData.bindings` + * * linear_expressions has an expression for all `vars` + * * quantifierData.upperExclusiveBounds has a non-empty list for all `vars` + * + * We are looking for patterns: + * /\_{0 <= i < k} {0 <= x_i < n_i} : ... ar[Sum_{0 <= i < k} {a_i * x_i} + b] ... + * and we require that for i>0 + * a_i == a_{i-1} * n_{i-1} + * (or equivalent a_i == Prod_{0 <= j < i} {n_j} * a_0 ) + * + * Further more we require that n_i > 0 and a_i > 0 (although I think a_0<0 is also valid) + * TODO: We are not checking n_i and a_i on this + * We can than replace the forall with + * b <= x_new < a_{k-1} * n_{k-1} + b && (x_new - b) % a_0 == 0 : ... ar[x_new] ... + * and each x_i gets replaced by + * x_i -> ((x_new - b) / a_i) % n_i + * And since we never go past a_{k-1} * n_{k-1} + b, no modulo needed here + * x_{k-1} -> (x_new - b) / a_{k-1}\ + */ + def check_vars_list(vars: List[Variable[Pre]]): Option[SubstituteForall] = { + val x_0 = vars.head + val a_0 = linear_expressions(x_0) + // x_{i-1}, a_{i-1}, n_{i-1} + var x_i_last = x_0 + var a_i_last = a_0 + var n_i_last: Expr[Pre] = null + val ns : mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() + + val x_new = new Variable[Post](TInt()) + + + val newGen: Expr[Pre] => Expr[Post] = quantifierData.rewriter.dispatch(_) + + // x_base == (x_new -b) + val x_base: Expr[Post]= constant_expression match { + case None => Local(x_new.ref) + case Some(b) => Minus(Local(x_new.ref), newGen(b)) + } + val replace_map: mutable.Map[Expr[Pre], Expr[Post]] = mutable.Map() + + for(x_i <- vars.tail){ + val a_i = linear_expressions(x_i) + var found_valid_n = false + + // Find a suitable upper bound + for (n_i_last_candidate <- quantifierData.upperExclusiveBounds(x_i_last)) { + if( !found_valid_n && equality_checker.equal_expressions(a_i, simplified_mult(a_i_last, n_i_last_candidate)) ) { + found_valid_n = true + n_i_last = n_i_last_candidate + ns(x_i_last) = n_i_last_candidate + } + } + + if(!found_valid_n) return None + // We now know the valid bound of x_{i-1} + // x_{i-1} -> ((x_new -b) / a_{i-1}) % n_{i-1} + + replace_map(Local(x_i_last.ref)) = + if(is_value(a_i_last, 1)) + Mod(x_base, newGen(n_i_last))(PanicBlame("TODO")) + else + Mod(FloorDiv(x_base, newGen(a_i_last))(PanicBlame("TODO")), newGen(n_i_last))(PanicBlame("TODO")) + + // Yay we are good up to now, go check out the next i + x_i_last = x_i + a_i_last = a_i + n_i_last = null + } + // We found a replacement! + // Make the declaration final + x_new.declareDefault(quantifierData.rewriter) + val ArraySubscript(arr, index) = arrayIndex + // Replace the linear expression with the new variable + val x_new_var: Expr[Post] = Local(x_new.ref) + replace_map(index) = x_new_var + val newTrigger : Seq[Seq[Expr[Post]]] = Seq(Seq(ArraySubscript(newGen(arr), x_new_var)(arrayIndex.blame))) + + // Add the last value, no need to do modulo + //TODO + replace_map(Local(x_i_last.ref)) = FloorDiv(x_base, newGen(a_i_last))(PanicBlame("TODO")) + // Get a random upperbound for x_i_last; + n_i_last = quantifierData.upperExclusiveBounds(x_i_last).head + ns(x_i_last) = n_i_last + // 0 <= x_new - b < a_{k-1} * n_{k-1} + var new_bounds = And( + LessEq( IntegerValue(0), x_base), + Less(x_base, newGen(simplified_mult(a_i_last, n_i_last))) + ) + // && (x_new - b) % a_0 == 0 + new_bounds = if(is_value(a_0, 1)) new_bounds else + And(new_bounds, + //TODO + Eq( Mod(x_base, newGen(a_0))(PanicBlame("TODO")), + IntegerValue(0)) + ) + + for(x_i <- vars){ + val n_i = ns(x_i) + // Remove the upper bound we used, but keep the others + for(old_upper_bound <- quantifierData.upperExclusiveBounds(x_i)){ + if(old_upper_bound != n_i){ + new_bounds = And(Less(replace_map(Local(x_i.ref)), newGen(old_upper_bound)), new_bounds) + } + } + + // Remove the lower zero bound, but keep the others + for(old_lower_bound <- quantifierData.lowerBounds(x_i)) + if(!is_value(old_lower_bound, 0)) + new_bounds = And(LessEq(newGen(old_lower_bound), replace_map(Local(x_i.ref))), new_bounds) + + // Since we know the lower bound was also 0, and the we multiply the upper bounds, + // we do have to require that each upper bound is at least bigger than 0. + new_bounds = And(Less(IntegerValue(0), newGen(n_i)), new_bounds) + } + + Some(SubstituteForall(new_bounds, replace_map.toMap, newTrigger)) + } + + def simplified_mult(lhs: Expr[Pre], rhs: Expr[Pre]): Expr[Pre] = { + if (is_value(lhs, 1)) rhs + else if (is_value(rhs, 1)) lhs + else Mult(lhs, rhs) + } + + def isConstant(node: Expr[Pre]): Boolean = indepOf(quantifierData.bindings, node) + + def addToConstant(node : Expr[Pre], is_plus: Boolean = true): Unit = { + val added_node: Expr[Pre] = current_multiplier match { + case None => node + case Some(expr) => Mult(expr, node) + } + constant_expression = Some(constant_expression match { + case None => if(is_plus) added_node else Mult(IntegerValue(-1), added_node) + case Some(expr) => if(is_plus) Plus(expr, added_node) else Minus(expr, added_node) + }) + } + + def multiplyMultiplier(node : Expr[Pre]): Unit ={ + current_multiplier match { + case None => current_multiplier = Some(node); + case Some(expr) => current_multiplier = Some(Mult(expr, node)) + } + } + + def is_value(e: Expr[Pre], x: Int): Boolean = + equality_checker.is_constant_int(e) match { + case None => false + case Some(y) => y == x + } + } + } + + // // The `new_forall_var` will be the name of variable of the newly made forall. + // // The `newBounds`, will contain all the new equations for "select" part of the forall. + // // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions + // // We also store the `linear_expression`, so if we ever come across it, we can replace it with the new variable. + case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTrigger: Seq[Seq[Expr[Post]]]) +} + +// +// var equality_checker: ExpressionEqualityCheck = ExpressionEqualityCheck() +// +// override def visit(special: ASTSpecial): Unit = { +// if(special.kind == ASTSpecial.Kind.Inhale){ +// val info_getter = new AnnotationVariableInfoGetter() +// val annotations = ASTUtils.conjuncts(special.args(0), StandardOperator.Star).asScala +// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// +// result = create special(special.kind, rewrite(special.args):_*) +// +// equality_checker = ExpressionEqualityCheck() +// +// } else { +// result = create special(special.kind, rewrite(special.args): _*) +// } +// } +// +// override def visit(c: ASTClass): Unit = { //checkPermission(c); +// val name = c.getName +// if (name == null) Abort("illegal class without name") +// else { +// Debug("rewriting class " + name) +// val new_pars = rewrite(c.parameters) +// val new_supers = rewrite(c.super_classes) +// val new_implemented = rewrite(c.implemented_classes) +// val res = new ASTClass(name, c.kind, new_pars, new_supers, new_implemented) +// res.setOrigin(c.getOrigin) +// currentTargetClass = res +// val contract = c.getContract +// if (currentContractBuilder == null) currentContractBuilder = new ContractBuilder +// if (contract != null) { +// val info_getter = new AnnotationVariableInfoGetter() +// val annotations = LazyList(ASTUtils.conjuncts(contract.pre_condition, StandardOperator.Star).asScala +// , ASTUtils.conjuncts(contract.invariant, StandardOperator.Star).asScala).flatten +// +// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// rewrite(contract, currentContractBuilder) +// equality_checker = ExpressionEqualityCheck() +// } +// res.setContract(currentContractBuilder.getContract) +// currentContractBuilder = null +// +// for (i <- 0 until c.size()) { +// res.add(rewrite(c.get(i))) +// } +// result = res +// currentTargetClass = null +// } +// } +// +// override def visit(s: ForEachLoop): Unit = { +// val new_decl = rewrite(s.decls) +// val res = create.foreach(new_decl, rewrite(s.guard), rewrite(s.body)) +// +// val mc = s.getContract +// if (mc != null) { +// val info_getter = new AnnotationVariableInfoGetter() +// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala +// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten +// +// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// res.setContract(rewrite(mc)) +// equality_checker = ExpressionEqualityCheck() +// } else { +// res.setContract(rewrite(mc)) +// } +// +// +// res.set_before(rewrite(s.get_before)) +// res.set_after(rewrite(s.get_after)) +// result = res +// } +// +// override def visit(s: LoopStatement): Unit = { //checkPermission(s); +// val res = new LoopStatement +// var tmp = s.getInitBlock +// if (tmp != null) res.setInitBlock(tmp.apply(this)) +// tmp = s.getUpdateBlock +// if (tmp != null) res.setUpdateBlock(tmp.apply(this)) +// tmp = s.getEntryGuard +// if (tmp != null) res.setEntryGuard(tmp.apply(this)) +// tmp = s.getExitGuard +// if (tmp != null) res.setExitGuard(tmp.apply(this)) +// val mc = s.getContract +// if (mc != null) { +// val info_getter = new AnnotationVariableInfoGetter() +// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala +// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten +// +// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// res.appendContract(rewrite(mc)) +// equality_checker = ExpressionEqualityCheck() +// } else { +// res.appendContract(rewrite(mc)) +// } +// +// +// tmp = s.getBody +// res.setBody(tmp.apply(this)) +// res.set_before(rewrite(s.get_before)) +// res.set_after(rewrite(s.get_after)) +// res.setOrigin(s.getOrigin) +// result = res +// } +// +// override def visit(m: Method): Unit = { //checkPermission(m); +// val name = m.getName +// if (currentContractBuilder == null) { +// currentContractBuilder = new ContractBuilder +// } +// val args = rewrite(m.getArgs) +// val mc = m.getContract +// +// var c: Contract = null +// // Ensure we maintain the type of emptiness of mc +// // If the contract was null previously, the new contract can also be null +// // If the contract was non-null previously, the new contract cannot be null +// if (mc != null) { +// val info_getter = new AnnotationVariableInfoGetter() +// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala +// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten +// +// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// +// rewrite(mc, currentContractBuilder) +// c = currentContractBuilder.getContract(false) +// equality_checker = ExpressionEqualityCheck() +// } +// else { +// c = currentContractBuilder.getContract(true) +// } +// if (mc != null && c != null && c.getOrigin == null) { +// c.setOrigin(mc.getOrigin) +// } +// currentContractBuilder = null +// val kind = m.kind +// val rt = rewrite(m.getReturnType) +// val signals = rewrite(m.signals) +// val body = rewrite(m.getBody) +// result = create.method_kind(kind, rt, signals, c, name, args, m.usesVarArgs, body) +// } +// +// override def visit(expr: BindingExpression): Unit = { +// expr.binder match { +// case Binder.Forall | Binder.Star => +// val bindings = expr.getDeclarations.map(_.name).toSet +// val (select, main) = splitSelect(rewrite(expr.select), rewrite(expr.main)) +// val (independentSelect, potentialBounds) = select.partition(independentOf(bindings, _)) +// val (bounds, dependent_bounds) = getBounds(bindings, potentialBounds) +// //Only rewrite main, when the dependent bounds are not existing +// if(dependent_bounds.isEmpty && expr.binder != Binder.Star){ +// rewriteMain(bounds, main) match { +// case Some(main) => +// result = create expression(Implies, (independentSelect ++ bounds.selectNonEmpty).reduce(and), main); return +// case None => +// } +// } +// rewriteLinearArray(bounds, main, independentSelect, dependent_bounds, expr.binder, expr.result_type) match { +// case Some(new_forall) => +// result = new_forall; +// return +// case None => +// } +// super.visit(expr) +// case _ => +// super.visit(expr) +// } +// } +//} \ No newline at end of file From abf483853752952896de8b4da68defaa73ee99f8 Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 10 Aug 2022 11:02:48 +0200 Subject: [PATCH 03/25] Unfold forall, small bugfix & refactor naming conventions Unfold forall, small bugfix & refactor naming conventions Added to passes --- .../ast/util/ExpressionEqualityCheck.scala | 230 +++++++-------- .../SimplifyNestedQuantifiers.scala | 270 +++++++++++------- .../java/vct/main/stages/Transformation.scala | 1 + 3 files changed, 278 insertions(+), 223 deletions(-) diff --git a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala index 410ff9ac0b..1232c3ca62 100644 --- a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala +++ b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala @@ -1,7 +1,7 @@ package vct.col.ast.util import hre.lang.System.Warning -import vct.col.ast.util.ExpressionEqualityCheck.is_constant_int +import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt import vct.col.ast.{And, BitAnd, BitNot, BitOr, BitShl, BitShr, BitUShr, BitXor, Div, Eq, Exp, Expr, FloorDiv, Greater, GreaterEq, Implies, IntegerValue, Less, LessEq, Local, Minus, Mod, Mult, Neq, Not, Or, Plus, Star, UMinus, Wand} import scala.collection.mutable @@ -9,39 +9,39 @@ import scala.collection.mutable object ExpressionEqualityCheck { def apply[G](info: Option[AnnotationVariableInfo[G]] = None): ExpressionEqualityCheck[G] = new ExpressionEqualityCheck[G](info) - def is_constant_int[G](e: Expr[G]): Option[BigInt] = { - ExpressionEqualityCheck().is_constant_int(e) + def isConstantInt[G](e: Expr[G]): Option[BigInt] = { + ExpressionEqualityCheck().isConstantInt(e) } - def equal_expressions[G](lhs: Expr[G], rhs: Expr[G]): Boolean = { - ExpressionEqualityCheck().equal_expressions(lhs, rhs) + def equalExpressions[G](lhs: Expr[G], rhs: Expr[G]): Boolean = { + ExpressionEqualityCheck().equalExpressions(lhs, rhs) } } class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { - var replacer_depth = 0 - var replacer_depth_int = 0 + var replacerDepth = 0 + var replacerDepthInt = 0 val max_depth = 100 - def is_constant_int(e: Expr[G]): Option[BigInt] = { - replacer_depth_int = 0 - is_constant_int_(e) + def isConstantInt(e: Expr[G]): Option[BigInt] = { + replacerDepthInt = 0 + isConstantIntRecurse(e) } - def is_constant_int_(e: Expr[G]): Option[BigInt] = e match { + def isConstantIntRecurse(e: Expr[G]): Option[BigInt] = e match { case e: Local[G] => // Does it have a direct int value? - info.flatMap(_.variable_values.get(e)) match { + info.flatMap(_.variableValues.get(e)) match { case Some(x) => Some(x) case None => - info.flatMap(_.variable_equalities.get(e)) match { + info.flatMap(_.variableEqualities.get(e)) match { case None => None case Some(equals) => for (eq <- equals) { // Make sure we do not loop indefinitely by keep replacing the same expressions somehow - if (replacer_depth_int > max_depth) return None - replacer_depth_int += 1 - val res = is_constant_int_(eq) + if (replacerDepthInt > max_depth) return None + replacerDepthInt += 1 + val res = isConstantIntRecurse(eq) if (res.isDefined) return res } None @@ -49,31 +49,31 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { } case IntegerValue(value) => Some(value) - case Exp(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1.pow(i2.toInt) - case Plus(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 + i2 - case Minus(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 - i2 - case Mult(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 * i2 - case FloorDiv(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 / i2 - case Mod(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 % i2 - - case BitAnd(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 & i2 - case BitOr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 | i2 - case BitXor(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 ^ i2 - case BitShl(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 << i2.toInt - case BitShr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1 >> i2.toInt - case BitUShr(e1, e2) => for {i1 <- is_constant_int_(e1); i2 <- is_constant_int_(e2)} yield i1.toInt >>> i2.toInt + case Exp(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1.pow(i2.toInt) + case Plus(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 + i2 + case Minus(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 - i2 + case Mult(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 * i2 + case FloorDiv(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 / i2 + case Mod(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 % i2 + + case BitAnd(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 & i2 + case BitOr(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 | i2 + case BitXor(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 ^ i2 + case BitShl(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 << i2.toInt + case BitShr(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1 >> i2.toInt + case BitUShr(e1, e2) => for {i1 <- isConstantIntRecurse(e1); i2 <- isConstantIntRecurse(e2)} yield i1.toInt >>> i2.toInt case _ => None } - def equal_expressions(lhs: Expr[G], rhs: Expr[G]): Boolean = { - replacer_depth = 0 - equal_expressions_(lhs, rhs) + def equalExpressions(lhs: Expr[G], rhs: Expr[G]): Boolean = { + replacerDepth = 0 + equalExpressionsRecurse(lhs, rhs) } // - def equal_expressions_(lhs: Expr[G], rhs: Expr[G]): Boolean = { - (is_constant_int(lhs), is_constant_int(rhs)) match { + def equalExpressionsRecurse(lhs: Expr[G], rhs: Expr[G]): Boolean = { + (isConstantInt(lhs), isConstantInt(rhs)) match { case (Some(i1), Some(i2)) => return i1 == i2 case (None, None) => () //If one is a constant expression, and the other is not, this cannot be the same @@ -85,87 +85,87 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { // Commutative operators case (Plus(lhs1, lhs2), Plus(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (Mult(lhs1, lhs2), Mult(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (BitAnd(lhs1, lhs2), BitAnd(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (BitOr(lhs1, lhs2), BitOr(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (BitXor(lhs1, lhs2), BitXor(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (And(lhs1, lhs2), And(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (Or(lhs1, lhs2), Or(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (Eq(lhs1, lhs2), Eq(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) case (Neq(lhs1, lhs2), Neq(rhs1, rhs2)) => - (equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2)) || - (equal_expressions_(lhs1, rhs2) && equal_expressions_(lhs2, rhs1)) + (equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2)) || + (equalExpressionsRecurse(lhs1, rhs2) && equalExpressionsRecurse(lhs2, rhs1)) //Non commutative operators case (Exp(lhs1, lhs2), Exp(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Minus(lhs1, lhs2), Minus(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Div(lhs1, lhs2), Div(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (FloorDiv(lhs1, lhs2), FloorDiv(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Mod(lhs1, lhs2), Mod(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (BitShl(lhs1, lhs2), BitShl(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (BitShr(lhs1, lhs2), BitShr(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (BitUShr(lhs1, lhs2), BitUShr(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Implies(lhs1, lhs2), Implies(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Star(lhs1, lhs2), Star(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Wand(lhs1, lhs2), Wand(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Greater(lhs1, lhs2), Greater(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (Less(lhs1, lhs2), Less(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (GreaterEq(lhs1, lhs2), GreaterEq(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) case (LessEq(lhs1, lhs2), LessEq(rhs1, rhs2)) => - equal_expressions_(lhs1, rhs1) && equal_expressions_(lhs2, rhs2) + equalExpressionsRecurse(lhs1, rhs1) && equalExpressionsRecurse(lhs2, rhs2) // Unary expressions - case (UMinus(lhs), UMinus(rhs)) => equal_expressions_(lhs, rhs) - case (BitNot(lhs), BitNot(rhs)) => equal_expressions_(lhs, rhs) - case (Not(lhs), Not(rhs)) => equal_expressions_(lhs, rhs) + case (UMinus(lhs), UMinus(rhs)) => equalExpressionsRecurse(lhs, rhs) + case (BitNot(lhs), BitNot(rhs)) => equalExpressionsRecurse(lhs, rhs) + case (Not(lhs), Not(rhs)) => equalExpressionsRecurse(lhs, rhs) // Variables case (name1: Local[G], name2: Local[G]) => if (name1 == name2) true else if (info.isDefined) { // Check if the variables are synonyms - (info.get.variable_synonyms.get(name1), info.get.variable_synonyms.get(name2)) match { + (info.get.variableSynonyms.get(name1), info.get.variableSynonyms.get(name2)) match { case (Some(x), Some(y)) => x == y case _ => false } } else false case (name1: Local[G], e2) => - replace_variable(name1, e2) + replaceVariable(name1, e2) case (e1, name2: Local[G]) => - replace_variable(name2, e1) + replaceVariable(name2, e1) // In the general case, we are just interested in syntactic equality case (e1, e2) => e1 == e2 @@ -173,16 +173,16 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { } // - def replace_variable(name: Local[G], other_e: Expr[G]): Boolean = { + def replaceVariable(name: Local[G], other_e: Expr[G]): Boolean = { if (info.isDefined) { - info.get.variable_equalities.get(name) match { + info.get.variableEqualities.get(name) match { case None => false case Some(equals) => for (eq <- equals) { // Make sure we do not loop indefinitely by keep replacing the same expressions somehow - if (replacer_depth > max_depth) return false - replacer_depth += 1 - if (equal_expressions_(eq, other_e)) return true + if (replacerDepth > max_depth) return false + replacerDepth += 1 + if (equalExpressionsRecurse(eq, other_e)) return true } false } @@ -192,95 +192,95 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { } } -case class AnnotationVariableInfo[G](variable_equalities: Map[Local[G], List[Expr[G]]], variable_values: Map[Local[G], BigInt], - variable_synonyms: Map[Local[G], Int]) +case class AnnotationVariableInfo[G](variableEqualities: Map[Local[G], List[Expr[G]]], variableValues: Map[Local[G], BigInt], + variableSynonyms: Map[Local[G], Int]) /** This class gathers information about variables, such as: * `requires x == 0` and stores that x is equal to the value 0. * Which we can use in simplify steps - * This information is returned with ```get_info(annotations: Iterable[Expr[G]])``` + * This information is returned with ```getInfo(annotations: Iterable[Expr[G]])``` */ class AnnotationVariableInfoGetter[G]() { - val variable_equalities: mutable.Map[Local[G], mutable.ListBuffer[Expr[G]]] = + val variableEqualities: mutable.Map[Local[G], mutable.ListBuffer[Expr[G]]] = mutable.Map() - val variable_values: mutable.Map[Local[G], BigInt] = mutable.Map() + val variableValues: mutable.Map[Local[G], BigInt] = mutable.Map() // We put synonyms in the same group and give them a group number, to identify the same synonym groups - val variable_synonyms: mutable.Map[Local[G], Int] = mutable.Map() - var current_synonym_group = 0 + val variableSynonyms: mutable.Map[Local[G], Int] = mutable.Map() + var currentSynonymGroup = 0 - def extract_equalities(e: Expr[G]): Unit = { + def extractEqualities(e: Expr[G]): Unit = { e match{ case Eq(e1, e2) => (e1, e2) match{ - case (v1: Local[G], v2: Local[G]) => add_synonym(v1, v2) - case (v1: Local[G], _) => add_name(v1, e2) - case (_, v2: Local[G]) => add_name(v2, e1) + case (v1: Local[G], v2: Local[G]) => addSynonym(v1, v2) + case (v1: Local[G], _) => addName(v1, e2) + case (_, v2: Local[G]) => addName(v2, e1) case _ => } case _ => } } - def add_synonym(v1: Local[G], v2: Local[G]): Unit = { - (variable_synonyms.get(v1), variable_synonyms.get(v2)) match { + def addSynonym(v1: Local[G], v2: Local[G]): Unit = { + (variableSynonyms.get(v1), variableSynonyms.get(v2)) match { // We make a new group case (None, None) => - variable_synonyms(v1) = current_synonym_group - variable_synonyms(v2) = current_synonym_group - current_synonym_group += 1 + variableSynonyms(v1) = currentSynonymGroup + variableSynonyms(v2) = currentSynonymGroup + currentSynonymGroup += 1 // Add to the found group - case (Some(id1), None) => variable_synonyms(v2) = id1 - case (None, Some(id2)) => variable_synonyms(v1) = id2 + case (Some(id1), None) => variableSynonyms(v2) = id1 + case (None, Some(id2)) => variableSynonyms(v1) = id2 // Merge the groups, give every synonym group member of id2 value id1 case (Some(id1), Some(id2)) => - variable_synonyms.mapValuesInPlace((_, group) => if (group == id2) id1 else group) + variableSynonyms.mapValuesInPlace((_, group) => if (group == id2) id1 else group) } } - def add_name(v: Local[G], expr: Expr[G]): Unit ={ + def addName(v: Local[G], expr: Expr[G]): Unit ={ // Add to constant list - is_constant_int[G](expr) match { - case Some(x) => variable_values.get(v) match { + isConstantInt[G](expr) match { + case Some(x) => variableValues.get(v) match { case Some(x_) => if (x!=x_) Warning("Value of %s is required to be both %d and %d", v, x, x_); - case None => variable_values(v) = x + case None => variableValues(v) = x } case None => - val list = variable_equalities.getOrElseUpdate(v, mutable.ListBuffer()) + val list = variableEqualities.getOrElseUpdate(v, mutable.ListBuffer()) list.addOne(expr) } } - def get_info(annotations: Iterable[Expr[G]]): AnnotationVariableInfo[G] = { - variable_equalities.clear() - variable_values.clear() + def getInfo(annotations: Iterable[Expr[G]]): AnnotationVariableInfo[G] = { + variableEqualities.clear() + variableValues.clear() for(clause <- annotations){ - extract_equalities(clause) + extractEqualities(clause) } - distribute_info() - AnnotationVariableInfo(variable_equalities.view.mapValues(_.toList).toMap, variable_values.toMap, - variable_synonyms.toMap) + distributeInfo() + AnnotationVariableInfo(variableEqualities.view.mapValues(_.toList).toMap, variableValues.toMap, + variableSynonyms.toMap) } - def distribute_info(): Unit = { + def distributeInfo(): Unit = { // First distribute value knowledge over the rest of the map - val begin_size = variable_values.size + val beginSize = variableValues.size - for((name, equals) <- variable_equalities){ - if(!variable_values.contains(name)) + for((name, equals) <- variableEqualities){ + if(!variableValues.contains(name)) for(equal <- equals){ equal match { case n : Local[G] => - variable_values.get(n).foreach(variable_values(name) = _) + variableValues.get(n).foreach(variableValues(name) = _) case _ => } } } // If sizes are not the same, we know more, so distribute again! - if(variable_values.size != begin_size) distribute_info() + if(variableValues.size != beginSize) distributeInfo() } } \ No newline at end of file diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 7795fcb06e..0faab4d2d7 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -33,7 +33,6 @@ case object SimplifyNestedQuantifiers extends RewriterBuilder { override def desc: String = "Simplify nested quantifiers." } -//case class SimplifyNestedQuantifiers[Pre <: Generation]() extends NonLatchingRewriter[Pre, Rewritten[Pre]] { case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] { case object SimplifyNestedQuantifiersOrigin extends Origin { @@ -46,19 +45,30 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] override def inlineContext: String = "[Simplified expression]" } + case class BinderOrigin(name: String) extends Origin { + override def preferredName: String = name + + override def shortPosition: String = "generated" + + override def context: String = "[At generated expression for the simplification of nested quantifiers]" + + override def inlineContext: String = "[Simplified expression]" + } + private implicit val o: Origin = SimplifyNestedQuantifiersOrigin private def one: IntegerValue[Pre] = IntegerValue(1) // TODO: Supply information towards the expression equality checker when encountering a contract - var equality_checker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck() + var equalityChecker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck() override def dispatch(e: Expr[Pre]): Expr[Post] = { e match { case e: Binder[Pre] => rewriteLinearArray(e) match { case None => rewriteDefault(e) - case Some(e) => e + case Some(newE) + => newE } case other => rewriteDefault(other) } @@ -75,30 +85,30 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] val quantifierData = new RewriteQuantifierData(originalBody, e, this) quantifierData.setData() - quantifierData.check_single_value_variables() - quantifierData.check_independent_variables() + quantifierData.checkSingleValueVariables() + quantifierData.checkIndependentVariables() // Check if we have valid bounds to rewrite, otherwise we stop - if(!quantifierData.check_bounds()) return quantifierData.result() + if(!quantifierData.checkBounds() || quantifierData.checkOtherBinders()) return quantifierData.result() quantifierData.lookForLinearAccesses() } class RewriteQuantifierData(val bindings: mutable.Set[Variable[Pre]], - var lowerBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], - var upperBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], - var upperExclusiveBounds: Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var lowerBounds: mutable.Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var upperBounds: mutable.Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], + var upperExclusiveBounds: mutable.Map[Variable[Pre], ArrayBuffer[Expr[Pre]]], var independentConditions: ArrayBuffer[Expr[Pre]], val dependentConditions: ArrayBuffer[Expr[Pre]], var body: Expr[Pre], val originalBinder: Binder[Pre], - val rewriter: SimplifyNestedQuantifiers[Pre] + val mainRewriter: SimplifyNestedQuantifiers[Pre] ) { def this(originalBody: Expr[Pre], originalBinder: Binder[Pre], rewriter: SimplifyNestedQuantifiers[Pre]) = { this(originalBinder.bindings.to(mutable.Set), - originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, - originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, - originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).toMap, + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).to(mutable.Map), + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).to(mutable.Map), + originalBinder.bindings.map(_ -> ArrayBuffer[Expr[Pre]]()).to(mutable.Map), ArrayBuffer[Expr[Pre]](), ArrayBuffer[Expr[Pre]](), originalBody, @@ -108,17 +118,46 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } /** Keeps track if it is already feasible to make a new quantifier */ - var new_binder = false + var newBinder = false def setData(): Unit = { - val (allConditions, main_body) = unfoldImplies[Pre](body) - body = main_body + val allConditions = unfoldBody(Seq()) // Split bounds that are independent of any binding variables - val (new_independentConditions, potentialBounds) = allConditions.partition(indepOf(bindings, _)) - independentConditions.addAll(new_independentConditions) + val (newIndependentConditions, potentialBounds) = allConditions.partition(indepOf(bindings, _)) + independentConditions.addAll(newIndependentConditions) getBounds(potentialBounds) } + def unfoldBody(prevConditions: Seq[Expr[Pre]]): Seq[Expr[Pre]] = { + val (allConditions, mainBody) = unfoldImplies[Pre](body) + val newConditions = prevConditions ++ allConditions + val (newVars, secondBody) = mainBody match { + case Forall(newVars, _, secondBody) => (newVars, secondBody) + case Starall(newVars, _, secondBody) => (newVars, secondBody) + case _ => return newConditions + } + + bindings.addAll(newVars) + + for(v <- newVars){ + lowerBounds(v) = ArrayBuffer[Expr[Pre]]() + upperBounds(v) = ArrayBuffer[Expr[Pre]]() + upperExclusiveBounds(v) = ArrayBuffer[Expr[Pre]]() + } + + body = secondBody + + unfoldBody(newConditions) + } + + def containsOtherBinders(e: Expr[Pre]): Boolean = { + e match { + case _: Binder[Pre] => return true + case _ => e.transSubnodes.collectFirst { case e: Binder[Pre] => return true } + } + false + } + /** * Process the potential bounds to be either a bound or just a dependent condition. * @param potentialBounds Bounds to be processed. @@ -139,7 +178,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] right match { case Local(Ref(v)) if bindings.contains(v) => // If the quantified variable is the second argument: flip the relation - addSingleBound(v, right, comp.flip) + addSingleBound(v, left, comp.flip) case _ => dependentConditions.addOne(bound) } } @@ -188,15 +227,15 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] * We don't worry if a we have something like x == 5 && x < 0, since that will resolve to 5 < 0, which equally * does not work. * */ - def check_single_value_variables(): Unit = { + def checkSingleValueVariables(): Unit = { for (name <- bindings) { - val equal_bounds = lowerBounds(name).intersect(upperBounds(name)) - if (equal_bounds.nonEmpty) { + val equalBounds = lowerBounds(name).intersect(upperBounds(name)) + if (equalBounds.nonEmpty) { // We will put out a new quantifier - new_binder = true - val new_value = equal_bounds.head - val name_var: Expr[Pre] = Local(name.ref) - val sub = Substitute[Pre](Map(name_var -> new_value)) + newBinder = true + val newValue = equalBounds.head + val nameVar: Expr[Pre] = Local(name.ref) + val sub = Substitute[Pre](Map(nameVar -> newValue)) val replacer = sub.dispatch(_: Expr[Pre]) body = replacer(body) @@ -213,45 +252,45 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // Bounds for the name, have now become independent conditions lowerBounds(name).foreach(lb => - if (lb != new_value) independentConditions.addOne(LessEq(lb, new_value))) + if (lb != newValue) independentConditions.addOne(LessEq(lb, newValue))) upperBounds(name).foreach(ub => - if (ub != new_value) independentConditions.addOne(LessEq(new_value, ub))) + if (ub != newValue) independentConditions.addOne(LessEq(newValue, ub))) - lowerBounds = lowerBounds.removed(name) - upperBounds = upperBounds.removed(name) - upperExclusiveBounds = upperExclusiveBounds.removed(name) + lowerBounds.remove(name) + upperBounds.remove(name) + upperExclusiveBounds.remove(name) // Strictly speaking, a binding variable could be newly removed, if a previous one has been found constant // and then the bounds deem another binding variable also constant. We check that by doing recursion. - check_single_value_variables() + checkSingleValueVariables() return } } } - def check_independent_variables(): Unit = { + def checkIndependentVariables(): Unit = { for (name <- bindings) { if (indepOf(mutable.Set(name), body)) { var independent = true dependentConditions.foreach(s => if (!indepOf(mutable.Set(name), s)) independent = false) if (independent) { // We can freely remove this named variable - val max_bound = extremeValue(name, maximizing = true) - val min_bound = extremeValue(name, maximizing = false) - (max_bound, min_bound) match { - case (Some(max_bound), Some(min_bound)) => - new_binder = true + val maxBound = extremeValue(name, maximizing = true) + val minBound = extremeValue(name, maximizing = false) + (maxBound, minBound) match { + case (Some(maxBound), Some(minBound)) => + newBinder = true // Do not quantify over name anymore bindings.remove(name) - lowerBounds = lowerBounds.removed(name) - upperBounds = upperBounds.removed(name) - upperExclusiveBounds = upperExclusiveBounds.removed(name) + lowerBounds.remove(name) + upperBounds.remove(name) + upperExclusiveBounds.remove(name) // We remove the forall variable i, but need to rewrite some expressions // (forall i; a <= i <= b; ...Perm(ar, x)...) ====> b>=a ==> ...Perm(ar, x*(b-a+1))... - independentConditions.addOne(GreaterEq(max_bound, min_bound)) + independentConditions.addOne(GreaterEq(maxBound, minBound)) - body = Scale(Plus(one, Minus(max_bound, min_bound)), body)( + body = Scale(Plus(one, Minus(maxBound, minBound)), body)( PanicBlame("Error in SimplifyNestedQuantifiers class, implication should make sure scale is" + " never negative when accessed.")) // rp.dispatch(body) case _ => @@ -284,26 +323,42 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // This allows only forall's to be rewritten, if they have at least one lower bound of zero // TODO: Generalize this, so we don't have this restriction - def check_bounds(): Boolean = { + def checkBounds(): Boolean = { + if(bindings.size == 1) { + val name = bindings.head + return upperBounds.getOrElse(name, ArrayBuffer()).nonEmpty && + lowerBounds.getOrElse(name, ArrayBuffer()).nonEmpty + } + for (name <- bindings) { - var one_zero = false + var oneZero = false val zero = BigInt(0) lowerBounds.getOrElse(name, ArrayBuffer()) - .foreach(lower => equality_checker.is_constant_int(lower) match { - case Some(`zero`) => one_zero = true + .foreach(lower => equalityChecker.isConstantInt(lower) match { + case Some(`zero`) => oneZero = true case _ => }) //Exit when notAt least one zero, or no upper bounds - if (!one_zero || upperBounds.getOrElse(name, ArrayBuffer()).isEmpty) { + if (!oneZero || upperBounds.getOrElse(name, ArrayBuffer()).isEmpty) { return false } } true } + // Returns true if contains other binders, which we won't rewrite + def checkOtherBinders(): Boolean = { + independentConditions.foldLeft(containsOtherBinders(body))(_ || containsOtherBinders(_)) + } + case class ForallSubstitute(subs: Map[Expr[Pre], Expr[Post]]) extends Rewriter[Pre] { + override def lookupSuccessor: Declaration[Pre] => Option[Declaration[Post]] = { + val here = mainRewriter.lookupSuccessor + decl => here(decl) + } + override def dispatch(e: Expr[Pre]): Expr[Post] = e match { case expr if subs.contains(expr) => subs(expr) case other => rewriteDefault(other) @@ -312,7 +367,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def lookForLinearAccesses(): Option[Expr[Post]] = { val linearAccesses = new FindLinearArrayAccesses(this) - withCollectInScope(rewriter.variableScopes) {linearAccesses.search(body)} match { + withCollectInScope(mainRewriter.variableScopes) {linearAccesses.search(body)} match { case (bindings, Some(substituteForall)) => if(bindings.size != 1){ ??? @@ -334,7 +389,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def result(): Option[Expr[Post]] = { // If we changed something we always return a result, even if we could not rewrite further - val res = if(new_binder) { + val res = if(newBinder) { val select = independentConditions ++ dependentConditions if (bindings.isEmpty) { if (select.isEmpty) Some(body) else Some(Implies(AstBuildHelpers.foldAnd(select.toSeq), body)) @@ -369,7 +424,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] None } - res.map(rewriter.rewriteDefault) + res.map(mainRewriter.rewriteDefault) } } @@ -379,34 +434,32 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] class FindLinearArrayAccesses(quantifierData: RewriteQuantifierData){ // Search for linear array expressions - def search(e: Expr[Pre]): Option[SubstituteForall] = { + def search(e: Node[Pre]): Option[SubstituteForall] = { e match { case e @ ArraySubscript(_, index) => if (indepOf(quantifierData.bindings, index)) { return None } - linear_expression(e) match { - case Some(substitute_forall) => Some(substitute_forall) + linearExpression(e) match { + case Some(substituteForall) => Some(substituteForall) case None => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} } case _ => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} } } - def search(n: Node[Pre]): Option[SubstituteForall] = None - - def linear_expression(e: ArraySubscript[Pre]): Option[SubstituteForall] = { + def linearExpression(e: ArraySubscript[Pre]): Option[SubstituteForall] = { val ArraySubscript(_, index) = e val pot = new PotentialLinearExpressions(e) pot.visit(index) - pot.can_rewrite() + pot.canRewrite() } class PotentialLinearExpressions(val arrayIndex: ArraySubscript[Pre]){ - val linear_expressions: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() - var constant_expression: Option[Expr[Pre]] = None - var is_linear: Boolean = true - var current_multiplier: Option[Expr[Pre]] = None + val linearExpressions: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() + var constantExpression: Option[Expr[Pre]] = None + var isLinear: Boolean = true + var currentMultiplier: Option[Expr[Pre]] = None def visit(e: Expr[Pre]): Unit = { e match{ @@ -426,69 +479,69 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // if the first is constant, the second argument cannot be if (isConstant(left)) { addToConstant(left) - val old_multiplier = current_multiplier + val old_multiplier = currentMultiplier multiplyMultiplier(IntegerValue(-1)) visit(right) - current_multiplier = old_multiplier + currentMultiplier = old_multiplier } else if (isConstant(right)) { addToConstant(right, is_plus=false) visit(left) } else { // Both arguments contain linear information visit(left) - val old_multiplier = current_multiplier + val old_multiplier = currentMultiplier multiplyMultiplier(IntegerValue(-1)) visit(right) - current_multiplier = old_multiplier + currentMultiplier = old_multiplier } case Mult(left, right) => if (isConstant(left)) { - val old_multiplier = current_multiplier + val old_multiplier = currentMultiplier multiplyMultiplier(left) visit(right) - current_multiplier = old_multiplier + currentMultiplier = old_multiplier } else if (isConstant(right)) { - val old_multiplier = current_multiplier + val old_multiplier = currentMultiplier multiplyMultiplier(right) visit(left) - current_multiplier = old_multiplier + currentMultiplier = old_multiplier } else { - is_linear = false + isLinear = false } // TODO: Check if division is right conceptually with an example. Take special care to think about // the order of division case e@Div(left, right) => if (isConstant(right)){ - val old_multiplier = current_multiplier + val old_multiplier = currentMultiplier multiplyMultiplier(Div(IntegerValue(1), right)(e.blame)) visit(left) - current_multiplier = old_multiplier + currentMultiplier = old_multiplier } else { - is_linear = false + isLinear = false } case Local(ref) => if(quantifierData.bindings.contains(ref.decl)) { - linear_expressions get ref.decl match { - case None => linear_expressions(ref.decl) = current_multiplier.getOrElse(IntegerValue(1)) - case Some(old) => linear_expressions(ref.decl) = - Plus(old, current_multiplier.getOrElse(IntegerValue(1))) + linearExpressions get ref.decl match { + case None => linearExpressions(ref.decl) = currentMultiplier.getOrElse(IntegerValue(1)) + case Some(old) => linearExpressions(ref.decl) = + Plus(old, currentMultiplier.getOrElse(IntegerValue(1))) } } else { Unreachable("We should not end up here, the precondition of \'FindLinearArrayAccesses\' was not uphold.") } case _ => - is_linear = false + isLinear = false } } - def can_rewrite(): Option[SubstituteForall] = { - if(!is_linear) { + def canRewrite(): Option[SubstituteForall] = { + if(!isLinear) { return None } // Checking the preconditions of the check_vars_list function if(quantifierData.bindings.isEmpty) return None for(v <- quantifierData.bindings){ - if(!(linear_expressions.contains(v) && + if(!(linearExpressions.contains(v) && quantifierData.upperExclusiveBounds.contains(v) && quantifierData.upperExclusiveBounds(v).nonEmpty) ) { @@ -506,7 +559,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] * * Precondition: * * At least one var in `quantifierData.bindings` - * * linear_expressions has an expression for all `vars` + * * linearExpressions has an expression for all `vars` * * quantifierData.upperExclusiveBounds has a non-empty list for all `vars` * * We are looking for patterns: @@ -526,32 +579,33 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] */ def check_vars_list(vars: List[Variable[Pre]]): Option[SubstituteForall] = { val x_0 = vars.head - val a_0 = linear_expressions(x_0) + val a_0 = linearExpressions(x_0) // x_{i-1}, a_{i-1}, n_{i-1} var x_i_last = x_0 var a_i_last = a_0 var n_i_last: Expr[Pre] = null val ns : mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() - val x_new = new Variable[Post](TInt()) + val new_name = vars.tail.foldLeft(vars.head.o.preferredName)(_ + "_" + _.o.preferredName) + val x_new = new Variable[Post](TInt())(BinderOrigin(new_name)) - val newGen: Expr[Pre] => Expr[Post] = quantifierData.rewriter.dispatch(_) + val newGen: Expr[Pre] => Expr[Post] = quantifierData.mainRewriter.dispatch(_) // x_base == (x_new -b) - val x_base: Expr[Post]= constant_expression match { + val x_base: Expr[Post]= constantExpression match { case None => Local(x_new.ref) case Some(b) => Minus(Local(x_new.ref), newGen(b)) } val replace_map: mutable.Map[Expr[Pre], Expr[Post]] = mutable.Map() for(x_i <- vars.tail){ - val a_i = linear_expressions(x_i) + val a_i = linearExpressions(x_i) var found_valid_n = false // Find a suitable upper bound for (n_i_last_candidate <- quantifierData.upperExclusiveBounds(x_i_last)) { - if( !found_valid_n && equality_checker.equal_expressions(a_i, simplified_mult(a_i_last, n_i_last_candidate)) ) { + if( !found_valid_n && equalityChecker.equalExpressions(a_i, simplified_mult(a_i_last, n_i_last_candidate)) ) { found_valid_n = true n_i_last = n_i_last_candidate ns(x_i_last) = n_i_last_candidate @@ -575,7 +629,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } // We found a replacement! // Make the declaration final - x_new.declareDefault(quantifierData.rewriter) + x_new.declareDefault(quantifierData.mainRewriter) val ArraySubscript(arr, index) = arrayIndex // Replace the linear expression with the new variable val x_new_var: Expr[Post] = Local(x_new.ref) @@ -632,25 +686,25 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def isConstant(node: Expr[Pre]): Boolean = indepOf(quantifierData.bindings, node) def addToConstant(node : Expr[Pre], is_plus: Boolean = true): Unit = { - val added_node: Expr[Pre] = current_multiplier match { + val added_node: Expr[Pre] = currentMultiplier match { case None => node case Some(expr) => Mult(expr, node) } - constant_expression = Some(constant_expression match { + constantExpression = Some(constantExpression match { case None => if(is_plus) added_node else Mult(IntegerValue(-1), added_node) case Some(expr) => if(is_plus) Plus(expr, added_node) else Minus(expr, added_node) }) } def multiplyMultiplier(node : Expr[Pre]): Unit ={ - current_multiplier match { - case None => current_multiplier = Some(node); - case Some(expr) => current_multiplier = Some(Mult(expr, node)) + currentMultiplier match { + case None => currentMultiplier = Some(node); + case Some(expr) => currentMultiplier = Some(Mult(expr, node)) } } def is_value(e: Expr[Pre], x: Int): Boolean = - equality_checker.is_constant_int(e) match { + equalityChecker.isConstantInt(e) match { case None => false case Some(y) => y == x } @@ -660,22 +714,22 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // // The `new_forall_var` will be the name of variable of the newly made forall. // // The `newBounds`, will contain all the new equations for "select" part of the forall. // // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions - // // We also store the `linear_expression`, so if we ever come across it, we can replace it with the new variable. + // // We also store the `linearExpression`, so if we ever come across it, we can replace it with the new variable. case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTrigger: Seq[Seq[Expr[Post]]]) } // -// var equality_checker: ExpressionEqualityCheck = ExpressionEqualityCheck() +// var equalityChecker: ExpressionEqualityCheck = ExpressionEqualityCheck() // // override def visit(special: ASTSpecial): Unit = { // if(special.kind == ASTSpecial.Kind.Inhale){ // val info_getter = new AnnotationVariableInfoGetter() // val annotations = ASTUtils.conjuncts(special.args(0), StandardOperator.Star).asScala -// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) // // result = create special(special.kind, rewrite(special.args):_*) // -// equality_checker = ExpressionEqualityCheck() +// equalityChecker = ExpressionEqualityCheck() // // } else { // result = create special(special.kind, rewrite(special.args): _*) @@ -700,9 +754,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // val annotations = LazyList(ASTUtils.conjuncts(contract.pre_condition, StandardOperator.Star).asScala // , ASTUtils.conjuncts(contract.invariant, StandardOperator.Star).asScala).flatten // -// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) // rewrite(contract, currentContractBuilder) -// equality_checker = ExpressionEqualityCheck() +// equalityChecker = ExpressionEqualityCheck() // } // res.setContract(currentContractBuilder.getContract) // currentContractBuilder = null @@ -725,9 +779,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala // , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten // -// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) // res.setContract(rewrite(mc)) -// equality_checker = ExpressionEqualityCheck() +// equalityChecker = ExpressionEqualityCheck() // } else { // res.setContract(rewrite(mc)) // } @@ -754,9 +808,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala // , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten // -// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) // res.appendContract(rewrite(mc)) -// equality_checker = ExpressionEqualityCheck() +// equalityChecker = ExpressionEqualityCheck() // } else { // res.appendContract(rewrite(mc)) // } @@ -787,11 +841,11 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala // , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten // -// equality_checker = ExpressionEqualityCheck(Some(info_getter.get_info(annotations))) +// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) // // rewrite(mc, currentContractBuilder) // c = currentContractBuilder.getContract(false) -// equality_checker = ExpressionEqualityCheck() +// equalityChecker = ExpressionEqualityCheck() // } // else { // c = currentContractBuilder.getContract(true) diff --git a/src/main/java/vct/main/stages/Transformation.scala b/src/main/java/vct/main/stages/Transformation.scala index 91e30be908..50a12439f8 100644 --- a/src/main/java/vct/main/stages/Transformation.scala +++ b/src/main/java/vct/main/stages/Transformation.scala @@ -169,6 +169,7 @@ case class SilverTransformation CheckContractSatisfiability.withArg(checkSat), ) ++ simplifyBeforeRelations ++ Seq( SimplifyQuantifiedRelations, + SimplifyNestedQuantifiers, ) ++ simplifyAfterRelations ++ Seq( ResolveExpressionSideChecks, From 280d32b24eceb516217ea03d4580f624f8930ea8 Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 10 Aug 2022 11:02:19 +0200 Subject: [PATCH 04/25] Parblock without extra variables --- .../vct/col/newrewrite/ParBlockEncoder.scala | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala index 5d123f40c5..652e06a50f 100644 --- a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala +++ b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala @@ -131,23 +131,15 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { quantify(block, block.context_everywhere &* block.ensures) } - def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Statement[Post] = region match { - case ParParallel(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) - case ParSequential(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) + def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Unit = region match { + case ParParallel(regions) => regions.foreach(ranges(_, rangeValues)) + case ParSequential(regions) => regions.foreach(ranges(_, rangeValues)) case block @ ParBlock(decl, iters, _, _, _, _) => decl.drop() blockDecl(decl) = block - Block(iters.map { v => - implicit val o: Origin = v.o - val lo = new Variable(TInt())(LowEvalOrigin(v)).declareDefault(this) - val hi = new Variable(TInt())(HighEvalOrigin(v)).declareDefault(this) - rangeValues(v.variable) = (lo.get, hi.get) - - Block(Seq( - assignLocal(lo.get, dispatch(v.from)), - assignLocal(hi.get, dispatch(v.to)), - )) - })(region.o) + iters.foreach { v => + rangeValues(v.variable) = (dispatch(v.from), dispatch(v.to)) + } } def execute(region: ParRegion[Pre]): Statement[Post] = { @@ -192,18 +184,15 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { val rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])] = mutable.Map() - val (vars, evalRanges) = withCollectInScope(variableScopes) { - ranges(region, rangeValues) - } + ranges(region, rangeValues) currentRanges.having(rangeValues.toMap) { - Scope(vars, Block(Seq( - evalRanges, + Block(Seq( IndetBranch(Seq( execute(region), Block(Seq(check(region), Inhale(ff))) )), - ))) + )) } case inv @ ParInvariant(decl, dependentInvariant, body) => From f2e26a576c3b48172535267537f9069122b27eef Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 10 Aug 2022 11:18:28 +0200 Subject: [PATCH 05/25] Add names of forall binders after rewriting --- .../col/newrewrite/ApplyTermRewriter.scala | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala index 87ef977bba..223fbf9a12 100644 --- a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala +++ b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala @@ -49,6 +49,14 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] "of which the body is an equality.") } + case class MalformedSimplificationRuleBinders(body: Expr[_]) extends UserError { + override def code: String = "malformedSimpRuleBinders" + override def text: String = + body.o.messageInContext( + "The body of a simplification rule with nested \\forall predicates, " + + "must have consistent order of forall binders indicators (e.g. not (r1!t1,t2) and (r2!t2,t1)") + } + val rules: Seq[(Seq[Variable[Rule]], Expr[Rule], Expr[Rule], Origin)] = ruleNodes.map(node => (node.o, consumeForalls(node.axiom))).map { case (o, (free, body)) => body match { case Eq(left, right) => (free, left, right, o) @@ -79,7 +87,38 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] override def dispatch(decl: Declaration[Pre]): Unit = decl.succeedDefault(decl) } - case class ApplyRule(inst: Map[Variable[Rule], (Expr[Pre], Seq[Variable[Pre]])], typeInst: Map[Variable[Rule], Type[Pre]], defaultOrigin: Origin) extends NonLatchingRewriter[Rule, Pre] { + case class ApplyRule(inst: Map[Variable[Rule], (Expr[Pre], Seq[Variable[Pre]])], typeInst: Map[Variable[Rule], Type[Pre]], + defaultOrigin: Origin, rule: Expr[Rule]) extends NonLatchingRewriter[Rule, Pre] { + val binderOrigins: mutable.Map[Variable[Rule], Origin] = mutable.Map.empty + + def addBinderOrigin(newVar: Ref[Rule,Variable[Rule]], originalVar: Variable[Pre]): Unit = { + val newOrigin = originalVar.o + if (binderOrigins.contains(newVar.decl) && binderOrigins(newVar.decl) != newOrigin) { + throw MalformedSimplificationRuleBinders(rule) + } + else { + binderOrigins(newVar.decl) = newOrigin + } + } + + /* We want to reuse the names of forall binders, we do this by how the binders are captured, e.g.: + * axiom starall_star { + * (∀type T, resource r1, resource r2; + * (∀* T t1; (r1!t1) ** (r2!t1)) == + * ((∀* T t2; (r1!t2)) ** (∀* T t3; (r2!t3)))) + * } + * we capture the name "t1" by the pattern "(r1!t1)" and "(r1!t2) and (r2!t3) to name "t2" and "t3" the same + * as we have no other way to relate the forall's since they internally bind to different variables. + */ + def findBinderOrigin(e: Node[Rule]): Unit = { + e match { + case FunctionOf(Ref(v), ruleVars) => + val (_, vars) = inst(v) + ruleVars.zip(vars).foreach(t => addBinderOrigin(t._1, t._2)) + case _ => e.subnodes.foreach(findBinderOrigin) + } + } + override def dispatch(o: Origin): Origin = defaultOrigin override def dispatch(e: Expr[Rule]): Expr[Pre] = e match { @@ -89,6 +128,15 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] case FunctionOf(Ref(v), ruleVars) => val (replacement, vars) = inst(v) ApplyParametricBindings(vars.zip(ruleVars.map(succ[Variable[Pre]])).toMap).dispatch(replacement) + case e: Binder[Rule] => + findBinderOrigin(e) + rewriteDefault(e) + case other => rewriteDefault(other) + } + + override def dispatch(d: Declaration[Rule]): Unit = d match { + case v: Variable[Rule] => + new Variable(dispatch(v.t))(binderOrigins.getOrElse(v,defaultOrigin)).succeedDefault(v)(this) case other => rewriteDefault(other) } @@ -209,7 +257,7 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] return None } - val result = ApplyRule(inst.toMap, typeInst.toMap, subject.o).dispatch(subtitute) + val result = ApplyRule(inst.toMap, typeInst.toMap, subject.o, subtitute).dispatch(subtitute) if(debugMatch && debugFilter) { if(debugMatchShort) { From c5fdc5e6fa5d9ea2e60b13755257ea16c8b21f5b Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 15 Aug 2022 17:21:34 +0200 Subject: [PATCH 06/25] Take account equalities whilst simplifying quantifiers --- .../ast/util/ExpressionEqualityCheck.scala | 11 +++- .../SimplifyNestedQuantifiers.scala | 55 ++++++++++++++++++- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala index 1232c3ca62..8ecdb2d413 100644 --- a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala +++ b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala @@ -2,7 +2,8 @@ package vct.col.ast.util import hre.lang.System.Warning import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt -import vct.col.ast.{And, BitAnd, BitNot, BitOr, BitShl, BitShr, BitUShr, BitXor, Div, Eq, Exp, Expr, FloorDiv, Greater, GreaterEq, Implies, IntegerValue, Less, LessEq, Local, Minus, Mod, Mult, Neq, Not, Or, Plus, Star, UMinus, Wand} +import vct.col.ast.{And, BitAnd, BitNot, BitOr, BitShl, BitShr, BitUShr, BitXor, Div, Eq, Exp, Expr, FloorDiv, Greater, GreaterEq, Implies, IntegerValue, Less, LessEq, Local, Loop, Minus, Mod, Mult, Neq, Not, Or, Plus, Star, UMinus, Wand} +import vct.result.VerificationError.UserError import scala.collection.mutable @@ -18,6 +19,11 @@ object ExpressionEqualityCheck { } } +case class InconsistentVariableEquality(v: Local[_], x: BigInt, y: BigInt) extends UserError { + override def code: String = "inconsistentVariableEquality" + override def text: String = s"Inconsistent variable equality: value of $v is required to be both $x and $y" +} + class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { var replacerDepth = 0 var replacerDepthInt = 0 @@ -242,7 +248,8 @@ class AnnotationVariableInfoGetter[G]() { // Add to constant list isConstantInt[G](expr) match { case Some(x) => variableValues.get(v) match { - case Some(x_) => if (x!=x_) Warning("Value of %s is required to be both %d and %d", v, x, x_); + case Some(y) => if (x!=y) + throw InconsistentVariableEquality(v, x, y) case None => variableValues(v) = x } case None => diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 0faab4d2d7..716db2127b 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -1,9 +1,9 @@ package vct.col.newrewrite import vct.col.ast.{ArraySubscript, _} -import vct.col.ast.util.ExpressionEqualityCheck +import vct.col.ast.util.{AnnotationVariableInfoGetter, ExpressionEqualityCheck} import vct.col.newrewrite.util.Comparison -import vct.col.origin.{Origin, PanicBlame} +import vct.col.origin.{Blame, NontrivialUnsatisfiable, Origin, PanicBlame} import vct.col.ref.Ref import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ @@ -74,6 +74,53 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } } + override def dispatch(stat: Statement[Pre]): Statement[Post] = { + val e = stat match { + case Exhale(e) => e + case Inhale(e) => e + case _ => return rewriteDefault(stat) + } + + val conditions = getConditions(e) + val infoGetter = new AnnotationVariableInfoGetter[Pre]() + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.getInfo(conditions))) + val result = rewriteDefault(stat) + equalityChecker = ExpressionEqualityCheck() + result + } + + def getConditions(preds: AccountedPredicate[Pre]): Seq[Expr[Pre]] = preds match { + case UnitAccountedPredicate(pred) => getConditions(pred) + case SplitAccountedPredicate(left, right) => getConditions(left) ++ getConditions(right) + } + + def getConditions(e: Expr[Pre]): Seq[Expr[Pre]] = e match { + case And(left, right) => getConditions(left) ++ getConditions(right) + case Star(left, right) => getConditions(left) ++ getConditions(right) + case other => Seq[Expr[Pre]](other) + } + + override def dispatch(contract: ApplicableContract[Pre]): ApplicableContract[Post] = { + val infoGetter = new AnnotationVariableInfoGetter[Pre]() + val reqConditions = getConditions(contract.requires) + val contextConditions = getConditions(contract.contextEverywhere) + val ensureConditions = getConditions(contract.ensures) + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.getInfo(reqConditions ++ contextConditions))) + val requires = dispatch(contract.requires) + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.getInfo(ensureConditions ++ contextConditions))) + val ensures = dispatch(contract.ensures) + equalityChecker = ExpressionEqualityCheck(Some(infoGetter.getInfo(contextConditions))) + val contextEverywhere = dispatch(contract.contextEverywhere) + equalityChecker = ExpressionEqualityCheck() + + val signals = contract.signals.map(element => dispatch(element)) + val givenArgs = collectInScope(variableScopes) {contract.givenArgs.foreach(dispatch)} + val yieldsArgs = collectInScope(variableScopes) {contract.yieldsArgs.foreach(dispatch)} + val decreases = contract.decreases.map(element => rewriter.dispatch(element)) + + ApplicableContract(requires, ensures, contextEverywhere, signals, givenArgs, yieldsArgs, decreases)(contract.blame)(contract.o) + } + def rewriteLinearArray(e: Binder[Pre]): Option[Expr[Post]] = { val originalBody = e match { case Forall(_, _, body) => body @@ -134,7 +181,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] val (newVars, secondBody) = mainBody match { case Forall(newVars, _, secondBody) => (newVars, secondBody) case Starall(newVars, _, secondBody) => (newVars, secondBody) - case _ => return newConditions + case _ => + body = mainBody + return newConditions } bindings.addAll(newVars) From 73cd952b70621bb3d46ccf9659fe313b3d17f489 Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 17 Aug 2022 14:07:18 +0200 Subject: [PATCH 07/25] Small c file read fix & debug info print --- parsers/src/main/java/vct/parsers/ColCParser.scala | 2 +- src/main/java/vct/main/modes/Verify.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parsers/src/main/java/vct/parsers/ColCParser.scala b/parsers/src/main/java/vct/parsers/ColCParser.scala index 7cf5cc9d41..94549ef43d 100644 --- a/parsers/src/main/java/vct/parsers/ColCParser.scala +++ b/parsers/src/main/java/vct/parsers/ColCParser.scala @@ -51,7 +51,7 @@ case class ColCParser(override val originProvider: OriginProvider, val interpreted = File.createTempFile("vercors-interpreted-", ".i") interpreted.deleteOnExit() - val process = interpret(localInclude=Seq(Paths.get(readable.fileName).getParent), input="-", output=interpreted.toString) + val process = interpret(localInclude=Seq(Paths.get(readable.fileName).toAbsolutePath.getParent), input="-", output=interpreted.toString) new Thread(() => { val writer = new OutputStreamWriter(process.getOutputStream, StandardCharsets.UTF_8) readable.read { reader => diff --git a/src/main/java/vct/main/modes/Verify.scala b/src/main/java/vct/main/modes/Verify.scala index 7a432acd8a..8f269cdbcf 100644 --- a/src/main/java/vct/main/modes/Verify.scala +++ b/src/main/java/vct/main/modes/Verify.scala @@ -33,7 +33,7 @@ case object Verify extends LazyLogging { def verifyWithOptions(options: Options, inputs: Seq[Readable]): Either[VerificationError, Seq[VerificationFailure]] = { val collector = BlameCollector() val stages = Stages.ofOptions(options, ConstantBlameProvider(collector)) - logger.debug(stages.toString) + logger.debug("Stages: " ++ stages.flatNames.map(_._1).mkString(", ")) stages.run(inputs) match { case Left(error) => Left(error) case Right(()) => Right(collector.errs.toSeq) From 95b5df7b73718d1ab89fbb372b29bfba62078707 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 29 Aug 2022 11:10:13 +0200 Subject: [PATCH 08/25] merge --- .../col/newrewrite/ApplyTermRewriter.scala | 2 +- .../SimplifyNestedQuantifiers.scala | 22 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala index 47daa7cd0f..a473bb2e46 100644 --- a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala +++ b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala @@ -142,7 +142,7 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] override def dispatch(d: Declaration[Rule]): Unit = d match { case v: Variable[Rule] => - new Variable(dispatch(v.t))(binderOrigins.getOrElse(v,defaultOrigin)).succeedDefault(v)(this) + allScopes.anySucceedOnly(v, new Variable(dispatch(v.t))(binderOrigins.getOrElse(v,defaultOrigin))) case other => rewriteDefault(other) } diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 716db2127b..15d1136ee5 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -14,6 +14,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigInt import scala.annotation.nowarn +import scala.reflect.ClassTag /** * This rewrite pass simplifies expressions of roughly this form: @@ -114,8 +115,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] equalityChecker = ExpressionEqualityCheck() val signals = contract.signals.map(element => dispatch(element)) - val givenArgs = collectInScope(variableScopes) {contract.givenArgs.foreach(dispatch)} - val yieldsArgs = collectInScope(variableScopes) {contract.yieldsArgs.foreach(dispatch)} + val givenArgs = variables.collect { contract.givenArgs.foreach(dispatch) }._1 + val yieldsArgs = variables.collect {contract.yieldsArgs.foreach(dispatch)}._1 val decreases = contract.decreases.map(element => rewriter.dispatch(element)) ApplicableContract(requires, ensures, contextEverywhere, signals, givenArgs, yieldsArgs, decreases)(contract.blame)(contract.o) @@ -403,10 +404,13 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] case class ForallSubstitute(subs: Map[Expr[Pre], Expr[Post]]) extends Rewriter[Pre] { - override def lookupSuccessor: Declaration[Pre] => Option[Declaration[Post]] = { - val here = mainRewriter.lookupSuccessor - decl => here(decl) - } +// override def lookupSuccessor: Declaration[Pre] => Option[Declaration[Post]] = { +// val here = mainRewriter.lookupSuccessor +// decl => here(decl) +// } + + override def anySucc[RefDecl <: Declaration[Post]](decl: Declaration[Pre])(implicit tag: ClassTag[RefDecl]): Ref[Post, RefDecl] + = mainRewriter.anySucc(decl)(tag) override def dispatch(e: Expr[Pre]): Expr[Post] = e match { case expr if subs.contains(expr) => subs(expr) @@ -416,7 +420,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def lookForLinearAccesses(): Option[Expr[Post]] = { val linearAccesses = new FindLinearArrayAccesses(this) - withCollectInScope(mainRewriter.variableScopes) {linearAccesses.search(body)} match { + + mainRewriter.variables.collect {linearAccesses.search(body)} match { case (bindings, Some(substituteForall)) => if(bindings.size != 1){ ??? @@ -678,7 +683,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } // We found a replacement! // Make the declaration final - x_new.declareDefault(quantifierData.mainRewriter) + quantifierData.mainRewriter.variables.declare(x_new) +// x_new.declareDefault(quantifierData.mainRewriter) val ArraySubscript(arr, index) = arrayIndex // Replace the linear expression with the new variable val x_new_var: Expr[Post] = Local(x_new.ref) From 4cd25a74e72e61ba53dea4ff23f5bf03b8ab7388 Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 30 Aug 2022 15:58:07 +0200 Subject: [PATCH 09/25] Added greater than/not zero information --- .../ast/util/ExpressionEqualityCheck.scala | 106 ++++++++++++++---- 1 file changed, 87 insertions(+), 19 deletions(-) diff --git a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala index 8ecdb2d413..876c5c7695 100644 --- a/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala +++ b/col/src/main/java/vct/col/ast/util/ExpressionEqualityCheck.scala @@ -1,6 +1,5 @@ package vct.col.ast.util -import hre.lang.System.Warning import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt import vct.col.ast.{And, BitAnd, BitNot, BitOr, BitShl, BitShr, BitUShr, BitXor, Div, Eq, Exp, Expr, FloorDiv, Greater, GreaterEq, Implies, IntegerValue, Less, LessEq, Local, Loop, Minus, Mod, Mult, Neq, Not, Or, Plus, Star, UMinus, Wand} import vct.result.VerificationError.UserError @@ -77,6 +76,16 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { equalExpressionsRecurse(lhs, rhs) } + def isGreaterThanZero(e: Expr[G]): Boolean = e match { + case v: Local[G] => info.exists(_.variableGreaterThanZero.contains(v)) + case _ => isConstantInt(e).getOrElse(0: BigInt) > 0 + } + + def isNonZero(e: Expr[G]):Boolean = e match { + case v: Local[G] => info.exists(_.variableNotZero.contains(v)) + case _ => isConstantInt(e).getOrElse(0) != 0 + } + // def equalExpressionsRecurse(lhs: Expr[G], rhs: Expr[G]): Boolean = { (isConstantInt(lhs), isConstantInt(rhs)) match { @@ -199,7 +208,7 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) { } case class AnnotationVariableInfo[G](variableEqualities: Map[Local[G], List[Expr[G]]], variableValues: Map[Local[G], BigInt], - variableSynonyms: Map[Local[G], Int]) + variableSynonyms: Map[Local[G], Int], variableNotZero: Set[Local[G]], variableGreaterThanZero: Set[Local[G]]) /** This class gathers information about variables, such as: * `requires x == 0` and stores that x is equal to the value 0. @@ -213,7 +222,10 @@ class AnnotationVariableInfoGetter[G]() { val variableValues: mutable.Map[Local[G], BigInt] = mutable.Map() // We put synonyms in the same group and give them a group number, to identify the same synonym groups val variableSynonyms: mutable.Map[Local[G], Int] = mutable.Map() + val variableNotZero: mutable.Set[Local[G]] = mutable.Set() + val variableGreaterThanZero: mutable.Set[Local[G]] = mutable.Set() var currentSynonymGroup = 0 + var equalCheck: ExpressionEqualityCheck[G] = ExpressionEqualityCheck() def extractEqualities(e: Expr[G]): Unit = { e match{ @@ -228,6 +240,32 @@ class AnnotationVariableInfoGetter[G]() { } } + def extractComparisons(e: Expr[G]): Unit = { + e match{ + // x != 0 + case Neq(v: Local[G], e2) => equalCheck.isConstantInt(e2).foreach{i => if(i == 0) variableNotZero.add(v)} + // 0 != x + case Neq(e1, v: Local[G]) => equalCheck.isConstantInt(e1).foreach{i => if(i == 0) variableNotZero.add(v)} + // x < 0 + case Less(v: Local[G], e2) => equalCheck.isConstantInt(e2).foreach{i => if(i <= 0) variableNotZero.add(v) } + // x <= -1 + case LessEq(v: Local[G], e2) => equalCheck.isConstantInt(e2).foreach{i => if(i < 0) variableNotZero.add(v) } + // 0 < x + case Less(e1, v: Local[G]) => equalCheck.isConstantInt(e1).foreach{i => if(i >= 0) variableNotZero.add(v); variableGreaterThanZero.add(v) } + // 1 <= x + case LessEq(e1, v: Local[G]) => equalCheck.isConstantInt(e1).foreach{i => if(i > 0) variableNotZero.add(v); variableGreaterThanZero.add(v) } + // x > 0 + case Greater(v: Local[G], e2) => equalCheck.isConstantInt(e2).foreach{i => if(i >= 0) variableNotZero.add(v); variableGreaterThanZero.add(v) } + // x >= 1 + case GreaterEq(v: Local[G], e2) => equalCheck.isConstantInt(e2).foreach{i => if(i > 0) variableNotZero.add(v); variableGreaterThanZero.add(v) } + // 0 > x + case Greater(e1, v: Local[G]) => equalCheck.isConstantInt(e1).foreach{i => if(i <= 0) variableNotZero.add(v) } + // 0 >= x + case GreaterEq(e1, v: Local[G]) => equalCheck.isConstantInt(e1).foreach{i => if(i < 0) variableNotZero.add(v) } + case _ => + } + } + def addSynonym(v1: Local[G], v2: Local[G]): Unit = { (variableSynonyms.get(v1), variableSynonyms.get(v2)) match { // We make a new group @@ -239,55 +277,85 @@ class AnnotationVariableInfoGetter[G]() { case (Some(id1), None) => variableSynonyms(v2) = id1 case (None, Some(id2)) => variableSynonyms(v1) = id2 // Merge the groups, give every synonym group member of id2 value id1 - case (Some(id1), Some(id2)) => + case (Some(id1), Some(id2)) if id1 != id2 => variableSynonyms.mapValuesInPlace((_, group) => if (group == id2) id1 else group) + case _ => } } + def addValue(v: Local[G], x: BigInt): Unit = + variableValues.get(v) match { + case Some(y) => if (x!=y) throw InconsistentVariableEquality(v, x, y) + case None => + variableValues(v) = x + if(x>0) variableGreaterThanZero.add(v) + if(x!=0) variableNotZero.add(v) + } + def addName(v: Local[G], expr: Expr[G]): Unit ={ // Add to constant list isConstantInt[G](expr) match { - case Some(x) => variableValues.get(v) match { - case Some(y) => if (x!=y) - throw InconsistentVariableEquality(v, x, y) - case None => variableValues(v) = x - } + case Some(x) => addValue(v, x) case None => val list = variableEqualities.getOrElseUpdate(v, mutable.ListBuffer()) list.addOne(expr) } } - def getInfo(annotations: Iterable[Expr[G]]): AnnotationVariableInfo[G] = { + def getInfo(annotations: Seq[Expr[G]]): AnnotationVariableInfo[G] = { variableEqualities.clear() variableValues.clear() + variableSynonyms.clear() + currentSynonymGroup = 0 + variableNotZero.clear() + variableGreaterThanZero.clear() for(clause <- annotations){ extractEqualities(clause) } + val res = AnnotationVariableInfo[G](variableEqualities.view.mapValues(_.toList).toMap, variableValues.toMap, + variableSynonyms.toMap, Set[Local[G]](), Set[Local[G]]()) + equalCheck = ExpressionEqualityCheck(Some(res)) + + for(clause <- annotations){ + extractComparisons(clause) + } + distributeInfo() + AnnotationVariableInfo(variableEqualities.view.mapValues(_.toList).toMap, variableValues.toMap, - variableSynonyms.toMap) + variableSynonyms.toMap, variableNotZero.toSet, variableGreaterThanZero.toSet) } def distributeInfo(): Unit = { - // First distribute value knowledge over the rest of the map - val beginSize = variableValues.size - + // First check if expressions have become integers for((name, equals) <- variableEqualities){ if(!variableValues.contains(name)) for(equal <- equals){ - equal match { - case n : Local[G] => - variableValues.get(n).foreach(variableValues(name) = _) - case _ => + equalCheck.isConstantInt(equal) match { + case Some(x) => addValue(name, x) + case None => } } } - // If sizes are not the same, we know more, so distribute again! - if(variableValues.size != beginSize) distributeInfo() + // Group synonym sets + val synonymSets: mutable.Map[Int, mutable.Set[Local[G]]] = mutable.Map() + variableSynonyms.foreach{ case (v, groupId) => synonymSets.getOrElse(groupId,mutable.Set()).add(v) } + + def hasValue(vars: mutable.Set[Local[G]]): Option[BigInt] = { + vars.foreach{v => if(variableValues.contains(v)) return variableValues.get(v) } + None + } + + synonymSets.foreach{ case (_, vars) => + // Redistribute values over synonyms + hasValue(vars).foreach{x => vars.foreach{addValue(_, x)}} + // Redistribute not-zero over synonyms + if(vars.intersect(variableNotZero).nonEmpty) variableNotZero.addAll(vars) + // Redistribute greater than zero over synonyms + if(vars.intersect(variableGreaterThanZero).nonEmpty) variableGreaterThanZero.addAll(vars) } } } \ No newline at end of file From f200d2628277a747a529ca1bbf5ee928b237d9a1 Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 30 Aug 2022 15:58:59 +0200 Subject: [PATCH 10/25] Smaller fixes --- .../main/java/vct/col/feature/FeatureRainbow.scala | 2 ++ col/src/main/java/vct/col/print/Printer.scala | 12 +++++++++--- .../java/vct/col/newrewrite/ApplyTermRewriter.scala | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/col/src/main/java/vct/col/feature/FeatureRainbow.scala b/col/src/main/java/vct/col/feature/FeatureRainbow.scala index 43e02c28d7..cddb5efdb8 100644 --- a/col/src/main/java/vct/col/feature/FeatureRainbow.scala +++ b/col/src/main/java/vct/col/feature/FeatureRainbow.scala @@ -450,6 +450,8 @@ class FeatureRainbow[G] { case node: CStructAccess[G] => return Nil case node: CStructDeref[G] => return Nil case node: GpgpuCudaKernelInvocation[G] => return Nil + case node: LocalThreadId[G] => return Nil + case node: GlobalThreadId[G] => return Nil case node: CPrimitiveType[G] => return Nil case node: JavaName[G] => return Nil case node: JavaImport[G] => return Nil diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 38675de737..969dc13e4f 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -633,9 +633,9 @@ case class Printer(out: Appendable, case MapDisjoint(left, right) => (phrase("disjointMap(", left, ",", space, right, ")"), 100) case Forall(bindings, triggers, body) => - (phrase("(", "\\forall", space, phrase(bindings.map(NodePhrase):_*), "; true; ", body, ")"), 120) + (phrase("(", "\\forall", space, commas(bindings.map(NodePhrase)), "; true; ", body, ")"), 120) case Exists(bindings, triggers, body) => - (phrase("(", "\\exists", space, phrase(bindings.map(NodePhrase):_*), "; true; ", body, ")"), 120) + (phrase("(", "\\exists", space, commas(bindings.map(NodePhrase)), "; true; ", body, ")"), 120) case ValidArray(arr, len) => (phrase("\\array(", arr, ",", space, len, ")"), 100) case ValidMatrix(mat, w, h) => @@ -655,13 +655,15 @@ case class Printer(out: Appendable, case JoinToken(thread) => (phrase("running", "(", thread, ")"), 100) case Starall(bindings, triggers, body) => - (phrase("(", "\\forall*", space, phrase(bindings.map(NodePhrase):_*), "; true; ", body, ")"), 120) + (phrase("(", "\\forall*", space, commas(bindings.map(NodePhrase)), "; true; ", body, ")"), 120) case Star(left, right) => (phrase(assoc(40, left), space, "**", space, assoc(40, right)), 40) case Wand(left, right) => (phrase(bind(30, left), space, "-*", space, assoc(30, right)), 30) case Scale(scale, res) => (phrase("[", scale, "]", assoc(90, res)), 90) + case ScaleByParBlock(block, res) => + (phrase("[", block.decl, "]", assoc(90, res)), 90) case Perm(loc, perm) => (phrase("Perm(", loc, ",", space, perm, ")"), 100) case PointsTo(loc, perm, value) => @@ -849,6 +851,10 @@ case class Printer(out: Appendable, (phrase(assoc(100, arr), "[", index, "]"), 100) case PointerSubscript(pointer, index) => (phrase(assoc(100, pointer), "[", index, "]"), 100) + case PointerBlockOffset(pointer) => + (phrase("pointer_block(", pointer ,")"), 100) + case PointerBlockLength(pointer) => + (phrase("block_length(", pointer ,")"), 100) case Cons(x, xs) => (phrase(bind(87, x), space, "::", space, assoc(87, xs)), 87) case Head(xs) => diff --git a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala index a473bb2e46..fb9cb23db4 100644 --- a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala +++ b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala @@ -142,7 +142,7 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] override def dispatch(d: Declaration[Rule]): Unit = d match { case v: Variable[Rule] => - allScopes.anySucceedOnly(v, new Variable(dispatch(v.t))(binderOrigins.getOrElse(v,defaultOrigin))) + variables.succeed(v, new Variable(dispatch(v.t))(binderOrigins.getOrElse(v,defaultOrigin))) case other => rewriteDefault(other) } From 0dc703d6364122adc6ec1d4e131900d8cd5dac09 Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 30 Aug 2022 15:59:15 +0200 Subject: [PATCH 11/25] Improvements to nested quantifier check --- .../SimplifyNestedQuantifiers.scala | 142 +++++++++++------- 1 file changed, 86 insertions(+), 56 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 15d1136ee5..a9235d838b 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -1,9 +1,10 @@ package vct.col.newrewrite +import com.typesafe.scalalogging.LazyLogging import vct.col.ast.{ArraySubscript, _} import vct.col.ast.util.{AnnotationVariableInfoGetter, ExpressionEqualityCheck} import vct.col.newrewrite.util.Comparison -import vct.col.origin.{Blame, NontrivialUnsatisfiable, Origin, PanicBlame} +import vct.col.origin.{Origin, PanicBlame} import vct.col.ref.Ref import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ @@ -12,7 +13,6 @@ import vct.result.VerificationError.Unreachable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.math.BigInt import scala.annotation.nowarn import scala.reflect.ClassTag @@ -34,7 +34,7 @@ case object SimplifyNestedQuantifiers extends RewriterBuilder { override def desc: String = "Simplify nested quantifiers." } -case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] { +case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] with LazyLogging { case object SimplifyNestedQuantifiersOrigin extends Origin { override def preferredName: String = "unknown" @@ -60,14 +60,20 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] private def one: IntegerValue[Pre] = IntegerValue(1) - // TODO: Supply information towards the expression equality checker when encountering a contract var equalityChecker: ExpressionEqualityCheck[Pre] = ExpressionEqualityCheck() override def dispatch(e: Expr[Pre]): Expr[Post] = { e match { case e: Binder[Pre] => rewriteLinearArray(e) match { - case None => rewriteDefault(e) + case None => + val res = rewriteDefault(e) + res match { + case Starall(_, triggers, _) if triggers.isEmpty => logger.info(f"Warning the binder '$res' contains no triggers") + case Forall(_, triggers, _) if triggers.isEmpty => logger.info(f"Warning the binder '$res' contains no triggers") + case _ => + } + res case Some(newE) => newE } @@ -340,9 +346,11 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // (forall i; a <= i <= b; ...Perm(ar, x)...) ====> b>=a ==> ...Perm(ar, x*(b-a+1))... independentConditions.addOne(GreaterEq(maxBound, minBound)) - body = Scale(Plus(one, Minus(maxBound, minBound)), body)( - PanicBlame("Error in SimplifyNestedQuantifiers class, implication should make sure scale is" + - " never negative when accessed.")) // rp.dispatch(body) + if(body.t == TResource()){ + body = Scale(Plus(one, Minus(maxBound, minBound)), body)( + PanicBlame("Error in SimplifyNestedQuantifiers class, implication should make sure scale is" + + " never negative when accessed.")) + } case _ => } } @@ -380,16 +388,19 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] lowerBounds.getOrElse(name, ArrayBuffer()).nonEmpty } + val zero = BigInt(0) for (name <- bindings) { - var oneZero = false - val zero = BigInt(0) - lowerBounds.getOrElse(name, ArrayBuffer()) - .foreach(lower => equalityChecker.isConstantInt(lower) match { - case Some(`zero`) => oneZero = true + def hasZero: Boolean = { + lowerBounds.getOrElse(name, ArrayBuffer()) + .foreach(lower => equalityChecker.isConstantInt(lower) match { + case Some(`zero`) => return true case _ => }) + false + } + //Exit when notAt least one zero, or no upper bounds - if (!oneZero || upperBounds.getOrElse(name, ArrayBuffer()).isEmpty) { + if (!hasZero || upperBounds.getOrElse(name, ArrayBuffer()).isEmpty) { return false } } @@ -403,14 +414,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] case class ForallSubstitute(subs: Map[Expr[Pre], Expr[Post]]) extends Rewriter[Pre] { - -// override def lookupSuccessor: Declaration[Pre] => Option[Declaration[Post]] = { -// val here = mainRewriter.lookupSuccessor -// decl => here(decl) -// } - - override def anySucc[RefDecl <: Declaration[Post]](decl: Declaration[Pre])(implicit tag: ClassTag[RefDecl]): Ref[Post, RefDecl] - = mainRewriter.anySucc(decl)(tag) + override val allScopes = mainRewriter.allScopes override def dispatch(e: Expr[Pre]): Expr[Post] = e match { case expr if subs.contains(expr) => subs(expr) @@ -424,17 +428,18 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] mainRewriter.variables.collect {linearAccesses.search(body)} match { case (bindings, Some(substituteForall)) => if(bindings.size != 1){ - ??? + Unreachable("Only one new variable should be declared with SimplifyNestedQuantifiers.") } val sub = ForallSubstitute(substituteForall.substituteOldVars) val newBody = sub.dispatch(body) val select = Seq(substituteForall.newBounds) ++ independentConditions.map(sub.dispatch) ++ dependentConditions.map(sub.dispatch) val main = if (select.nonEmpty) Implies(AstBuildHelpers.foldAnd(select), newBody) else newBody + @nowarn("msg=xhaust") val forall: Binder[Post] = originalBinder match { - case _: Forall[Pre] => Forall(bindings, substituteForall.newTrigger, main) - case originalBinder: Starall[Pre] => Starall(bindings, substituteForall.newTrigger, main)(originalBinder.blame) - case _ => ??? + case _: Forall[Pre] => Forall(bindings, substituteForall.newTrigger, main)(originalBinder.o) + case originalBinder: Starall[Pre] => + Starall(bindings, substituteForall.newTrigger, main)(originalBinder.blame)(originalBinder.o) } Some(forall) case (_, None) => result() @@ -469,8 +474,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // are not there anymore? @nowarn("msg=xhaust") val forall: Expr[Pre] = originalBinder match{ - case _: Forall[Pre] => Forall(bindings.toSeq, Seq(), new_body) - case e: Starall[Pre] => Starall(bindings.toSeq, Seq(), new_body)(e.blame) + case _: Forall[Pre] => Forall(bindings.toSeq, Seq(), new_body)(originalBinder.o) + case e: Starall[Pre] => Starall(bindings.toSeq, Seq(), new_body)(e.blame)(originalBinder.o) } Some(forall) } @@ -485,31 +490,50 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def indepOf[G](bindings: mutable.Set[Variable[G]], e: Expr[G]): Boolean = e.transSubnodes.collectFirst { case Local(ref) if bindings.contains(ref.decl) => () }.isEmpty + sealed trait Subscript { + val index: Expr[Pre] + val subnodes: Seq[Node[Pre]] + } + final case class Array(e: ArraySubscript[Pre]) extends Subscript { + val index: Expr[Pre] = e.index + val subnodes: Seq[Node[Pre]] = e.subnodes + } + + final case class Pointer(e: PointerSubscript[Pre]) extends Subscript { + val index: Expr[Pre] = e.index + val subnodes: Seq[Node[Pre]] = e.subnodes + } + class FindLinearArrayAccesses(quantifierData: RewriteQuantifierData){ // Search for linear array expressions def search(e: Node[Pre]): Option[SubstituteForall] = { e match { - case e @ ArraySubscript(_, index) => - if (indepOf(quantifierData.bindings, index)) { - return None - } - linearExpression(e) match { - case Some(substituteForall) => Some(substituteForall) - case None => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} - } + case e @ ArraySubscript(_, _) => + testSubscript(Array(e)) + case e @ PointerSubscript(_, _) => + testSubscript(Pointer(e)) case _ => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} } } - def linearExpression(e: ArraySubscript[Pre]): Option[SubstituteForall] = { - val ArraySubscript(_, index) = e + def testSubscript(e: Subscript): Option[SubstituteForall] = { + if (indepOf(quantifierData.bindings, e.index)) { + return None + } + linearExpression(e) match { + case Some(substituteForall) => Some(substituteForall) + case None => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} + } + } + + def linearExpression(e: Subscript): Option[SubstituteForall] = { val pot = new PotentialLinearExpressions(e) - pot.visit(index) + pot.visit(e.index) pot.canRewrite() } - class PotentialLinearExpressions(val arrayIndex: ArraySubscript[Pre]){ + class PotentialLinearExpressions(val arrayIndex: Subscript){ val linearExpressions: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() var constantExpression: Option[Expr[Pre]] = None var isLinear: Boolean = true @@ -623,28 +647,28 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] * (or equivalent a_i == Prod_{0 <= j < i} {n_j} * a_0 ) * * Further more we require that n_i > 0 and a_i > 0 (although I think a_0<0 is also valid) - * TODO: We are not checking n_i and a_i on this * We can than replace the forall with * b <= x_new < a_{k-1} * n_{k-1} + b && (x_new - b) % a_0 == 0 : ... ar[x_new] ... * and each x_i gets replaced by * x_i -> ((x_new - b) / a_i) % n_i - * And since we never go past a_{k-1} * n_{k-1} + b, no modulo needed here - * x_{k-1} -> (x_new - b) / a_{k-1}\ + * And since we never go past a_{k-1} * n_{k-1} + b, no modulo needed here + * x_{k-1} -> (x_new - b) / a_{k-1} */ def check_vars_list(vars: List[Variable[Pre]]): Option[SubstituteForall] = { val x_0 = vars.head val a_0 = linearExpressions(x_0) + if(!equalityChecker.isNonZero(a_0)) return None // x_{i-1}, a_{i-1}, n_{i-1} var x_i_last = x_0 var a_i_last = a_0 var n_i_last: Expr[Pre] = null - val ns : mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() + val ns: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() - val new_name = vars.tail.foldLeft(vars.head.o.preferredName)(_ + "_" + _.o.preferredName) + val new_name = vars.map(_.o.preferredName).mkString("_") val x_new = new Variable[Post](TInt())(BinderOrigin(new_name)) - val newGen: Expr[Pre] => Expr[Post] = quantifierData.mainRewriter.dispatch(_) + val newGen: Expr[Pre] => Expr[Post] = quantifierData.mainRewriter.dispatch // x_base == (x_new -b) val x_base: Expr[Post]= constantExpression match { @@ -659,7 +683,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // Find a suitable upper bound for (n_i_last_candidate <- quantifierData.upperExclusiveBounds(x_i_last)) { - if( !found_valid_n && equalityChecker.equalExpressions(a_i, simplified_mult(a_i_last, n_i_last_candidate)) ) { + if( !found_valid_n && /*equalityChecker.isGreaterThanZero(n_i_last_candidate) && */ + equalityChecker.equalExpressions(a_i, simplified_mult(a_i_last, n_i_last_candidate)) ) { found_valid_n = true n_i_last = n_i_last_candidate ns(x_i_last) = n_i_last_candidate @@ -672,9 +697,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] replace_map(Local(x_i_last.ref)) = if(is_value(a_i_last, 1)) - Mod(x_base, newGen(n_i_last))(PanicBlame("TODO")) + Mod(x_base, newGen(n_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, n_i_last should not be zero")) else - Mod(FloorDiv(x_base, newGen(a_i_last))(PanicBlame("TODO")), newGen(n_i_last))(PanicBlame("TODO")) + Mod(FloorDiv(x_base, newGen(a_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, a_i_last should not be zero")), newGen(n_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, n_i_last should not be zero")) // Yay we are good up to now, go check out the next i x_i_last = x_i @@ -684,16 +709,21 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // We found a replacement! // Make the declaration final quantifierData.mainRewriter.variables.declare(x_new) -// x_new.declareDefault(quantifierData.mainRewriter) - val ArraySubscript(arr, index) = arrayIndex + // Replace the linear expression with the new variable val x_new_var: Expr[Post] = Local(x_new.ref) - replace_map(index) = x_new_var - val newTrigger : Seq[Seq[Expr[Post]]] = Seq(Seq(ArraySubscript(newGen(arr), x_new_var)(arrayIndex.blame))) + replace_map(arrayIndex.index) = x_new_var + val trigger: Expr[Post] = arrayIndex match { + case Array( node @ ArraySubscript(arr, _)) => ArraySubscript(newGen(arr), x_new_var)(node.blame) + case Pointer( node @ PointerSubscript(arr, _)) => PointerSubscript(newGen(arr), x_new_var)(node.blame) + } + val newTrigger : Seq[Seq[Expr[Post]]] = Seq(Seq(trigger)) // Add the last value, no need to do modulo - //TODO - replace_map(Local(x_i_last.ref)) = FloorDiv(x_base, newGen(a_i_last))(PanicBlame("TODO")) + // + replace_map(Local(x_i_last.ref)) = if(is_value(a_i_last, 1)) x_base + else FloorDiv(x_base, newGen(a_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, a_i_last should not be zero")) + // Get a random upperbound for x_i_last; n_i_last = quantifierData.upperExclusiveBounds(x_i_last).head ns(x_i_last) = n_i_last @@ -706,7 +736,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] new_bounds = if(is_value(a_0, 1)) new_bounds else And(new_bounds, //TODO - Eq( Mod(x_base, newGen(a_0))(PanicBlame("TODO")), + Eq( Mod(x_base, newGen(a_0))(PanicBlame("Error in SimplifyNestedQuantifiers, a_0 should not be zero")), IntegerValue(0)) ) From aebc48ba86d56c3d47658d70ae46954f3ff66b60 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 11:58:39 +0200 Subject: [PATCH 12/25] Add better origin names of C files --- col/src/main/java/vct/col/origin/Origin.scala | 11 +++++++++++ .../vct/col/newrewrite/lang/LangCToCol.scala | 16 +++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/col/src/main/java/vct/col/origin/Origin.scala b/col/src/main/java/vct/col/origin/Origin.scala index 8011c54be6..a7cafe0a8f 100644 --- a/col/src/main/java/vct/col/origin/Origin.scala +++ b/col/src/main/java/vct/col/origin/Origin.scala @@ -260,6 +260,17 @@ case class ReadableOrigin(readable: Readable, override def toString: String = f"$startText - $endText" } +case class InterpretedOriginVariable(name: String, original: Origin) + extends InputOrigin { + override def preferredName: String = name + + override def context: String = original.context + + override def inlineContext: String = original.inlineContext + + override def shortPosition: String = original.shortPosition +} + case class InterpretedOrigin(interpreted: Readable, startLineIdx: Int, endLineIdx: Int, cols: Option[(Int, Int)], diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index acec372516..294df155fb 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -4,7 +4,7 @@ import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast._ import vct.col.newrewrite.lang.LangSpecificToCol.NotAValue -import vct.col.origin.{AbstractApplicable, Blame, CallableFailure, Origin, PanicBlame, TrueSatisfiable} +import vct.col.origin.{AbstractApplicable, Blame, CallableFailure, InterpretedOriginVariable, Origin, PanicBlame, TrueSatisfiable} import vct.col.ref.Ref import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CInvocationTarget, CNameTarget, RefADTFunction, RefAxiomaticDataType, RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCudaBlockDim, RefCudaBlockIdx, RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, RefPredicate, RefProcedure, RefVariable, SpecDerefTarget, SpecInvocationTarget} import vct.col.rewrite.{Generation, Rewritten} @@ -69,9 +69,18 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz cUnit.declarations.foreach(rw.dispatch) } + def cDeclToName(cDecl: CDeclarator[Pre]): String = cDecl match { + case CPointerDeclarator(_, inner) => cDeclToName(inner) + case CArrayDeclarator(_, _, inner) => cDeclToName(inner) + case CTypedFunctionDeclarator(_, _, inner) => cDeclToName(inner) + case CAnonymousFunctionDeclarator(_, inner) => cDeclToName(inner) + case CName(name: String) => name + } + def rewriteParam(cParam: CParam[Pre]): Unit = { cParam.drop() - val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(cParam.o) + val o = InterpretedOriginVariable(cDeclToName(cParam.declarator), cParam.o) + val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) cNameSuccessor(RefCParam(cParam)) = v rw.variables.declare(v) } @@ -98,7 +107,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz cCurrentDefinitionParamSubstitutions.having(subs) { rw.globalDeclarations.declare( if (func.specs.collectFirst { case CKernel() => () }.nonEmpty) { - kernelProcedure(func.o, contract, info, Some(func.body)) + val namedO = InterpretedOriginVariable(cDeclToName(func.declarator), func.o) + kernelProcedure(namedO, contract, info, Some(func.body)) } else { new Procedure[Post]( returnType = returnType, From fee6e7f0ccb97053ef18c843e2ad52b4bed980b9 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 12:00:22 +0200 Subject: [PATCH 13/25] Added pointer_length and smaller fixes --- col/src/main/java/vct/col/print/Printer.scala | 1 - parsers/lib/antlr4/SpecLexer.g4 | 1 + parsers/lib/antlr4/SpecParser.g4 | 1 + parsers/src/main/java/vct/parsers/transform/CToCol.scala | 4 ++++ src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala | 4 +++- .../java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala | 4 ++-- src/main/java/vct/options/Options.scala | 2 +- 7 files changed, 12 insertions(+), 5 deletions(-) diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 969dc13e4f..0d02d5a793 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -976,7 +976,6 @@ case class Printer(out: Appendable, phrase(spaced(local.modifiers.map(NodePhrase)), space, local.t, space, javaDecls(local.decls)) case defn: CFunctionDefinition[_] => control(phrase(spaced(defn.specs.map(NodePhrase)), space, defn.declarator), defn.body) - case decl: CGlobalDeclaration[_] => decl case ns: JavaNamespace[_] => phrase( if(ns.pkg.nonEmpty) statement("package", space, ns.pkg.get) else phrase(), diff --git a/parsers/lib/antlr4/SpecLexer.g4 b/parsers/lib/antlr4/SpecLexer.g4 index 150c4b0e60..23f1e6b7bd 100644 --- a/parsers/lib/antlr4/SpecLexer.g4 +++ b/parsers/lib/antlr4/SpecLexer.g4 @@ -135,6 +135,7 @@ POINTER: '\\pointer'; POINTER_INDEX: '\\pointer_index'; POINTER_BLOCK_LENGTH: '\\pointer_block_length'; POINTER_BLOCK_OFFSET: '\\pointer_block_offset'; +POINTER_LENGTH: '\\pointer_length'; VALUES: '\\values'; VCMP: '\\vcmp'; VREP: '\\vrep'; diff --git a/parsers/lib/antlr4/SpecParser.g4 b/parsers/lib/antlr4/SpecParser.g4 index c2bfbb40dc..8c98dcfc49 100644 --- a/parsers/lib/antlr4/SpecParser.g4 +++ b/parsers/lib/antlr4/SpecParser.g4 @@ -177,6 +177,7 @@ valPrimaryPermission | '\\pointer_index' '(' langExpr ',' langExpr ',' langExpr ')' # valPointerIndex | '\\pointer_block_length' '(' langExpr ')' # valPointerBlockLength | '\\pointer_block_offset' '(' langExpr ')' # valPointerBlockOffset + | '\\pointer_length' '(' langExpr ')' # valPointerLength ; valPrimaryBinder diff --git a/parsers/src/main/java/vct/parsers/transform/CToCol.scala b/parsers/src/main/java/vct/parsers/transform/CToCol.scala index db47daae3b..7ada5a9858 100644 --- a/parsers/src/main/java/vct/parsers/transform/CToCol.scala +++ b/parsers/src/main/java/vct/parsers/transform/CToCol.scala @@ -956,6 +956,10 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case ValPointerIndex(_, _, ptr, _, idx, _, perm, _) => PermPointerIndex(convert(ptr), convert(idx), convert(perm)) case ValPointerBlockLength(_, _, ptr, _) => PointerBlockLength(convert(ptr))(blame(e)) case ValPointerBlockOffset(_, _, ptr, _) => PointerBlockOffset(convert(ptr))(blame(e)) + case ValPointerLength(_, _, ptr, _) => + val convertedPtr = convert(ptr) + val blameExpr = blame(e) + PointerBlockLength(convertedPtr)(blameExpr) - PointerBlockOffset(convertedPtr)(blameExpr) } def convert(implicit e: ValPrimaryBinderContext): Expr[G] = e match { diff --git a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala index fb9cb23db4..73bcc6c450 100644 --- a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala +++ b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala @@ -355,7 +355,9 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] override def dispatch(e: Expr[Pre]): Expr[Post] = if(simplificationDone.nonEmpty) rewriteDefault(e) else simplificationDone.having(()) { - Progress.nextPhase(s"`$e`") + //TODO: This progress "nextPhase" prints out a lot of garbage when having to much foralls (for instance when verifying CUDA/OpenCl programs + // Disabled it for now, probably a bug? + //Progress.nextPhase(s"`$e`") countApply = 0 countSuccess = 0 currentExpr = e diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index a9235d838b..7ae25d3e2e 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -69,8 +69,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] case None => val res = rewriteDefault(e) res match { - case Starall(_, triggers, _) if triggers.isEmpty => logger.info(f"Warning the binder '$res' contains no triggers") - case Forall(_, triggers, _) if triggers.isEmpty => logger.info(f"Warning the binder '$res' contains no triggers") + case Starall(_, triggers, _) if triggers.isEmpty => logger.warn(f"The binder '$res' contains no triggers") + case Forall(_, triggers, _) if triggers.isEmpty => logger.warn(f"The binder '$res' contains no triggers") case _ => } res diff --git a/src/main/java/vct/options/Options.scala b/src/main/java/vct/options/Options.scala index 42f22fb919..3cdb1b9252 100644 --- a/src/main/java/vct/options/Options.scala +++ b/src/main/java/vct/options/Options.scala @@ -132,7 +132,7 @@ case object Options { .text("Debug matched expressions in simplifications"), opt[Unit]("dev-simplify-debug-match-long").maybeHidden() .action((_, c) => c.copy(devSimplifyDebugMatchShort = false)) - .text("Use long form to print matched expressions in sipmlifications"), + .text("Use long form to print matched expressions in simplifications"), opt[Unit]("dev-simplify-debug-no-match").maybeHidden() .action((_, c) => c.copy(devSimplifyDebugNoMatch = true)) .text("Debug expressions that do not match in simplifications"), From f2938cb3262c47797447a8d6ab197f0bf62b4cd2 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 12:01:10 +0200 Subject: [PATCH 14/25] GPU blocks should always be non empty encoded --- .../vct/col/newrewrite/lang/LangCToCol.scala | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index 294df155fb..584c9f0f18 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -132,13 +132,31 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz } def all(idx: CudaVec, dim: CudaVec, e: Expr[Post]): Expr[Post] = { + foldStar(unfoldStar(e).map(allOneExpr(idx, dim, _)))(e.o) + } + + def allOneExpr(idx: CudaVec, dim: CudaVec, e: Expr[Post]): Expr[Post] = { implicit val o: Origin = e.o - Starall(idx.indices.values.toSeq, Nil, Implies( - foldAnd(idx.indices.values.zip(dim.indices.values).map { - case (idx, dim) => const[Post](0) <= idx.get && idx.get < dim.get - }), - e - ))(PanicBlame("Where blame")) + val vars = findVars(e) + val filteredIdx = idx.indices.values.zip(dim.indices.values).filter{ case (i, _) => vars.contains(i)} + val otherIdx = idx.indices.values.zip(dim.indices.values).filterNot{ case (i, _) => vars.contains(i)} + + val body = otherIdx.map{case (_,range) => range}.foldLeft(e)((newE, scaleFactor) => Scale(scaleFactor.get, newE)(PanicBlame("Framed positive")) ) + if(filteredIdx.isEmpty){ + body + } else { + Starall(filteredIdx.map{case (v,_) => v}.toSeq, Nil, Implies( + foldAnd(filteredIdx.map { + case (idx, dim) => const[Post](0) <= idx.get && idx.get < dim.get + }), + body + ))(PanicBlame("Where blame")) + } + } + + def findVars(e: Node[Post], vars: Set[Variable[Post]] = Set()): Set[Variable[Post]] = e match { + case Local(ref) => vars + ref.decl + case _ => e.subnodes.foldLeft(vars)( (set, node) => set ++ findVars(node) ) } def allThreadsInBlock(e: Expr[Pre]): Expr[Post] = { @@ -147,7 +165,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz } def allThreadsInGrid(e: Expr[Pre]): Expr[Post] = { - val thread = new CudaVec(RefCudaThreadIdx())(e.o) val block = new CudaVec(RefCudaBlockIdx())(e.o) cudaCurrentBlockIdx.having(block) { all(block, cudaCurrentGridDim.top, allThreadsInBlock(e)) } } @@ -157,6 +174,19 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz val gridDim = new CudaVec(RefCudaGridDim())(o) cudaCurrentBlockDim.having(blockDim) { cudaCurrentGridDim.having(gridDim) { + val newArgs = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 + val newGivenArgs = rw.variables.dispatch(contract.givenArgs) + val newYieldsArgs = rw.variables.dispatch(contract.yieldsArgs) + // We add the requirement that a GPU kernel must always have threads (non zero block or grid dimensions) + val nonZeroThreadsSeq: Seq[Expr[Post]] + = (blockDim.indices.values ++ gridDim.indices.values).map( v => Less(IntegerValue(0)(o), v.get(o))(o) ).toSeq + val nonZeroThreads = foldStar(nonZeroThreadsSeq)(o) + + // Ugly, but can't get it to type check otherwise + val nonZeroThreadsPred1: Seq[AccountedPredicate[Post]] = nonZeroThreadsSeq.map(UnitAccountedPredicate(_)(o)) + val nonZeroThreadsPred: AccountedPredicate[Post] = + nonZeroThreadsPred1.reduceLeft((l,r) => SplitAccountedPredicate(l, r)(o)) + val parBody = body.map(impl => { implicit val o: Origin = impl.o val threadIdx = new CudaVec(RefCudaThreadIdx()) @@ -173,15 +203,16 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz iters = blockIdx.indices.values.zip(gridDim.indices.values).map { case (index, dim) => IterVariable(index, const(0), dim.get) }.toSeq, - context_everywhere = allThreadsInBlock(contract.contextEverywhere), - requires = allThreadsInBlock(foldStar(unfoldPredicate(contract.requires))), - ensures = allThreadsInBlock(foldStar(unfoldPredicate(contract.ensures))), + context_everywhere = tt, //Star(nonZeroThreads, allThreadsInBlock(contract.contextEverywhere))(o), + requires = Star(nonZeroThreads, allThreadsInBlock(foldStar(unfoldPredicate(contract.requires)))), + ensures = Star(nonZeroThreads, allThreadsInBlock(foldStar(unfoldPredicate(contract.ensures)))), content = ParStatement(ParBlock( decl = blockDecl, iters = threadIdx.indices.values.zip(blockDim.indices.values).map { case (index, dim) => IterVariable(index, const(0), dim.get) }.toSeq, - context_everywhere = rw.dispatch(contract.contextEverywhere), + // Context is already inherited + context_everywhere = Star(nonZeroThreads, rw.dispatch(contract.contextEverywhere)), requires = rw.dispatch(foldStar(unfoldPredicate(contract.requires))), ensures = rw.dispatch(foldStar(unfoldPredicate(contract.ensures))), content = rw.dispatch(impl), @@ -195,16 +226,18 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz new Procedure[Post]( returnType = TVoid(), - args = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1, + args = newArgs, outArgs = Nil, typeArgs = Nil, body = parBody, contract = ApplicableContract( - mapPredicate(contract.requires, allThreadsInGrid), - mapPredicate(contract.ensures, allThreadsInGrid), - allThreadsInGrid(contract.contextEverywhere), + SplitAccountedPredicate(nonZeroThreadsPred, mapPredicate(contract.requires, allThreadsInGrid))(o), + SplitAccountedPredicate(nonZeroThreadsPred, mapPredicate(contract.ensures, allThreadsInGrid))(o), + // Context everywhere is already passed down in the body + //allThreadsInGrid(contract.contextEverywhere), + tt, contract.signals.map(rw.dispatch), - rw.variables.dispatch(contract.givenArgs), - rw.variables.dispatch(contract.yieldsArgs), + newGivenArgs, + newYieldsArgs, contract.decreases.map(rw.dispatch), )(contract.blame)(contract.o) )(AbstractApplicable)(o) From dabf180bf605af6154935371aaed03afa8a41e6c Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 12:01:32 +0200 Subject: [PATCH 15/25] Single parblock required to be non-empty --- .../vct/col/newrewrite/ParBlockEncoder.scala | 115 ++++++++++++++---- 1 file changed, 88 insertions(+), 27 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala index 9dd89532b9..6f080a6b18 100644 --- a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala +++ b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala @@ -104,31 +104,73 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { def from(v: Variable[Pre]): Expr[Post] = range(v)._1 def to(v: Variable[Pre]): Expr[Post] = range(v)._2 - def quantify(block: ParBlock[Pre], expr: Expr[Pre])(implicit o: Origin): Expr[Post] = { - val quantVars = block.iters.map(_.variable).map(v => v -> new Variable[Pre](v.t)(v.o)).toMap - val body = Substitute(quantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]]).dispatch(expr) - block.iters.foldLeft(dispatch(body))((body, iter) => { - val v = quantVars(iter.variable) - Starall[Post]( - Seq(variables.dispatch(v)), - Nil, - (from(iter.variable) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(iter.variable)) ==> body - )(ParBlockNotInjective(block, expr)) - }) +// def quantify(block: ParBlock[Pre], expr: Expr[Pre])(implicit o: Origin): Expr[Post] = { +// val quantVars = block.iters.map(_.variable).map(v => v -> new Variable[Pre](v.t)(v.o)).toMap +// val body = Substitute(quantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]]).dispatch(expr) +// block.iters.foldLeft(dispatch(body))((body, iter) => { +// val v = quantVars(iter.variable) +// Starall[Post]( +// Seq(variables.dispatch(v)), +// Nil, +// SeqMember(Local[Post](succ(v)), Range(from(iter.variable), to(iter.variable))) ==> body +//// (from(iter.variable) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(iter.variable)) ==> body +// )(ParBlockNotInjective(block, expr)) +// }) +// } + + def depVars[G](bindings: Set[Variable[G]], e: Expr[G]): Set[Variable[G]] = { + val result: mutable.Set[Variable[G]] = mutable.Set() + e.transSubnodes.foreach { + case Local(ref) if bindings.contains(ref.decl) => result.addOne(ref.decl) + case _ => } + result.toSet } - def requires(region: ParRegion[Pre])(implicit o: Origin): Expr[Post] = region match { - case ParParallel(regions) => AstBuildHelpers.foldStar(regions.map(requires)) - case ParSequential(regions) => regions.headOption.map(requires).getOrElse(tt) + def quantify(block: ParBlock[Pre], expr: Expr[Pre], nonEmpty: Boolean)(implicit o: Origin): Expr[Post] = { + val exprs = AstBuildHelpers.unfoldStar(expr) + val vars = block.iters.map(_.variable).toSet + + val rewrittenExpr = exprs.map( + e => { + val quantVars = if (nonEmpty) depVars(vars, e) else vars + val nonQuantVars = vars.diff(quantVars) + val newQuantVars = quantVars.map(v => v -> new Variable[Pre](v.t)(v.o)).toMap + var body = dispatch(Substitute(newQuantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]]).dispatch(e)) + body = if (nonPermissionExpr(e)) { + body + } else { + // Scale the body if it contains permissions + nonQuantVars.foldLeft(body)((body, iter) => { + val scale = to(iter) - from(iter) + Scale(const[Post](1) /:/ scale, body)(PanicBlame("Par block was checked to be non-empty")) + }) + } + // Result, quantify over all the relevant variables + quantVars.foldLeft(body)((body, iter) => { + val v: Variable[Pre] = newQuantVars(iter) + Starall[Post]( + Seq(variables.dispatch(v)), + Nil, + (from(iter) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(iter)) ==> body + )(ParBlockNotInjective(block, e)) + }) + } + ) + AstBuildHelpers.foldStar(rewrittenExpr) + } + + def requires(region: ParRegion[Pre], nonEmpty: Boolean)(implicit o: Origin): Expr[Post] = region match { + case ParParallel(regions) => AstBuildHelpers.foldStar(regions.map(requires(_, nonEmpty))) + case ParSequential(regions) => regions.headOption.map(requires(_, nonEmpty)).getOrElse(tt) case block: ParBlock[Pre] => - quantify(block, block.context_everywhere &* block.requires) + quantify(block, block.context_everywhere &* block.requires, nonEmpty) } - def ensures(region: ParRegion[Pre])(implicit o: Origin): Expr[Post] = region match { - case ParParallel(regions) => AstBuildHelpers.foldStar(regions.map(ensures)) - case ParSequential(regions) => regions.lastOption.map(ensures).getOrElse(tt) + def ensures(region: ParRegion[Pre], nonEmpty: Boolean)(implicit o: Origin): Expr[Post] = region match { + case ParParallel(regions) => AstBuildHelpers.foldStar(regions.map(ensures(_, nonEmpty))) + case ParSequential(regions) => regions.lastOption.map(ensures(_, nonEmpty)).getOrElse(tt) case block: ParBlock[Pre] => - quantify(block, block.context_everywhere &* block.ensures) + quantify(block, block.context_everywhere &* block.ensures, nonEmpty) } def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Unit = region match { @@ -142,11 +184,11 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { } } - def execute(region: ParRegion[Pre]): Statement[Post] = { + def execute(region: ParRegion[Pre], nonEmpty: Boolean): Statement[Post] = { implicit val o: Origin = region.o Block(Seq( - Exhale(requires(region))(ParStatementExhaleFailed(region)), - Inhale(ensures(region)), + Exhale(requires(region, nonEmpty))(ParStatementExhaleFailed(region)), + Inhale(ensures(region, nonEmpty)), )) } @@ -157,7 +199,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { IndetBranch(regions.map(check) ++ regions.zip(regions.tail).map { case (leftRegion, rightRegion) => implicit val o: Origin = region.o - FramedProof(tt, Inhale(ensures(leftRegion)), requires(rightRegion))(ParSequenceProofFailed(rightRegion)) + FramedProof(tt, Inhale(ensures(leftRegion, false)), requires(rightRegion, false))(ParSequenceProofFailed(rightRegion)) })(region.o) case block @ ParBlock(decl, iters, context_everywhere, requires, ensures, content) => implicit val o: Origin = region.o @@ -180,6 +222,11 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(stat: Statement[Pre]): Statement[Rewritten[Pre]] = stat match { case ParStatement(region) => + val (isSingleBlock: Boolean, iters: Option[Seq[IterVariable[Pre]]]) = region match { + case pb : ParBlock[Pre] => (true, Some(pb.iters)) + case _ => (false, None) + } + implicit val o: Origin = stat.o val rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])] = mutable.Map() @@ -187,12 +234,17 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { ranges(region, rangeValues) currentRanges.having(rangeValues.toMap) { - Block(Seq( + var res: Statement[Post] = Block(Seq( IndetBranch(Seq( - execute(region), + execute(region, isSingleBlock), Block(Seq(check(region), Inhale(ff))) )), )) + if(isSingleBlock){ + val condition: Expr[Post] = foldAnd(rangeValues.values.map{case (low, hi) => low < hi}) + res = Branch(Seq((condition, res))) + } + res } case inv @ ParInvariant(decl, dependentInvariant, body) => @@ -226,7 +278,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { FramedProof( pre = tt, body = Block(Seq( - Inhale(quantify(block, requires)), + Inhale(quantify(block, requires, false)), Block(suspendedInvariants.map { case Ref(decl) => Inhale(dispatch(invDecl(decl))) }), dispatch(hint), Block(suspendedInvariants.reverse.map { @@ -235,7 +287,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { )(ParBarrierInvariantExhaleFailed(barrier)) }), )), - post = quantify(block, ensures), + post = quantify(block, ensures, false), )(ParBarrierProofFailed(barrier)), Inhale(ff), )), @@ -250,6 +302,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { } override def dispatch(e: Expr[Pre]): Expr[Rewritten[Pre]] = e match { +// case ScaleByParBlock(Ref(decl), res) if !nonPermissionExpr(res) => case ScaleByParBlock(Ref(decl), res) => implicit val o: Origin = e.o val block = blockDecl(decl) @@ -258,6 +311,14 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { val scale = to(v.variable) - from(v.variable) Implies(scale > const(0), Scale(const[Post](1) /:/ scale, res)(PanicBlame("framed positive"))) } +// case ScaleByParBlock(Ref(_), res) => dispatch(res) case other => rewriteDefault(other) } + + def nonPermissionExpr[G](e: Node[G]): Boolean = e match { + case _: Locator[G] => false + case _: PermPointer[G] => false + case _: PermPointerIndex[G] => false + case other => other.subnodes.forall(nonPermissionExpr) + } } From 6b8310b08a445b207d44167a02e8cacbb7f173bc Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 12:02:05 +0200 Subject: [PATCH 16/25] Rearange transformation to work better with simplify nested quantifiers --- .../java/vct/main/stages/Transformation.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/main/java/vct/main/stages/Transformation.scala b/src/main/java/vct/main/stages/Transformation.scala index 07007079ff..1f5426c4c3 100644 --- a/src/main/java/vct/main/stages/Transformation.scala +++ b/src/main/java/vct/main/stages/Transformation.scala @@ -146,14 +146,21 @@ case class SilverTransformation EncodeSendRecv, ParBlockEncoder, - // Encode proof helpers - EncodeProofHelpers, - // Encode exceptional behaviour (no more continue/break/return/try/throw) SpecifyImplicitLabels, SwitchToGoto, ContinueToBreak, EncodeBreakReturn, + + SplitQuantifiers, + ) ++ simplifyBeforeRelations ++ Seq( + SimplifyQuantifiedRelations, + SimplifyNestedQuantifiers, + ) ++ simplifyAfterRelations ++ Seq( + + // Encode proof helpers + EncodeProofHelpers, + // Resolve side effects including method invocations, for encodetrythrowsignals. ResolveExpressionSideEffects, EncodeTryThrowSignals, @@ -164,11 +171,6 @@ case class SilverTransformation CheckContractSatisfiability.withArg(checkSat), - SplitQuantifiers, - ) ++ simplifyBeforeRelations ++ Seq( - SimplifyQuantifiedRelations, - SimplifyNestedQuantifiers, - ) ++ simplifyAfterRelations ++ Seq( ResolveExpressionSideChecks, // Translate internal types to domains From d2e6a2b6adbef936679b9e08fee1f3f536cc1767 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 12:02:19 +0200 Subject: [PATCH 17/25] Add bool to cuda.h and opencl.h --- src/main/universal/res/c/cuda.h | 2 ++ src/main/universal/res/c/opencl.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/main/universal/res/c/cuda.h b/src/main/universal/res/c/cuda.h index fab02a2c51..e8d0c0cc62 100644 --- a/src/main/universal/res/c/cuda.h +++ b/src/main/universal/res/c/cuda.h @@ -3,6 +3,8 @@ #define __global__ __vercors_kernel__ +#define bool _Bool + #define cudaEvent_t int #define cudaMemcpyHostToDevice 0 #define cudaMemcpyDeviceToHost 1 diff --git a/src/main/universal/res/c/opencl.h b/src/main/universal/res/c/opencl.h index 405b1210a2..8ec3733f8c 100644 --- a/src/main/universal/res/c/opencl.h +++ b/src/main/universal/res/c/opencl.h @@ -8,6 +8,8 @@ #define barrier(locality) __vercors_barrier__(locality) +#define bool _Bool + extern /*@ pure @*/ int get_work_dim(); // Number of dimensions in use extern /*@ pure @*/ int get_global_size(int dimindx); // Number of global work-items From 47313475fc03185a5dd741866ff4adf2d7a71edc Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 5 Sep 2022 14:40:38 +0200 Subject: [PATCH 18/25] Merge and warning fixes --- .../vct/col/newrewrite/ParBlockEncoder.scala | 12 +- .../SimplifyNestedQuantifiers.scala | 204 ++---------------- 2 files changed, 27 insertions(+), 189 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala index 6f080a6b18..f5466ff6b9 100644 --- a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala +++ b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala @@ -316,9 +316,19 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { } def nonPermissionExpr[G](e: Node[G]): Boolean = e match { - case _: Locator[G] => false + case _: Perm[G] => false + case _: PointsTo[G] => false case _: PermPointer[G] => false case _: PermPointerIndex[G] => false + case _: ModelState[G] => false + case _: ModelSplit[G] => false + case _: ModelMerge[G] => false + case _: ModelChoose[G] => false + case _: ModelPerm[G] => false + case _: ActionPerm[G] => false + case _: PredicateApply[G] => false + case _: InstancePredicateApply[G] => false + case _: CoalesceInstancePredicateApply[G] => false case other => other.subnodes.forall(nonPermissionExpr) } } diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 7ae25d3e2e..0d38411ace 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -490,18 +490,18 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] def indepOf[G](bindings: mutable.Set[Variable[G]], e: Expr[G]): Boolean = e.transSubnodes.collectFirst { case Local(ref) if bindings.contains(ref.decl) => () }.isEmpty - sealed trait Subscript { - val index: Expr[Pre] - val subnodes: Seq[Node[Pre]] + sealed trait Subscript[G] { + val index: Expr[G] + val subnodes: Seq[Node[G]] } - final case class Array(e: ArraySubscript[Pre]) extends Subscript { - val index: Expr[Pre] = e.index - val subnodes: Seq[Node[Pre]] = e.subnodes + case class Array[G](e: ArraySubscript[G]) extends Subscript[G] { + val index: Expr[G] = e.index + val subnodes: Seq[Node[G]] = e.subnodes } - final case class Pointer(e: PointerSubscript[Pre]) extends Subscript { - val index: Expr[Pre] = e.index - val subnodes: Seq[Node[Pre]] = e.subnodes + case class Pointer[G](e: PointerSubscript[G]) extends Subscript[G] { + val index: Expr[G] = e.index + val subnodes: Seq[Node[G]] = e.subnodes } class FindLinearArrayAccesses(quantifierData: RewriteQuantifierData){ @@ -517,7 +517,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } } - def testSubscript(e: Subscript): Option[SubstituteForall] = { + def testSubscript(e: Subscript[Pre]): Option[SubstituteForall] = { if (indepOf(quantifierData.bindings, e.index)) { return None } @@ -527,13 +527,13 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } } - def linearExpression(e: Subscript): Option[SubstituteForall] = { + def linearExpression(e: Subscript[Pre]): Option[SubstituteForall] = { val pot = new PotentialLinearExpressions(e) pot.visit(e.index) pot.canRewrite() } - class PotentialLinearExpressions(val arrayIndex: Subscript){ + class PotentialLinearExpressions(val arrayIndex: Subscript[Pre]){ val linearExpressions: mutable.Map[Variable[Pre], Expr[Pre]] = mutable.Map() var constantExpression: Option[Expr[Pre]] = None var isLinear: Boolean = true @@ -796,180 +796,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] } } - // // The `new_forall_var` will be the name of variable of the newly made forall. - // // The `newBounds`, will contain all the new equations for "select" part of the forall. - // // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions - // // We also store the `linearExpression`, so if we ever come across it, we can replace it with the new variable. + // The `newBounds`, will contain all the new equations for "select" part of the forall. + // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions + // We also store the `linearExpression`, so if we ever come across it, we can replace it with the new variable. case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTrigger: Seq[Seq[Expr[Post]]]) -} - -// -// var equalityChecker: ExpressionEqualityCheck = ExpressionEqualityCheck() -// -// override def visit(special: ASTSpecial): Unit = { -// if(special.kind == ASTSpecial.Kind.Inhale){ -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = ASTUtils.conjuncts(special.args(0), StandardOperator.Star).asScala -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// -// result = create special(special.kind, rewrite(special.args):_*) -// -// equalityChecker = ExpressionEqualityCheck() -// -// } else { -// result = create special(special.kind, rewrite(special.args): _*) -// } -// } -// -// override def visit(c: ASTClass): Unit = { //checkPermission(c); -// val name = c.getName -// if (name == null) Abort("illegal class without name") -// else { -// Debug("rewriting class " + name) -// val new_pars = rewrite(c.parameters) -// val new_supers = rewrite(c.super_classes) -// val new_implemented = rewrite(c.implemented_classes) -// val res = new ASTClass(name, c.kind, new_pars, new_supers, new_implemented) -// res.setOrigin(c.getOrigin) -// currentTargetClass = res -// val contract = c.getContract -// if (currentContractBuilder == null) currentContractBuilder = new ContractBuilder -// if (contract != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(contract.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(contract.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// rewrite(contract, currentContractBuilder) -// equalityChecker = ExpressionEqualityCheck() -// } -// res.setContract(currentContractBuilder.getContract) -// currentContractBuilder = null -// -// for (i <- 0 until c.size()) { -// res.add(rewrite(c.get(i))) -// } -// result = res -// currentTargetClass = null -// } -// } -// -// override def visit(s: ForEachLoop): Unit = { -// val new_decl = rewrite(s.decls) -// val res = create.foreach(new_decl, rewrite(s.guard), rewrite(s.body)) -// -// val mc = s.getContract -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// res.setContract(rewrite(mc)) -// equalityChecker = ExpressionEqualityCheck() -// } else { -// res.setContract(rewrite(mc)) -// } -// -// -// res.set_before(rewrite(s.get_before)) -// res.set_after(rewrite(s.get_after)) -// result = res -// } -// -// override def visit(s: LoopStatement): Unit = { //checkPermission(s); -// val res = new LoopStatement -// var tmp = s.getInitBlock -// if (tmp != null) res.setInitBlock(tmp.apply(this)) -// tmp = s.getUpdateBlock -// if (tmp != null) res.setUpdateBlock(tmp.apply(this)) -// tmp = s.getEntryGuard -// if (tmp != null) res.setEntryGuard(tmp.apply(this)) -// tmp = s.getExitGuard -// if (tmp != null) res.setExitGuard(tmp.apply(this)) -// val mc = s.getContract -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// res.appendContract(rewrite(mc)) -// equalityChecker = ExpressionEqualityCheck() -// } else { -// res.appendContract(rewrite(mc)) -// } -// -// -// tmp = s.getBody -// res.setBody(tmp.apply(this)) -// res.set_before(rewrite(s.get_before)) -// res.set_after(rewrite(s.get_after)) -// res.setOrigin(s.getOrigin) -// result = res -// } -// -// override def visit(m: Method): Unit = { //checkPermission(m); -// val name = m.getName -// if (currentContractBuilder == null) { -// currentContractBuilder = new ContractBuilder -// } -// val args = rewrite(m.getArgs) -// val mc = m.getContract -// -// var c: Contract = null -// // Ensure we maintain the type of emptiness of mc -// // If the contract was null previously, the new contract can also be null -// // If the contract was non-null previously, the new contract cannot be null -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// -// rewrite(mc, currentContractBuilder) -// c = currentContractBuilder.getContract(false) -// equalityChecker = ExpressionEqualityCheck() -// } -// else { -// c = currentContractBuilder.getContract(true) -// } -// if (mc != null && c != null && c.getOrigin == null) { -// c.setOrigin(mc.getOrigin) -// } -// currentContractBuilder = null -// val kind = m.kind -// val rt = rewrite(m.getReturnType) -// val signals = rewrite(m.signals) -// val body = rewrite(m.getBody) -// result = create.method_kind(kind, rt, signals, c, name, args, m.usesVarArgs, body) -// } -// -// override def visit(expr: BindingExpression): Unit = { -// expr.binder match { -// case Binder.Forall | Binder.Star => -// val bindings = expr.getDeclarations.map(_.name).toSet -// val (select, main) = splitSelect(rewrite(expr.select), rewrite(expr.main)) -// val (independentSelect, potentialBounds) = select.partition(independentOf(bindings, _)) -// val (bounds, dependent_bounds) = getBounds(bindings, potentialBounds) -// //Only rewrite main, when the dependent bounds are not existing -// if(dependent_bounds.isEmpty && expr.binder != Binder.Star){ -// rewriteMain(bounds, main) match { -// case Some(main) => -// result = create expression(Implies, (independentSelect ++ bounds.selectNonEmpty).reduce(and), main); return -// case None => -// } -// } -// rewriteLinearArray(bounds, main, independentSelect, dependent_bounds, expr.binder, expr.result_type) match { -// case Some(new_forall) => -// result = new_forall; -// return -// case None => -// } -// super.visit(expr) -// case _ => -// super.visit(expr) -// } -// } -//} \ No newline at end of file +} \ No newline at end of file From e1db973a3d3725f648a313bc559c99f8365d4b25 Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 14 Sep 2022 12:15:53 +0200 Subject: [PATCH 19/25] Add shared memory support --- col/src/main/java/vct/col/ast/Node.scala | 3 + .../lang/GPUGlobalImpl.scala | 7 + .../lang/GPULocalImpl.scala | 7 + .../lang/SharedMemSizeImpl.scala | 7 + .../vct/col/coerce/CoercingRewriter.scala | 7 +- col/src/main/java/vct/col/print/Printer.scala | 7 + .../java/vct/col/util/AstBuildHelpers.scala | 11 +- parsers/lib/antlr4/LangCParser.g4 | 2 + parsers/lib/antlr4/LangGPGPULexer.g4 | 2 + parsers/lib/antlr4/LangGPGPUParser.g4 | 6 +- parsers/lib/antlr4/SpecLexer.g4 | 1 + parsers/lib/antlr4/SpecParser.g4 | 1 + .../main/java/vct/parsers/ColCParser.scala | 3 +- .../src}/main/java/vct/parsers/Language.scala | 4 +- .../main/java/vct/parsers/ParseResult.scala | 9 +- .../java/vct/parsers/transform/CToCol.scala | 3 + .../SimplifyNestedQuantifiers.scala | 193 ++------------- .../vct/col/newrewrite/lang/LangCToCol.scala | 221 +++++++++++++++--- .../newrewrite/lang/LangSpecificToCol.scala | 7 +- src/main/java/vct/main/stages/Parsing.scala | 9 +- .../java/vct/main/stages/Resolution.scala | 9 +- src/main/java/vct/main/util/Util.scala | 4 +- 22 files changed, 297 insertions(+), 226 deletions(-) create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPUGlobalImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPULocalImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/SharedMemSizeImpl.scala rename {src => parsers/src}/main/java/vct/parsers/Language.scala (83%) diff --git a/col/src/main/java/vct/col/ast/Node.scala b/col/src/main/java/vct/col/ast/Node.scala index 5d76eb22a8..3da3a74eed 100644 --- a/col/src/main/java/vct/col/ast/Node.scala +++ b/col/src/main/java/vct/col/ast/Node.scala @@ -526,6 +526,7 @@ final case class Length[G](arr: Expr[G])(val blame: Blame[ArrayNull])(implicit v final case class Size[G](obj: Expr[G])(implicit val o: Origin) extends Expr[G] with SizeImpl[G] final case class PointerBlockLength[G](pointer: Expr[G])(val blame: Blame[PointerNull])(implicit val o: Origin) extends Expr[G] with PointerBlockLengthImpl[G] final case class PointerBlockOffset[G](pointer: Expr[G])(val blame: Blame[PointerNull])(implicit val o: Origin) extends Expr[G] with PointerBlockOffsetImpl[G] +final case class SharedMemSize[G](pointer: Expr[G])(implicit val o: Origin) extends Expr[G] with SharedMemSizeImpl[G] final case class Cons[G](x: Expr[G], xs: Expr[G])(implicit val o: Origin) extends Expr[G] with ConsImpl[G] final case class Head[G](xs: Expr[G])(val blame: Blame[SeqBoundFailure])(implicit val o: Origin) extends Expr[G] with HeadImpl[G] @@ -623,6 +624,8 @@ sealed trait CStorageClassSpecifier[G] extends CDeclarationSpecifier[G] with CSt final case class CTypedef[G]()(implicit val o: Origin) extends CStorageClassSpecifier[G] with CTypedefImpl[G] final case class CExtern[G]()(implicit val o: Origin) extends CStorageClassSpecifier[G] with CExternImpl[G] final case class CStatic[G]()(implicit val o: Origin) extends CStorageClassSpecifier[G] with CStaticImpl[G] +final case class GPULocal[G]()(implicit val o: Origin) extends CStorageClassSpecifier[G] with GPULocalImpl[G] +final case class GPUGlobal[G]()(implicit val o: Origin) extends CStorageClassSpecifier[G] with GPUGlobalImpl[G] sealed trait CTypeSpecifier[G] extends CDeclarationSpecifier[G] with CTypeSpecifierImpl[G] final case class CVoid[G]()(implicit val o: Origin) extends CTypeSpecifier[G] with CVoidImpl[G] diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPUGlobalImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPUGlobalImpl.scala new file mode 100644 index 0000000000..6bce30b76d --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPUGlobalImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GPUGlobal + +trait GPUGlobalImpl[G] { this: GPUGlobal[G] => + +} diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPULocalImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPULocalImpl.scala new file mode 100644 index 0000000000..ff9bcbe4c4 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GPULocalImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GPULocal + +trait GPULocalImpl[G] {this: GPULocal[G] => + +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/SharedMemSizeImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/SharedMemSizeImpl.scala new file mode 100644 index 0000000000..fd6085ac15 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/SharedMemSizeImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.{SharedMemSize, TInt, Type} + +trait SharedMemSizeImpl[G] { this: SharedMemSize[G] => + override def t: Type[G] = TInt() +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala index 3aae4590df..d0890180dd 100644 --- a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala +++ b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala @@ -405,7 +405,7 @@ abstract class CoercingRewriter[Pre <: Generation]() extends Rewriter[Pre] with e match { case ApplyCoercion(_, _) => - throw Unreachable("All instances of ApplyCoercion should be immediately rewritten by CoercingRewriter.disptach.") + throw Unreachable("All instances of ApplyCoercion should be immediately rewritten by CoercingRewriter.dispatch.") case ActionApply(action, args) => ActionApply(action, coerceArgs(args, action.decl)) @@ -991,6 +991,11 @@ abstract class CoercingRewriter[Pre <: Generation]() extends Rewriter[Pre] with val (right, TSet(rightT)) = set(ys) val sharedElement = Types.leastCommonSuperType(leftT, rightT) SetUnion(coerce(left, TSet(sharedElement)), coerce(right, TSet(sharedElement))) + case SharedMemSize(xs) => + firstOk(e, s"Expected operand to be a pointer or array, but got ${xs.t}.", + SharedMemSize(array(xs)._1), + SharedMemSize(pointer(xs)._1), + ) case SilverBagSize(xs) => SilverBagSize(bag(xs)._1) case SilverCurFieldPerm(obj, field) => diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 0d02d5a793..0e85849388 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -1213,6 +1213,13 @@ case class Printer(out: Appendable, case CSpecificationType(t) => say(t) case CTypeQualifierDeclarationSpecifier(typeQual) => say(typeQual) case CKernel() => say("__kernel") + case GPULocal() => say(syntax( + Cuda -> phrase("__shared__"), + OpenCL -> phrase("__local"), + )) + case GPUGlobal() => say(syntax( + OpenCL -> phrase("__global"), + )) } def printCTypeQualifier(node: CTypeQualifier[_]): Unit = node match { diff --git a/col/src/main/java/vct/col/util/AstBuildHelpers.scala b/col/src/main/java/vct/col/util/AstBuildHelpers.scala index f40731409f..42e1e90b9b 100644 --- a/col/src/main/java/vct/col/util/AstBuildHelpers.scala +++ b/col/src/main/java/vct/col/util/AstBuildHelpers.scala @@ -316,10 +316,13 @@ object AstBuildHelpers { exprs.reduceOption(And(_, _)).getOrElse(tt) def unfoldPredicate[G](p: AccountedPredicate[G]): Seq[Expr[G]] = p match { - case UnitAccountedPredicate(pred) => Seq(pred) + case UnitAccountedPredicate(pred) => unfoldStar(pred) case SplitAccountedPredicate(left, right) => unfoldPredicate(left) ++ unfoldPredicate(right) } + def filterPredicate[G](p: AccountedPredicate[G], f: Expr[G] => Boolean): AccountedPredicate[G] = + foldPredicate(unfoldPredicate(p).filter(f))(p.o) + def mapPredicate[G1, G2](p: AccountedPredicate[G1], f: Expr[G1] => Expr[G2]): AccountedPredicate[G2] = p match { case UnitAccountedPredicate(pred) => UnitAccountedPredicate(f(pred))(p.o) case SplitAccountedPredicate(left, right) => SplitAccountedPredicate(mapPredicate(left, f), mapPredicate(right, f))(p.o) @@ -354,6 +357,12 @@ object AstBuildHelpers { case SplitAccountedPredicate(left, right) => Star(foldStar(left), foldStar(right)) } + def foldPredicate[G](exprs: Seq[Expr[G]])(implicit o: Origin): AccountedPredicate[G] = + exprs + .map(e => UnitAccountedPredicate(e)(e.o)) + .reduceOption[AccountedPredicate[G]](SplitAccountedPredicate(_, _)) + .getOrElse(UnitAccountedPredicate(tt)) + def foldOr[G](exprs: Seq[Expr[G]])(implicit o: Origin): Expr[G] = exprs.reduceOption(Or(_, _)).getOrElse(ff) } diff --git a/parsers/lib/antlr4/LangCParser.g4 b/parsers/lib/antlr4/LangCParser.g4 index 916bd64a71..84c53b206f 100644 --- a/parsers/lib/antlr4/LangCParser.g4 +++ b/parsers/lib/antlr4/LangCParser.g4 @@ -259,6 +259,8 @@ storageClassSpecifier | '_Thread_local' | 'auto' | 'register' + | gpgpuLocalMemory + | gpgpuGlobalMemory ; typeSpecifier diff --git a/parsers/lib/antlr4/LangGPGPULexer.g4 b/parsers/lib/antlr4/LangGPGPULexer.g4 index b860622fd6..4d9a96914c 100644 --- a/parsers/lib/antlr4/LangGPGPULexer.g4 +++ b/parsers/lib/antlr4/LangGPGPULexer.g4 @@ -5,6 +5,8 @@ GPGPU_LOCAL_BARRIER: '__vercors_local_barrier__'; GPGPU_GLOBAL_BARRIER: '__vercors_global_barrier__'; GPGPU_KERNEL: '__vercors_kernel__'; GPGPU_ATOMIC: '__vercors_atomic__'; +GPGPU_GLOBAL_MEMORY: '__global'; +GPGPU_LOCAL_MEMORY: '__local'; GPGPU_CUDA_OPEN_EXEC_CONFIG: '<<<'; GPGPU_CUDA_CLOSE_EXEC_CONFIG: '>>>'; \ No newline at end of file diff --git a/parsers/lib/antlr4/LangGPGPUParser.g4 b/parsers/lib/antlr4/LangGPGPUParser.g4 index b493d5e9a3..dde7906c9d 100644 --- a/parsers/lib/antlr4/LangGPGPUParser.g4 +++ b/parsers/lib/antlr4/LangGPGPUParser.g4 @@ -16,4 +16,8 @@ gpgpuAtomicBlock : valEmbedWith? GPGPU_ATOMIC compoundStatement valEmbedThen? ; -gpgpuKernelSpecifier: GPGPU_KERNEL; \ No newline at end of file +gpgpuKernelSpecifier: GPGPU_KERNEL; + +gpgpuLocalMemory: GPGPU_LOCAL_MEMORY; + +gpgpuGlobalMemory: GPGPU_GLOBAL_MEMORY; \ No newline at end of file diff --git a/parsers/lib/antlr4/SpecLexer.g4 b/parsers/lib/antlr4/SpecLexer.g4 index 23f1e6b7bd..df5a5812e3 100644 --- a/parsers/lib/antlr4/SpecLexer.g4 +++ b/parsers/lib/antlr4/SpecLexer.g4 @@ -136,6 +136,7 @@ POINTER_INDEX: '\\pointer_index'; POINTER_BLOCK_LENGTH: '\\pointer_block_length'; POINTER_BLOCK_OFFSET: '\\pointer_block_offset'; POINTER_LENGTH: '\\pointer_length'; +SHARED_MEM_SIZE: '\\shared_mem_size'; VALUES: '\\values'; VCMP: '\\vcmp'; VREP: '\\vrep'; diff --git a/parsers/lib/antlr4/SpecParser.g4 b/parsers/lib/antlr4/SpecParser.g4 index 8c98dcfc49..27b26887c7 100644 --- a/parsers/lib/antlr4/SpecParser.g4 +++ b/parsers/lib/antlr4/SpecParser.g4 @@ -254,6 +254,7 @@ valPrimary | '\\type' '(' langType ')' # valTypeValue | 'held' '(' langExpr ')' # valHeld | LANG_ID_ESCAPE # valIdEscape + | '\\shared_mem_size' '(' langExpr ')' # valSharedMemSize ; // Out spec: defined meaning: a language local diff --git a/parsers/src/main/java/vct/parsers/ColCParser.scala b/parsers/src/main/java/vct/parsers/ColCParser.scala index f6cdbc9b03..dd534e0f9c 100644 --- a/parsers/src/main/java/vct/parsers/ColCParser.scala +++ b/parsers/src/main/java/vct/parsers/ColCParser.scala @@ -24,7 +24,8 @@ case class ColCParser(override val originProvider: OriginProvider, cc: Path, systemInclude: Path, otherIncludes: Seq[Path], - defines: Map[String, String]) extends Parser(originProvider, blameProvider) with LazyLogging { + defines: Map[String, String], + language: Language) extends Parser(originProvider, blameProvider) with LazyLogging { def interpret(localInclude: Seq[Path], input: String, output: String): Process = { var command = Seq(cc.toString, "-C", "-E") diff --git a/src/main/java/vct/parsers/Language.scala b/parsers/src/main/java/vct/parsers/Language.scala similarity index 83% rename from src/main/java/vct/parsers/Language.scala rename to parsers/src/main/java/vct/parsers/Language.scala index 60b22696a6..20780f923e 100644 --- a/src/main/java/vct/parsers/Language.scala +++ b/parsers/src/main/java/vct/parsers/Language.scala @@ -3,7 +3,8 @@ package vct.parsers case object Language { def fromFilename(filename: String): Option[Language] = filename.split('.').last match { - case "cl" | "c" | "cu" => Some(C) + case "cl" | "c" => Some(C) + case "cu" => Some(CUDA) case "i" => Some(InterpretedC) case "java" => Some(Java) case "pvl" => Some(PVL) @@ -12,6 +13,7 @@ case object Language { } case object C extends Language + case object CUDA extends Language case object InterpretedC extends Language case object Java extends Language case object PVL extends Language diff --git a/parsers/src/main/java/vct/parsers/ParseResult.scala b/parsers/src/main/java/vct/parsers/ParseResult.scala index 2acd2c9aea..b6a5ade416 100644 --- a/parsers/src/main/java/vct/parsers/ParseResult.scala +++ b/parsers/src/main/java/vct/parsers/ParseResult.scala @@ -4,11 +4,12 @@ import vct.col.ast.{GlobalDeclaration, VerificationContext} import vct.col.util.ExpectedError case object ParseResult { - def reduce[G](parses: Seq[ParseResult[G]]): ParseResult[G] = + def reduce[G](parses: Seq[(ParseResult[G], Option[Language])]): (ParseResult[G], Option[Language]) = parses.reduceOption((l, r) => (l, r) match { - case (ParseResult(declsLeft, expectedLeft), ParseResult(declsRight, expectedRight)) => - ParseResult(declsLeft ++ declsRight, expectedLeft ++ expectedRight) - }).getOrElse(ParseResult(Nil, Nil)) + case ((ParseResult(declsLeft, expectedLeft), l1), (ParseResult(declsRight, expectedRight), l2)) => + val lan = if(l1 == l2) l1 else None + (ParseResult(declsLeft ++ declsRight, expectedLeft ++ expectedRight), lan) + }).getOrElse((ParseResult(Nil, Nil), None)) } case class ParseResult[G](decls: Seq[GlobalDeclaration[G]], expectedErrors: Seq[ExpectedError]) \ No newline at end of file diff --git a/parsers/src/main/java/vct/parsers/transform/CToCol.scala b/parsers/src/main/java/vct/parsers/transform/CToCol.scala index 7ada5a9858..4ba15415c7 100644 --- a/parsers/src/main/java/vct/parsers/transform/CToCol.scala +++ b/parsers/src/main/java/vct/parsers/transform/CToCol.scala @@ -82,6 +82,8 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case StorageClassSpecifier3(_) => ??(storageClass) case StorageClassSpecifier4(_) => ??(storageClass) case StorageClassSpecifier5(_) => ??(storageClass) + case StorageClassSpecifier6(_) => GPULocal() + case StorageClassSpecifier7(_) => GPUGlobal() } def convert(implicit typeSpec: TypeSpecifierContext): CTypeSpecifier[G] = typeSpec match { @@ -1037,6 +1039,7 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case ValTypeValue(_, _, t, _) => TypeValue(convert(t)) case ValHeld(_, _, obj, _) => Held(convert(obj)) case ValIdEscape(text) => local(e, text.substring(1, text.length-1)) + case ValSharedMemSize(_, _, ptr, _) => SharedMemSize(convert(ptr)) } def convert(implicit e: ValExprContext): Expr[G] = e match { diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 7ae25d3e2e..0693f1c284 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -14,7 +14,6 @@ import vct.result.VerificationError.Unreachable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.annotation.nowarn -import scala.reflect.ClassTag /** * This rewrite pass simplifies expressions of roughly this form: @@ -125,7 +124,8 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] val yieldsArgs = variables.collect {contract.yieldsArgs.foreach(dispatch)}._1 val decreases = contract.decreases.map(element => rewriter.dispatch(element)) - ApplicableContract(requires, ensures, contextEverywhere, signals, givenArgs, yieldsArgs, decreases)(contract.blame)(contract.o) + ApplicableContract(requires, ensures, contextEverywhere, signals, givenArgs, yieldsArgs, decreases + )(contract.blame)(contract.o) } def rewriteLinearArray(e: Binder[Pre]): Option[Expr[Post]] = { @@ -255,6 +255,18 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] */ @nowarn("msg=xhaust") def addSingleBound(v: Variable[Pre], right: Expr[Pre], comp: Comparison): Unit = { + right match { + // Simplify rules from simplify.pvl come up with these kind of rules (specialize_range_right_i), + // but we want the original bounds + case Select(Less(e1, e2), e3, e4) => + if(e1 == e3 && e2 == e4 || e1 == e4 && e2 == e3){ + addSingleBound(v, e1, comp) + addSingleBound(v, e2, comp) + return + } + case _ => + } + comp match { // v < right case Comparison.LESS => @@ -699,7 +711,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] if(is_value(a_i_last, 1)) Mod(x_base, newGen(n_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, n_i_last should not be zero")) else - Mod(FloorDiv(x_base, newGen(a_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, a_i_last should not be zero")), newGen(n_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, n_i_last should not be zero")) + Mod(FloorDiv(x_base, newGen(a_i_last))( + PanicBlame("Error in SimplifyNestedQuantifiers, a_i_last should not be zero")), + newGen(n_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, n_i_last should not be zero")) // Yay we are good up to now, go check out the next i x_i_last = x_i @@ -801,175 +815,4 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions // // We also store the `linearExpression`, so if we ever come across it, we can replace it with the new variable. case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTrigger: Seq[Seq[Expr[Post]]]) -} - -// -// var equalityChecker: ExpressionEqualityCheck = ExpressionEqualityCheck() -// -// override def visit(special: ASTSpecial): Unit = { -// if(special.kind == ASTSpecial.Kind.Inhale){ -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = ASTUtils.conjuncts(special.args(0), StandardOperator.Star).asScala -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// -// result = create special(special.kind, rewrite(special.args):_*) -// -// equalityChecker = ExpressionEqualityCheck() -// -// } else { -// result = create special(special.kind, rewrite(special.args): _*) -// } -// } -// -// override def visit(c: ASTClass): Unit = { //checkPermission(c); -// val name = c.getName -// if (name == null) Abort("illegal class without name") -// else { -// Debug("rewriting class " + name) -// val new_pars = rewrite(c.parameters) -// val new_supers = rewrite(c.super_classes) -// val new_implemented = rewrite(c.implemented_classes) -// val res = new ASTClass(name, c.kind, new_pars, new_supers, new_implemented) -// res.setOrigin(c.getOrigin) -// currentTargetClass = res -// val contract = c.getContract -// if (currentContractBuilder == null) currentContractBuilder = new ContractBuilder -// if (contract != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(contract.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(contract.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// rewrite(contract, currentContractBuilder) -// equalityChecker = ExpressionEqualityCheck() -// } -// res.setContract(currentContractBuilder.getContract) -// currentContractBuilder = null -// -// for (i <- 0 until c.size()) { -// res.add(rewrite(c.get(i))) -// } -// result = res -// currentTargetClass = null -// } -// } -// -// override def visit(s: ForEachLoop): Unit = { -// val new_decl = rewrite(s.decls) -// val res = create.foreach(new_decl, rewrite(s.guard), rewrite(s.body)) -// -// val mc = s.getContract -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// res.setContract(rewrite(mc)) -// equalityChecker = ExpressionEqualityCheck() -// } else { -// res.setContract(rewrite(mc)) -// } -// -// -// res.set_before(rewrite(s.get_before)) -// res.set_after(rewrite(s.get_after)) -// result = res -// } -// -// override def visit(s: LoopStatement): Unit = { //checkPermission(s); -// val res = new LoopStatement -// var tmp = s.getInitBlock -// if (tmp != null) res.setInitBlock(tmp.apply(this)) -// tmp = s.getUpdateBlock -// if (tmp != null) res.setUpdateBlock(tmp.apply(this)) -// tmp = s.getEntryGuard -// if (tmp != null) res.setEntryGuard(tmp.apply(this)) -// tmp = s.getExitGuard -// if (tmp != null) res.setExitGuard(tmp.apply(this)) -// val mc = s.getContract -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// res.appendContract(rewrite(mc)) -// equalityChecker = ExpressionEqualityCheck() -// } else { -// res.appendContract(rewrite(mc)) -// } -// -// -// tmp = s.getBody -// res.setBody(tmp.apply(this)) -// res.set_before(rewrite(s.get_before)) -// res.set_after(rewrite(s.get_after)) -// res.setOrigin(s.getOrigin) -// result = res -// } -// -// override def visit(m: Method): Unit = { //checkPermission(m); -// val name = m.getName -// if (currentContractBuilder == null) { -// currentContractBuilder = new ContractBuilder -// } -// val args = rewrite(m.getArgs) -// val mc = m.getContract -// -// var c: Contract = null -// // Ensure we maintain the type of emptiness of mc -// // If the contract was null previously, the new contract can also be null -// // If the contract was non-null previously, the new contract cannot be null -// if (mc != null) { -// val info_getter = new AnnotationVariableInfoGetter() -// val annotations = LazyList(ASTUtils.conjuncts(mc.pre_condition, StandardOperator.Star).asScala -// , ASTUtils.conjuncts(mc.invariant, StandardOperator.Star).asScala).flatten -// -// equalityChecker = ExpressionEqualityCheck(Some(info_getter.getInfo(annotations))) -// -// rewrite(mc, currentContractBuilder) -// c = currentContractBuilder.getContract(false) -// equalityChecker = ExpressionEqualityCheck() -// } -// else { -// c = currentContractBuilder.getContract(true) -// } -// if (mc != null && c != null && c.getOrigin == null) { -// c.setOrigin(mc.getOrigin) -// } -// currentContractBuilder = null -// val kind = m.kind -// val rt = rewrite(m.getReturnType) -// val signals = rewrite(m.signals) -// val body = rewrite(m.getBody) -// result = create.method_kind(kind, rt, signals, c, name, args, m.usesVarArgs, body) -// } -// -// override def visit(expr: BindingExpression): Unit = { -// expr.binder match { -// case Binder.Forall | Binder.Star => -// val bindings = expr.getDeclarations.map(_.name).toSet -// val (select, main) = splitSelect(rewrite(expr.select), rewrite(expr.main)) -// val (independentSelect, potentialBounds) = select.partition(independentOf(bindings, _)) -// val (bounds, dependent_bounds) = getBounds(bindings, potentialBounds) -// //Only rewrite main, when the dependent bounds are not existing -// if(dependent_bounds.isEmpty && expr.binder != Binder.Star){ -// rewriteMain(bounds, main) match { -// case Some(main) => -// result = create expression(Implies, (independentSelect ++ bounds.selectNonEmpty).reduce(and), main); return -// case None => -// } -// } -// rewriteLinearArray(bounds, main, independentSelect, dependent_bounds, expr.binder, expr.result_type) match { -// case Some(new_forall) => -// result = new_forall; -// return -// case None => -// } -// super.visit(expr) -// case _ => -// super.visit(expr) -// } -// } -//} \ No newline at end of file +} \ No newline at end of file diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index 584c9f0f18..12bf25bd6a 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -4,15 +4,17 @@ import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast._ import vct.col.newrewrite.lang.LangSpecificToCol.NotAValue -import vct.col.origin.{AbstractApplicable, Blame, CallableFailure, InterpretedOriginVariable, Origin, PanicBlame, TrueSatisfiable} +import vct.col.origin.{AbstractApplicable, InterpretedOriginVariable, Origin, PanicBlame} import vct.col.ref.Ref -import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CInvocationTarget, CNameTarget, RefADTFunction, RefAxiomaticDataType, RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCudaBlockDim, RefCudaBlockIdx, RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, RefPredicate, RefProcedure, RefVariable, SpecDerefTarget, SpecInvocationTarget} +import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CNameTarget, RefADTFunction, RefAxiomaticDataType, RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCudaBlockDim, RefCudaBlockIdx, RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, RefPredicate, RefProcedure, RefVariable, SpecInvocationTarget} import vct.col.rewrite.{Generation, Rewritten} -import vct.col.util.{Substitute, SuccessionMap} +import vct.col.util.SuccessionMap import vct.col.util.AstBuildHelpers._ +import vct.parsers.Language import vct.result.VerificationError.UserError import scala.collection.immutable.ListMap +import scala.collection.mutable case object LangCToCol { case class CGlobalStateNotSupported(example: CInit[_]) extends UserError { @@ -21,6 +23,16 @@ case object LangCToCol { example.o.messageInContext("Global variables in C are not supported.") } + case class WrongGPUKernelParameterType(param: CParam[_]) extends UserError { + override def code: String = "wrongParameterType" + override def text: String = s"The parameter `$param` has a type that is not allowed`as parameter in a GPU kernel." + } + + case class WrongGPUType(param: CParam[_]) extends UserError { + override def code: String = "wrongGPUType" + override def text: String = s"The parameter `$param` has a type that is not allowed`outside of a GPU kernel." + } + case class CDoubleContracted(decl: CGlobalDeclaration[_], defn: CFunctionDefinition[_]) extends UserError { override def code: String = "multipleContracts" override def text: String = @@ -31,7 +43,7 @@ case object LangCToCol { } } -case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends LazyLogging { +case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: Option[Language]) extends LazyLogging { import LangCToCol._ type Post = Rewritten[Pre] implicit val implicitRewriter: AbstractRewriter[Pre, Post] = rw @@ -49,6 +61,13 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz val cudaCurrentGrid: ScopedStack[ParBlockDecl[Post]] = ScopedStack() val cudaCurrentBlock: ScopedStack[ParBlockDecl[Post]] = ScopedStack() + private val dynamicSharedMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() + private val dynamicSharedMemLengthVar: mutable.Map[CNameTarget[Pre], Variable[Post]] = mutable.Map() + private val staticSharedMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() + private val globalMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() + private var inKernel: Boolean = false + private var inKernelArgs: Boolean = false + case class CudaIndexVariableOrigin(dim: RefCudaVecDim[_]) extends Origin { override def preferredName: String = dim.vec.name + dim.name.toUpperCase @@ -65,6 +84,22 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz ): _*) } + private def hasNoSharedMemNames(node: Node[Pre]): Boolean = { + val allowedNonRefs = Set("get_local_id", "get_group_id", "get_local_size", "get_num_groups") + node match { + // SharedMemSize gets rewritten towards the length of a shared memory name, so is valid in global context + case _: SharedMemSize[Pre] => + case l: CLocal[Pre] => l.ref match { + case Some(ref: RefCParam[Pre]) + if dynamicSharedMemNames.contains(ref) || staticSharedMemNames.contains(ref) => return false + case None => if(!allowedNonRefs.contains(l.name)) ??? + case _ => + } + case e => if(!e.subnodes.forall(hasNoSharedMemNames)) return false + } + true + } + def rewriteUnit(cUnit: CTranslationUnit[Pre]): Unit = { cUnit.declarations.foreach(rw.dispatch) } @@ -77,11 +112,68 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz case CName(name: String) => name } + def sharedSize(shared: SharedMemSize[Pre]): Expr[Post] = { + val SharedMemSize(pointer) = shared + + pointer match { + case loc: CLocal[Pre] => Local[Post](dynamicSharedMemLengthVar(loc.ref.get).ref)(shared.o) + case _ => ??? + } + } + def rewriteParam(cParam: CParam[Pre]): Unit = { cParam.drop() val o = InterpretedOriginVariable(cDeclToName(cParam.declarator), cParam.o) + + var array = false + var pointer = false + var global = false + var shared = false + var extern = false + var innerType: Option[Type[Pre]] = None + + var isDynamicShared = false + + val cRef = RefCParam(cParam) + + cParam.specifiers.foreach{ + case GPULocal() => shared = true + case GPUGlobal() => global = true + case CSpecificationType(TPointer(t)) => + pointer = true + innerType = Some(t) + case CSpecificationType(TArray(t)) => + array = true + innerType = Some(t) + case CExtern() => extern = true + case _ => + } + + if((shared || global) && !inKernel) throw WrongGPUType(cParam) + if(inKernelArgs){ + language.get match { + case Language.C => + if(global && !shared && (pointer || array) && !extern) globalMemNames.add(cRef) + else if(shared && !global && (pointer || array) && !extern) { + dynamicSharedMemNames.add(cRef) + // Create Var with array type here and return + isDynamicShared = true + val v = new Variable[Post](TArray[Post]( rw.dispatch(innerType.get) )(o) )(o) + cNameSuccessor(cRef) = v + return + } + else if(!shared && !global && !pointer && !array && !extern) () + else throw WrongGPUKernelParameterType(cParam) + case Language.CUDA => + if(!global && !shared && (pointer || array) && !extern) globalMemNames.add(cRef) + else if(!shared && !global && !pointer && !array && !extern) () + else throw WrongGPUKernelParameterType(cParam) + case _ => ??? // Should not happen + } + } + val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) - cNameSuccessor(RefCParam(cParam)) = v + cNameSuccessor(cRef) = v rw.variables.declare(v) } @@ -89,7 +181,13 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz func.drop() val info = C.getDeclaratorInfo(func.declarator) val returnType = func.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) + if (func.specs.collectFirst { case CKernel() => () }.nonEmpty){ + inKernel = true + inKernelArgs = true + } val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 + inKernel = false + inKernelArgs = false val (contract, subs: Map[CParam[Pre], CParam[Pre]]) = func.ref match { case Some(RefCGlobalDeclaration(decl, idx)) if decl.decl.contract.nonEmpty => @@ -170,11 +268,36 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz } def kernelProcedure(o: Origin, contract: ApplicableContract[Pre], info: C.DeclaratorInfo[Pre], body: Option[Statement[Pre]]): Procedure[Post] = { + dynamicSharedMemNames.clear() + staticSharedMemNames.clear() + inKernel = true + val blockDim = new CudaVec(RefCudaBlockDim())(o) val gridDim = new CudaVec(RefCudaGridDim())(o) cudaCurrentBlockDim.having(blockDim) { cudaCurrentGridDim.having(gridDim) { - val newArgs = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 + inKernelArgs = true + val args = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 + inKernelArgs = false + rw.variables.collect { dynamicSharedMemNames.foreach(d => rw.variables.declare(cNameSuccessor(d)) ) } + val (sharedMemSizes, sharedMemInit: Seq[Statement[Post]]) = rw.variables.collect { + var result: Seq[Statement[Post]] = Seq() + dynamicSharedMemNames.foreach(d => + { + implicit val o: Origin = d.decl.o + val v = new Variable[Post](TInt()) + dynamicSharedMemLengthVar(d) = v + rw.variables.declare(v) + val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) + val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) + , NewArray[Post](v.t, Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) + result = result ++ Seq(decl, assign) + }) + result + } + + val newArgs = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ args ++ sharedMemSizes + val newGivenArgs = rw.variables.dispatch(contract.givenArgs) val newYieldsArgs = rw.variables.dispatch(contract.yieldsArgs) // We add the requirement that a GPU kernel must always have threads (non zero block or grid dimensions) @@ -182,10 +305,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz = (blockDim.indices.values ++ gridDim.indices.values).map( v => Less(IntegerValue(0)(o), v.get(o))(o) ).toSeq val nonZeroThreads = foldStar(nonZeroThreadsSeq)(o) - // Ugly, but can't get it to type check otherwise - val nonZeroThreadsPred1: Seq[AccountedPredicate[Post]] = nonZeroThreadsSeq.map(UnitAccountedPredicate(_)(o)) - val nonZeroThreadsPred: AccountedPredicate[Post] = - nonZeroThreadsPred1.reduceLeft((l,r) => SplitAccountedPredicate(l, r)(o)) + val nonZeroThreadsPred = + nonZeroThreadsSeq + .map(UnitAccountedPredicate(_)(o)) + .reduceLeft[AccountedPredicate[Post]](SplitAccountedPredicate(_,_)(o)) val parBody = body.map(impl => { implicit val o: Origin = impl.o @@ -198,15 +321,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz cudaCurrentBlockIdx.having(blockIdx) { cudaCurrentGrid.having(gridDecl) { cudaCurrentBlock.having(blockDecl) { - ParStatement(ParBlock( - decl = gridDecl, - iters = blockIdx.indices.values.zip(gridDim.indices.values).map { - case (index, dim) => IterVariable(index, const(0), dim.get) - }.toSeq, - context_everywhere = tt, //Star(nonZeroThreads, allThreadsInBlock(contract.contextEverywhere))(o), - requires = Star(nonZeroThreads, allThreadsInBlock(foldStar(unfoldPredicate(contract.requires)))), - ensures = Star(nonZeroThreads, allThreadsInBlock(foldStar(unfoldPredicate(contract.ensures)))), - content = ParStatement(ParBlock( + val contextBlock = foldStar(unfoldStar(contract.contextEverywhere) + .filter(hasNoSharedMemNames).map(allThreadsInBlock)) + val innerContent = ParStatement(ParBlock( decl = blockDecl, iters = threadIdx.indices.values.zip(blockDim.indices.values).map { case (index, dim) => IterVariable(index, const(0), dim.get) @@ -217,6 +334,29 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz ensures = rw.dispatch(foldStar(unfoldPredicate(contract.ensures))), content = rw.dispatch(impl), )(PanicBlame("where blame?"))) + ParStatement(ParBlock( + decl = gridDecl, + iters = blockIdx.indices.values.zip(gridDim.indices.values).map { + case (index, dim) => IterVariable(index, const(0), dim.get) + }.toSeq, + // Context is added to requires and ensures here + context_everywhere = tt, + requires = Star( + Star(nonZeroThreads, + foldStar( + unfoldPredicate(contract.requires) + .filter(hasNoSharedMemNames) + .map(allThreadsInBlock))) + , contextBlock), + ensures = Star( + Star(nonZeroThreads, + foldStar( + unfoldPredicate(contract.ensures) + .filter(hasNoSharedMemNames) + .map(allThreadsInBlock)) ) + , contextBlock), + // Add shared memory initialization before beginning of inner parallel block + content = Block[Post](sharedMemInit ++ Seq(innerContent)) )(PanicBlame("where blame?"))) } } @@ -224,16 +364,25 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz } }) + val gridContext: AccountedPredicate[Post] = + foldPredicate(unfoldStar(contract.contextEverywhere) + .filter(hasNoSharedMemNames) + .map(allThreadsInBlock))(o) new Procedure[Post]( returnType = TVoid(), args = newArgs, outArgs = Nil, typeArgs = Nil, body = parBody, contract = ApplicableContract( - SplitAccountedPredicate(nonZeroThreadsPred, mapPredicate(contract.requires, allThreadsInGrid))(o), - SplitAccountedPredicate(nonZeroThreadsPred, mapPredicate(contract.ensures, allThreadsInGrid))(o), + SplitAccountedPredicate( + SplitAccountedPredicate(nonZeroThreadsPred, + mapPredicate(filterPredicate(contract.requires, hasNoSharedMemNames), allThreadsInGrid))(o), + gridContext)(o), + SplitAccountedPredicate( + SplitAccountedPredicate(nonZeroThreadsPred, + mapPredicate(filterPredicate(contract.ensures, hasNoSharedMemNames), allThreadsInGrid))(o), + gridContext)(o), // Context everywhere is already passed down in the body - //allThreadsInGrid(contract.contextEverywhere), tt, contract.signals.map(rw.dispatch), newGivenArgs, @@ -334,7 +483,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz def local(local: CLocal[Pre]): Expr[Post] = { implicit val o: Origin = local.o local.ref.get match { - case RefAxiomaticDataType(decl) => throw NotAValue(local) + case RefAxiomaticDataType(_) => throw NotAValue(local) case RefVariable(decl) => Local(rw.succ(decl)) case RefModelField(decl) => ModelDeref[Post](rw.currentThis.top, rw.succ(decl))(local.blame) case ref: RefCParam[Pre] => @@ -342,10 +491,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz Local(cNameSuccessor.ref(RefCParam(cCurrentDefinitionParamSubstitutions.top.getOrElse(ref.decl, ref.decl)))) else Local(cNameSuccessor.ref(ref)) - case RefCFunctionDefinition(decl) => throw NotAValue(local) - case RefCGlobalDeclaration(decls, initIdx) => throw NotAValue(local) + case RefCFunctionDefinition(_) => throw NotAValue(local) + case RefCGlobalDeclaration(_, _) => throw NotAValue(local) case ref: RefCLocalDeclaration[Pre] => Local(cNameSuccessor.ref(ref)) - case ref: RefCudaVec[Pre] => throw NotAValue(local) + case _: RefCudaVec[Pre] => throw NotAValue(local) } } @@ -392,10 +541,22 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz ProcedureInvocation[Post](cFunctionSuccessor.ref(ref.decl), args.map(rw.dispatch), Nil, Nil, givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) - case RefCGlobalDeclaration(decls, initIdx) => - ProcedureInvocation[Post](cFunctionDeclSuccessor.ref((decls, initIdx)), args.map(rw.dispatch), Nil, Nil, - givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, - yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) + case e@ RefCGlobalDeclaration(decls, initIdx) => + val arg = if(args.size == 1){ + args.head match { + case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt) + case _ => None + } + } else None + (e.name, arg) match { + case ("get_local_id", Some(i)) => cudaCurrentThreadIdx.top.indices.values.toSeq.apply(i).get + case ("get_group_id", Some(i)) => cudaCurrentBlockIdx.top.indices.values.toSeq.apply(i).get + case ("get_local_size", Some(i)) => cudaCurrentBlockDim.top.indices.values.toSeq.apply(i).get + case ("get_num_groups", Some(i)) => cudaCurrentGridDim.top.indices.values.toSeq.apply(i).get + case _ => ProcedureInvocation[Post](cFunctionDeclSuccessor.ref((decls, initIdx)), args.map(rw.dispatch), Nil, Nil, + givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, + yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) + } } } } diff --git a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala index 6d73fe340d..ce34f232f4 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala @@ -7,9 +7,11 @@ import vct.col.ast._ import vct.col.origin._ import vct.col.resolve._ import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} +import vct.parsers.Language import vct.result.VerificationError.UserError case object LangSpecificToCol extends RewriterBuilder { + def apply[Pre <: Generation](): AbstractRewriter[Pre, _ <: Generation] = apply(None) override def key: String = "langSpecific" override def desc: String = "Translate language-specific constructs to a common subset of nodes." @@ -26,9 +28,9 @@ case object LangSpecificToCol extends RewriterBuilder { } } -case class LangSpecificToCol[Pre <: Generation]() extends Rewriter[Pre] with LazyLogging { +case class LangSpecificToCol[Pre <: Generation](language: Option[Language]) extends Rewriter[Pre] with LazyLogging { val java: LangJavaToCol[Pre] = LangJavaToCol(this) - val c: LangCToCol[Pre] = LangCToCol(this) + val c: LangCToCol[Pre] = LangCToCol(this, language) val pvl: LangPVLToCol[Pre] = LangPVLToCol(this) val silver: LangSilverToCol[Pre] = LangSilverToCol(this) @@ -124,6 +126,7 @@ case class LangSpecificToCol[Pre <: Generation]() extends Rewriter[Pre] with Laz case local: CLocal[Pre] => c.local(local) case deref: CStructAccess[Pre] => c.deref(deref) case inv: CInvocation[Pre] => c.invocation(inv) + case shared: SharedMemSize[Pre] => c.sharedSize(shared) case inv: SilverPartialADTFunctionInvocation[Pre] => silver.adtInvocation(inv) case map: SilverUntypedNonemptyLiteralMap[Pre] => silver.nonemptyMap(map) diff --git a/src/main/java/vct/main/stages/Parsing.scala b/src/main/java/vct/main/stages/Parsing.scala index 054d0b0ab9..4249727ce5 100644 --- a/src/main/java/vct/main/stages/Parsing.scala +++ b/src/main/java/vct/main/stages/Parsing.scala @@ -36,11 +36,11 @@ case class Parsing[G <: Generation] cSystemInclude: Path = Resources.getCIncludePath, cOtherIncludes: Seq[Path] = Nil, cDefines: Map[String, String] = Map.empty, -) extends Stage[Seq[Readable], ParseResult[G]] { +) extends Stage[Seq[Readable], (ParseResult[G], Option[Language])] { override def friendlyName: String = "Parsing" override def progressWeight: Int = 4 - override def run(in: Seq[Readable]): ParseResult[G] = + override def run(in: Seq[Readable]): (ParseResult[G], Option[Language]) = ParseResult.reduce(in.map { readable => val language = forceLanguage .orElse(Language.fromFilename(readable.fileName)) @@ -49,13 +49,14 @@ case class Parsing[G <: Generation] val originProvider = ReadableOriginProvider(readable) val parser = language match { - case Language.C => ColCParser(originProvider, blameProvider, cc, cSystemInclude, cOtherIncludes, cDefines) + case Language.C | Language.CUDA + => ColCParser(originProvider, blameProvider, cc, cSystemInclude, cOtherIncludes, cDefines, language) case Language.InterpretedC => ColIParser(originProvider, blameProvider) case Language.Java => ColJavaParser(originProvider, blameProvider) case Language.PVL => ColPVLParser(originProvider, blameProvider) case Language.Silver => ColSilverParser(originProvider, blameProvider) } - parser.parse[G](readable) + (parser.parse[G](readable), Some(language)) }) } \ No newline at end of file diff --git a/src/main/java/vct/main/stages/Resolution.scala b/src/main/java/vct/main/stages/Resolution.scala index 934b5b26a8..67e5b0fd46 100644 --- a/src/main/java/vct/main/stages/Resolution.scala +++ b/src/main/java/vct/main/stages/Resolution.scala @@ -12,7 +12,7 @@ import vct.main.Main.TemporarilyUnsupported import vct.main.stages.Resolution.InputResolutionError import vct.main.stages.Transformation.TransformationCheckError import vct.options.Options -import vct.parsers.ParseResult +import vct.parsers.{Language, ParseResult} import vct.parsers.transform.BlameProvider import vct.resources.Resources import vct.result.VerificationError.UserError @@ -38,11 +38,12 @@ case class Resolution[G <: Generation] blameProvider: BlameProvider, withJava: Boolean = true, javaLibraryPath: Path = Resources.getJrePath, -) extends Stage[ParseResult[G], VerificationContext[_ <: Generation]] { +) extends Stage[(ParseResult[G], Option[Language]), VerificationContext[_ <: Generation]] { override def friendlyName: String = "Name Resolution" override def progressWeight: Int = 1 - override def run(in: ParseResult[G]): VerificationContext[_ <: Generation] = { + override def run(inLanguage: (ParseResult[G], Option[Language]) ): VerificationContext[_ <: Generation] = { + val (in, language) = inLanguage in.decls.foreach(_.transSubnodes.foreach { case decl: CGlobalDeclaration[_] => decl.decl.inits.foreach(init => { if(C.getDeclaratorInfo(init.decl).params.isEmpty) { @@ -64,7 +65,7 @@ case class Resolution[G <: Generation] case Nil => // ok case some => throw InputResolutionError(some) } - val resolvedProgram = LangSpecificToCol().dispatch(typedProgram) + val resolvedProgram = LangSpecificToCol(language).dispatch(typedProgram) resolvedProgram.check match { case Nil => // ok case some => throw TransformationCheckError(some) diff --git a/src/main/java/vct/main/util/Util.scala b/src/main/java/vct/main/util/Util.scala index 329d3fc26b..bc1730ebae 100644 --- a/src/main/java/vct/main/util/Util.scala +++ b/src/main/java/vct/main/util/Util.scala @@ -5,7 +5,7 @@ import vct.col.ast.Program import vct.col.newrewrite.Disambiguate import vct.col.origin.{Blame, VerificationFailure} import vct.main.stages.Resolution -import vct.parsers.ColPVLParser +import vct.parsers.{ColPVLParser, Language} import vct.parsers.transform.{BlameProvider, ConstantBlameProvider, ReadableOriginProvider} import vct.result.VerificationError.UserError @@ -23,7 +23,7 @@ case object Util { def loadPVLLibraryFile[G](readable: Readable): Program[G] = { val res = ColPVLParser(ReadableOriginProvider(readable), ConstantBlameProvider(LibraryFileBlame)).parse(readable) - val context = Resolution(ConstantBlameProvider(LibraryFileBlame), withJava = false).run(res) + val context = Resolution(ConstantBlameProvider(LibraryFileBlame), withJava = false).run((res, Some(Language.PVL))) assert(context.expectedErrors.isEmpty) val unambiguousProgram: Program[_] = Disambiguate().dispatch(context.program) unambiguousProgram.asInstanceOf[Program[G]] From 722e7347309d51e342365bf83f56deb8680c183c Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 14 Sep 2022 18:03:27 +0200 Subject: [PATCH 20/25] Some fixes to SimplifNestedQuantifiers to work with 'new' permissions --- col/src/main/java/vct/col/print/Printer.scala | 2 + .../vct/col/newrewrite/Disambiguate.scala | 9 +++- .../SimplifyNestedQuantifiers.scala | 41 ++++++++++--------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 0e85849388..8169ed6cc2 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -754,6 +754,8 @@ case class Printer(out: Appendable, (phrase(assoc(100, obj), ".", name(ref.decl)), 100) case DerefPointer(pointer) => (phrase("*", assoc(90, pointer)), 90) + case PointerAdd(pointer, offset) => + (phrase(assoc(70, pointer), space, "+", space, assoc(70, offset)), 70) case AddrOf(e) => (phrase("&", assoc(90, e)), 90) case PredicateApply(ref, args, perm) => diff --git a/src/main/java/vct/col/newrewrite/Disambiguate.scala b/src/main/java/vct/col/newrewrite/Disambiguate.scala index 9b174cb8ee..fdfcdec647 100644 --- a/src/main/java/vct/col/newrewrite/Disambiguate.scala +++ b/src/main/java/vct/col/newrewrite/Disambiguate.scala @@ -22,7 +22,7 @@ case class Disambiguate[Pre <: Generation]() extends Rewriter[Pre] { else Mult(dispatch(left), dispatch(right)) case op @ AmbiguousPlus(left, right) => if(op.isProcessOp) ProcessChoice(dispatch(left), dispatch(right)) - else if(op.isPointerOp) PointerAdd(dispatch(left), dispatch(right))(op.blame) + else if(op.isPointerOp) unfoldPointerAdd(PointerAdd(dispatch(left), dispatch(right))(op.blame)) else if(op.isSeqOp) Concat(dispatch(left), dispatch(right)) else if(op.isSetOp) SetUnion(dispatch(left), dispatch(right)) else if(op.isBagOp) BagAdd(dispatch(left), dispatch(right)) @@ -78,4 +78,11 @@ case class Disambiguate[Pre <: Generation]() extends Rewriter[Pre] { case other => rewriteDefault(other) } } + + def unfoldPointerAdd[G](e: PointerAdd[G]): PointerAdd[G] = e.pointer match { + case inner @ PointerAdd(_, _) => + val PointerAdd(pointerInner, offsetInner) = unfoldPointerAdd(inner) + PointerAdd(pointerInner, Plus(offsetInner, e.offset)(e.o) )(e.blame)(e.o) + case _ => e + } } diff --git a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala index 1d079d9151..b649eac20a 100644 --- a/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala +++ b/src/main/java/vct/col/newrewrite/SimplifyNestedQuantifiers.scala @@ -4,7 +4,7 @@ import com.typesafe.scalalogging.LazyLogging import vct.col.ast.{ArraySubscript, _} import vct.col.ast.util.{AnnotationVariableInfoGetter, ExpressionEqualityCheck} import vct.col.newrewrite.util.Comparison -import vct.col.origin.{Origin, PanicBlame} +import vct.col.origin.{ArrayInsufficientPermission, Origin, PanicBlame, PointerBounds} import vct.col.ref.Ref import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} import vct.col.util.AstBuildHelpers._ @@ -449,9 +449,9 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] val main = if (select.nonEmpty) Implies(AstBuildHelpers.foldAnd(select), newBody) else newBody @nowarn("msg=xhaust") val forall: Binder[Post] = originalBinder match { - case _: Forall[Pre] => Forall(bindings, substituteForall.newTrigger, main)(originalBinder.o) + case _: Forall[Pre] => Forall(bindings, substituteForall.newTriggers, main)(originalBinder.o) case originalBinder: Starall[Pre] => - Starall(bindings, substituteForall.newTrigger, main)(originalBinder.blame)(originalBinder.o) + Starall(bindings, substituteForall.newTriggers, main)(originalBinder.blame)(originalBinder.o) } Some(forall) case (_, None) => result() @@ -506,25 +506,23 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] val index: Expr[G] val subnodes: Seq[Node[G]] } - case class Array[G](e: ArraySubscript[G]) extends Subscript[G] { - val index: Expr[G] = e.index - val subnodes: Seq[Node[G]] = e.subnodes - } + case class Array[G](index: Expr[G], subnodes: Seq[Node[G]], array: Expr[G]) extends Subscript[G] - case class Pointer[G](e: PointerSubscript[G]) extends Subscript[G] { - val index: Expr[G] = e.index - val subnodes: Seq[Node[G]] = e.subnodes - } + case class Pointer[G](index: Expr[G], subnodes: Seq[Node[G]], array: Expr[G]) extends Subscript[G] class FindLinearArrayAccesses(quantifierData: RewriteQuantifierData){ // Search for linear array expressions def search(e: Node[Pre]): Option[SubstituteForall] = { e match { + case e @ ArrayLocation(_, _) => + testSubscript(Array(e.subscript, e.subnodes, e.array)) case e @ ArraySubscript(_, _) => - testSubscript(Array(e)) + testSubscript(Array(e.index, e.subnodes, e.arr)) case e @ PointerSubscript(_, _) => - testSubscript(Pointer(e)) + testSubscript(Pointer(e.index, e.subnodes, e.pointer)) + case e @ PointerAdd(_, _) => + testSubscript(Pointer(e.offset, e.subnodes, e.pointer)) case _ => e.subnodes.to(LazyList).map(search).collectFirst{case Some(sub) => sub} } } @@ -727,14 +725,17 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // Replace the linear expression with the new variable val x_new_var: Expr[Post] = Local(x_new.ref) replace_map(arrayIndex.index) = x_new_var - val trigger: Expr[Post] = arrayIndex match { - case Array( node @ ArraySubscript(arr, _)) => ArraySubscript(newGen(arr), x_new_var)(node.blame) - case Pointer( node @ PointerSubscript(arr, _)) => PointerSubscript(newGen(arr), x_new_var)(node.blame) + val newTriggers : Seq[Seq[Expr[Post]]] = arrayIndex match { + case arrayIndex: Array[Pre] => + Seq(Seq(ArraySubscript(newGen(arrayIndex.array), x_new_var)(PanicBlame("Only used as trigger, not as access"))), + // Seq(ArrayLocation(newGen(arrayIndex.array), x_new_var)(PanicBlame("Only used as trigger, not as access"))) + ) + case arrayIndex: Pointer[Pre] => + Seq(Seq(PointerSubscript(newGen(arrayIndex.array), x_new_var)(PanicBlame("Only used as trigger, not as access"))), + Seq(PointerAdd(newGen(arrayIndex.array), x_new_var)(PanicBlame("Only used as trigger, not as access")))) } - val newTrigger : Seq[Seq[Expr[Post]]] = Seq(Seq(trigger)) // Add the last value, no need to do modulo - // replace_map(Local(x_i_last.ref)) = if(is_value(a_i_last, 1)) x_base else FloorDiv(x_base, newGen(a_i_last))(PanicBlame("Error in SimplifyNestedQuantifiers, a_i_last should not be zero")) @@ -773,7 +774,7 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] new_bounds = And(Less(IntegerValue(0), newGen(n_i)), new_bounds) } - Some(SubstituteForall(new_bounds, replace_map.toMap, newTrigger)) + Some(SubstituteForall(new_bounds, replace_map.toMap, newTriggers)) } def simplified_mult(lhs: Expr[Pre], rhs: Expr[Pre]): Expr[Pre] = { @@ -813,5 +814,5 @@ case class SimplifyNestedQuantifiers[Pre <: Generation]() extends Rewriter[Pre] // The `newBounds`, will contain all the new equations for "select" part of the forall. // The `substituteOldVars` contains a map, so we can replace the old forall variables with new expressions // We also store the `linearExpression`, so if we ever come across it, we can replace it with the new variable. - case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTrigger: Seq[Seq[Expr[Post]]]) + case class SubstituteForall(newBounds: Expr[Post], substituteOldVars: Map[Expr[Pre], Expr[Post]], newTriggers: Seq[Seq[Expr[Post]]]) } \ No newline at end of file From 0565c7de5da6025aeccf7c72167a16949fea8111 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 19 Sep 2022 14:05:22 +0200 Subject: [PATCH 21/25] Added barriers with memory fences --- col/src/main/java/vct/col/ast/Node.scala | 10 +- .../lang/GpgpuBarrierImpl.scala | 7 + .../lang/GpgpuGlobalBarrierImpl.scala | 7 - .../lang/GpgpuLocalBarrierImpl.scala | 7 - .../lang/GpuGlobalMemoryFenceImpl.scala | 7 + .../lang/GpuLocalMemoryFenceImpl.scala | 7 + .../lang/GpuMemoryFenceImpl.scala | 7 + .../lang/GpuZeroMemoryFenceImpl.scala | 7 + .../vct/col/coerce/CoercingRewriter.scala | 4 +- .../java/vct/col/feature/FeatureRainbow.scala | 5 +- col/src/main/java/vct/col/print/Printer.scala | 16 +- .../vct/col/rewrite/NonLatchingRewriter.scala | 1 + parsers/lib/antlr4/LangCParser.g4 | 3 +- parsers/lib/antlr4/LangGPGPULexer.g4 | 8 +- parsers/lib/antlr4/LangGPGPUParser.g4 | 19 +- .../java/vct/parsers/transform/CToCol.scala | 26 +- .../vct/col/newrewrite/lang/LangCToCol.scala | 233 +++++++++++++----- .../newrewrite/lang/LangSpecificToCol.scala | 3 +- src/main/universal/res/c/cuda.h | 3 +- src/main/universal/res/c/opencl.h | 9 +- 20 files changed, 270 insertions(+), 119 deletions(-) create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuBarrierImpl.scala delete mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuGlobalBarrierImpl.scala delete mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuLocalBarrierImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuGlobalMemoryFenceImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuLocalMemoryFenceImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuMemoryFenceImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuZeroMemoryFenceImpl.scala diff --git a/col/src/main/java/vct/col/ast/Node.scala b/col/src/main/java/vct/col/ast/Node.scala index 32153e903c..2c1b8a434b 100644 --- a/col/src/main/java/vct/col/ast/Node.scala +++ b/col/src/main/java/vct/col/ast/Node.scala @@ -700,8 +700,14 @@ final case class CGoto[G](label: String)(implicit val o: Origin) extends CStatem var ref: Option[LabelDecl[G]] = None } -final case class GpgpuLocalBarrier[G](requires: Expr[G], ensures: Expr[G])(implicit val o: Origin) extends CStatement[G] with GpgpuLocalBarrierImpl[G] -final case class GpgpuGlobalBarrier[G](requires: Expr[G], ensures: Expr[G])(implicit val o: Origin) extends CStatement[G] with GpgpuGlobalBarrierImpl[G] +sealed trait GpuMemoryFence[G] extends NodeFamily[G] with GpuMemoryFenceImpl[G] + +final case class GpuLocalMemoryFence[G]()(implicit val o: Origin) extends GpuMemoryFence[G] with GpuLocalMemoryFenceImpl[G] +final case class GpuGlobalMemoryFence[G]()(implicit val o: Origin) extends GpuMemoryFence[G] with GpuGlobalMemoryFenceImpl[G] +final case class GpuZeroMemoryFence[G](value: BigInt)(implicit val o: Origin) extends GpuMemoryFence[G] with GpuZeroMemoryFenceImpl[G] + + +final case class GpgpuBarrier[G](requires: Expr[G], ensures: Expr[G], specifiers: Seq[GpuMemoryFence[G]])(implicit val o: Origin) extends CStatement[G] with GpgpuBarrierImpl[G] final case class GpgpuAtomic[G](impl: Statement[G], before: Statement[G], after: Statement[G])(implicit val o: Origin) extends CStatement[G] with GpgpuAtomicImpl[G] sealed trait CExpr[G] extends Expr[G] with CExprImpl[G] diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuBarrierImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuBarrierImpl.scala new file mode 100644 index 0000000000..4613dbe419 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuBarrierImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GpgpuBarrier + +trait GpgpuBarrierImpl[G] { this: GpgpuBarrier[G] => + +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuGlobalBarrierImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuGlobalBarrierImpl.scala deleted file mode 100644 index d79486e985..0000000000 --- a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuGlobalBarrierImpl.scala +++ /dev/null @@ -1,7 +0,0 @@ -package vct.col.ast.temporaryimplpackage.lang - -import vct.col.ast.GpgpuGlobalBarrier - -trait GpgpuGlobalBarrierImpl[G] { this: GpgpuGlobalBarrier[G] => - -} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuLocalBarrierImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuLocalBarrierImpl.scala deleted file mode 100644 index c2f654bca1..0000000000 --- a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpgpuLocalBarrierImpl.scala +++ /dev/null @@ -1,7 +0,0 @@ -package vct.col.ast.temporaryimplpackage.lang - -import vct.col.ast.GpgpuLocalBarrier - -trait GpgpuLocalBarrierImpl[G] { this: GpgpuLocalBarrier[G] => - -} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuGlobalMemoryFenceImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuGlobalMemoryFenceImpl.scala new file mode 100644 index 0000000000..73ac4048a9 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuGlobalMemoryFenceImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GpuGlobalMemoryFence + +trait GpuGlobalMemoryFenceImpl[G] { this: GpuGlobalMemoryFence[G] => + +} diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuLocalMemoryFenceImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuLocalMemoryFenceImpl.scala new file mode 100644 index 0000000000..cd6feb2d66 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuLocalMemoryFenceImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GpuLocalMemoryFence + +trait GpuLocalMemoryFenceImpl[G] { this: GpuLocalMemoryFence[G] => + +} diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuMemoryFenceImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuMemoryFenceImpl.scala new file mode 100644 index 0000000000..654d44203b --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuMemoryFenceImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GpuMemoryFence + +trait GpuMemoryFenceImpl[G] { this: GpuMemoryFence[G] => + +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuZeroMemoryFenceImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuZeroMemoryFenceImpl.scala new file mode 100644 index 0000000000..a579d77851 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/GpuZeroMemoryFenceImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.GpuZeroMemoryFence + +trait GpuZeroMemoryFenceImpl[G] { this: GpuZeroMemoryFence[G] => + +} diff --git a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala index bb6e025a03..b2242f60d9 100644 --- a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala +++ b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala @@ -211,6 +211,7 @@ abstract class CoercingRewriter[Pre <: Generation]() extends Rewriter[Pre] with case node: CPointer[Pre] => node case node: CInit[Pre] => node case node: CDeclaration[Pre] => node + case node: GpuMemoryFence[Pre] => node case node: JavaModifier[Pre] => node case node: JavaImport[Pre] => node case node: JavaName[Pre] => node @@ -1138,8 +1139,7 @@ abstract class CoercingRewriter[Pre <: Generation]() extends Rewriter[Pre] with case proof @ FramedProof(pre, body, post) => FramedProof(res(pre), body, res(post))(proof.blame) case Goto(lbl) => Goto(lbl) case GpgpuAtomic(impl, before, after) => GpgpuAtomic(impl, before, after) - case GpgpuGlobalBarrier(requires, ensures) => GpgpuGlobalBarrier(res(requires), res(ensures)) - case GpgpuLocalBarrier(requires, ensures) => GpgpuLocalBarrier(res(requires), res(ensures)) + case GpgpuBarrier(requires, ensures, specifier) => GpgpuBarrier(res(requires), res(ensures), specifier) case Havoc(loc) => Havoc(loc) case IndetBranch(branches) => IndetBranch(branches) case Inhale(assn) => Inhale(res(assn)) diff --git a/col/src/main/java/vct/col/feature/FeatureRainbow.scala b/col/src/main/java/vct/col/feature/FeatureRainbow.scala index 1d42197d92..076afeffab 100644 --- a/col/src/main/java/vct/col/feature/FeatureRainbow.scala +++ b/col/src/main/java/vct/col/feature/FeatureRainbow.scala @@ -171,6 +171,7 @@ class FeatureRainbow[G] { case node: PointerSubscript[G] => Pointers case node: PointerBlockLength[G] => Pointers case node: PointerBlockOffset[G] => Pointers + case node: PointerLength[G] => Pointers case node: Length[G] => Arrays case node: Size[G] => return Nil case node: Cons[G] => SugarCollectionOperator @@ -449,9 +450,9 @@ class FeatureRainbow[G] { case node: CGlobalDeclaration[G] => return Nil case node: CDeclarationStatement[G] => return Nil case node: CGoto[G] => return Nil - case node: GpgpuLocalBarrier[G] => return Nil - case node: GpgpuGlobalBarrier[G] => return Nil + case node: GpgpuBarrier[G] => return Nil case node: GpgpuAtomic[G] => return Nil + case node: GpuMemoryFence[G] => return Nil case node: CLocal[G] => return Nil case node: CInvocation[G] => return Nil case node: CStructAccess[G] => return Nil diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 8169ed6cc2..c5c7c6bf35 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -415,15 +415,11 @@ case class Printer(out: Appendable, statement(syntax(C -> phrase(decl.decl.specs, commas(decl.decl.inits.map(NodePhrase))))) case ref @ CGoto(label) => statement(syntax(C -> phrase("goto", space, Text(ref.ref.map(name).getOrElse(label))))) - case GpgpuLocalBarrier(requires, ensures) => + case GpgpuBarrier(requires, ensures, specifier) => statement(spec(clauses(requires, "requires"), clauses(ensures, "ensures")), syntax( Cuda -> phrase("__syncthreads();"), - OpenCL -> phrase("barrier(CLK_LOCAL_MEM_FENCE);"), + OpenCL -> phrase("barrier(", intersperse(" | ", specifier.map(NodePhrase)), ");"), )) - case GpgpuGlobalBarrier(requires, ensures) => - statement(spec(clauses(requires, "requires"), clauses(ensures, "ensures")), syntax( - OpenCL -> phrase("barrier(CLK_GLOBAL_MEM_FENCE);", - ))) case GpgpuAtomic(impl, before, after) => syntax(C -> statement("__vercors_atomic__", space, forceInline(impl), space, spec("with", space, before, space, "then", space, before))) @@ -1214,6 +1210,7 @@ case class Printer(out: Appendable, case CTypedefName(name) => say(name) case CSpecificationType(t) => say(t) case CTypeQualifierDeclarationSpecifier(typeQual) => say(typeQual) + case CExtern() => say("extern") case CKernel() => say("__kernel") case GPULocal() => say(syntax( Cuda -> phrase("__shared__"), @@ -1240,6 +1237,12 @@ case class Printer(out: Appendable, case None => say(node.decl) } + def printGpuMemoryFence(node: GpuMemoryFence[_]): Unit = node match { + case GpuLocalMemoryFence() => say("CLK_LOCAL_MEM_FENCE") + case GpuGlobalMemoryFence() => say("CLK_GLOBAL_MEM_FENCE") + case GpuZeroMemoryFence(i) => say(f"$i") + } + def printJavaModifier(node: JavaModifier[_]): Unit = node match { case JavaPublic() => say("public") case JavaProtected() => say("protected") @@ -1294,6 +1297,7 @@ case class Printer(out: Appendable, case node: CTypeQualifier[_] => printCTypeQualifier(node) case node: CPointer[_] => printCPointer(node) case node: CInit[_] => printCInit(node) + case node: GpuMemoryFence[_] => printGpuMemoryFence(node) case node: JavaModifier[_] => printJavaModifier(node) case node: JavaImport[_] => printJavaImport(node) case node: JavaName[_] => printJavaName(node) diff --git a/col/src/main/java/vct/col/rewrite/NonLatchingRewriter.scala b/col/src/main/java/vct/col/rewrite/NonLatchingRewriter.scala index 31c503f294..9ef22a469a 100644 --- a/col/src/main/java/vct/col/rewrite/NonLatchingRewriter.scala +++ b/col/src/main/java/vct/col/rewrite/NonLatchingRewriter.scala @@ -16,6 +16,7 @@ class NonLatchingRewriter[Pre, Post]() extends AbstractRewriter[Pre, Post] { override def dispatch(node: DecreasesClause[Pre]): DecreasesClause[Post] = rewriteDefault(node) override def dispatch(node: AccountedPredicate[Pre]): AccountedPredicate[Post] = rewriteDefault(node) override def dispatch(node: ApplicableContract[Pre]): ApplicableContract[Post] = rewriteDefault(node) + override def dispatch(node: GpuMemoryFence[Pre]): GpuMemoryFence[Post] = rewriteDefault(node) override def dispatch(node: LoopContract[Pre]): LoopContract[Post] = rewriteDefault(node) override def dispatch(parRegion: ParRegion[Pre]): ParRegion[Post] = rewriteDefault(parRegion) diff --git a/parsers/lib/antlr4/LangCParser.g4 b/parsers/lib/antlr4/LangCParser.g4 index 84c53b206f..005bda6755 100644 --- a/parsers/lib/antlr4/LangCParser.g4 +++ b/parsers/lib/antlr4/LangCParser.g4 @@ -542,8 +542,7 @@ blockItem | statement | valEmbedStatementBlock | {specLevel>0}? valStatement - | gpgpuLocalBarrier - | gpgpuGlobalBarrier + | gpgpuBarrier | gpgpuAtomicBlock ; diff --git a/parsers/lib/antlr4/LangGPGPULexer.g4 b/parsers/lib/antlr4/LangGPGPULexer.g4 index 4d9a96914c..caacc6218c 100644 --- a/parsers/lib/antlr4/LangGPGPULexer.g4 +++ b/parsers/lib/antlr4/LangGPGPULexer.g4 @@ -1,12 +1,12 @@ lexer grammar LangGPGPULexer; GPGPU_BARRIER: '__vercors_barrier__'; -GPGPU_LOCAL_BARRIER: '__vercors_local_barrier__'; -GPGPU_GLOBAL_BARRIER: '__vercors_global_barrier__'; +GPGPU_LOCAL_MEMORY_FENCE: '__vercors_local_mem_fence__'; +GPGPU_GLOBAL_MEMORY_FENCE: '__vercors_global_mem_fence__'; GPGPU_KERNEL: '__vercors_kernel__'; GPGPU_ATOMIC: '__vercors_atomic__'; -GPGPU_GLOBAL_MEMORY: '__global'; -GPGPU_LOCAL_MEMORY: '__local'; +GPGPU_GLOBAL_MEMORY: '__vercors_global_memory__'; +GPGPU_LOCAL_MEMORY: '__vercors_local_memory__'; GPGPU_CUDA_OPEN_EXEC_CONFIG: '<<<'; GPGPU_CUDA_CLOSE_EXEC_CONFIG: '>>>'; \ No newline at end of file diff --git a/parsers/lib/antlr4/LangGPGPUParser.g4 b/parsers/lib/antlr4/LangGPGPUParser.g4 index dde7906c9d..bee08e5546 100644 --- a/parsers/lib/antlr4/LangGPGPUParser.g4 +++ b/parsers/lib/antlr4/LangGPGPUParser.g4 @@ -1,11 +1,7 @@ parser grammar LangGPGPUParser; -gpgpuLocalBarrier - : valEmbedContract? GPGPU_BARRIER '(' GPGPU_LOCAL_BARRIER ')' - ; - -gpgpuGlobalBarrier - : valEmbedContract? GPGPU_BARRIER '(' GPGPU_GLOBAL_BARRIER ')' +gpgpuBarrier + : valEmbedContract? GPGPU_BARRIER '(' gpgpuMemFenceList ')' ; gpgpuCudaKernelInvocation @@ -16,6 +12,17 @@ gpgpuAtomicBlock : valEmbedWith? GPGPU_ATOMIC compoundStatement valEmbedThen? ; +gpgpuMemFenceList + : gpgpuMemFence + | gpgpuMemFenceList '|' gpgpuMemFence + ; + +gpgpuMemFence + : GPGPU_LOCAL_MEMORY_FENCE + | GPGPU_GLOBAL_MEMORY_FENCE + | Constant + ; + gpgpuKernelSpecifier: GPGPU_KERNEL; gpgpuLocalMemory: GPGPU_LOCAL_MEMORY; diff --git a/parsers/src/main/java/vct/parsers/transform/CToCol.scala b/parsers/src/main/java/vct/parsers/transform/CToCol.scala index 6ad0bc24c9..61b8f05718 100644 --- a/parsers/src/main/java/vct/parsers/transform/CToCol.scala +++ b/parsers/src/main/java/vct/parsers/transform/CToCol.scala @@ -211,16 +211,25 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case BlockItem1(stat) => convert(stat) case BlockItem2(embedStats) => convert(embedStats) case BlockItem3(embedStat) => convert(embedStat) - case BlockItem4(GpgpuLocalBarrier0(contract, _, _, _, _)) => withContract(contract, c => { - GpgpuLocalBarrier(AstBuildHelpers.foldStar[G](c.consume(c.requires)), AstBuildHelpers.foldStar[G](c.consume(c.ensures))) + case BlockItem4(GpgpuBarrier0(contract, _, _, specifier, _)) => withContract(contract, c => { + GpgpuBarrier(AstBuildHelpers.foldStar[G](c.consume(c.requires)), AstBuildHelpers.foldStar[G](c.consume(c.ensures)) + , convert(specifier)) }) - case BlockItem5(GpgpuGlobalBarrier0(contract, _, _, _, _)) => withContract(contract, c => { - GpgpuGlobalBarrier(AstBuildHelpers.foldStar[G](c.consume(c.requires)), AstBuildHelpers.foldStar[G](c.consume(c.ensures))) - }) - case BlockItem6(GpgpuAtomicBlock0(whiff, _, impl, den)) => + case BlockItem5(GpgpuAtomicBlock0(whiff, _, impl, den)) => GpgpuAtomic(convert(impl), whiff.map(convert(_)).getOrElse(Block(Nil)), den.map(convert(_)).getOrElse(Block(Nil))) } + def convert(implicit spec: GpgpuMemFenceListContext): Seq[GpuMemoryFence[G]] = spec match { + case GpgpuMemFenceList0(argument) => Seq(convert(argument)) + case GpgpuMemFenceList1(init, _, last) => convert(init) :+ convert(last) + } + + def convert(implicit spec: GpgpuMemFenceContext): GpuMemoryFence[G] = spec match { + case GpgpuMemFence0(_) => GpuLocalMemoryFence() + case GpgpuMemFence1(_) => GpuGlobalMemoryFence() + case GpgpuMemFence2(i) => GpuZeroMemoryFence(Integer.parseInt(i)) + } + def convert(implicit stat: LabeledStatementContext): Statement[G] = stat match { case LabeledStatement0(label, _, inner) => Label(new LabelDecl()(SourceNameOrigin(convert(label), originProvider(stat))), convert(inner)) @@ -958,10 +967,7 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case ValPointerIndex(_, _, ptr, _, idx, _, perm, _) => PermPointerIndex(convert(ptr), convert(idx), convert(perm)) case ValPointerBlockLength(_, _, ptr, _) => PointerBlockLength(convert(ptr))(blame(e)) case ValPointerBlockOffset(_, _, ptr, _) => PointerBlockOffset(convert(ptr))(blame(e)) - case ValPointerLength(_, _, ptr, _) => - val convertedPtr = convert(ptr) - val blameExpr = blame(e) - PointerBlockLength(convertedPtr)(blameExpr) - PointerBlockOffset(convertedPtr)(blameExpr) + case ValPointerLength(_, _, ptr, _) => PointerLength((convert(ptr))(blame(e))) } def convert(implicit e: ValPrimaryBinderContext): Expr[G] = e match { diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index 12bf25bd6a..7ad2ec181d 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -6,12 +6,16 @@ import vct.col.ast._ import vct.col.newrewrite.lang.LangSpecificToCol.NotAValue import vct.col.origin.{AbstractApplicable, InterpretedOriginVariable, Origin, PanicBlame} import vct.col.ref.Ref -import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CNameTarget, RefADTFunction, RefAxiomaticDataType, RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCudaBlockDim, RefCudaBlockIdx, RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, RefPredicate, RefProcedure, RefVariable, SpecInvocationTarget} +import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CNameTarget, RefADTFunction, RefAxiomaticDataType, + RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCudaBlockDim, RefCudaBlockIdx, + RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, + RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, + RefPredicate, RefProcedure, RefVariable, SpecInvocationTarget} import vct.col.rewrite.{Generation, Rewritten} import vct.col.util.SuccessionMap import vct.col.util.AstBuildHelpers._ import vct.parsers.Language -import vct.result.VerificationError.UserError +import vct.result.VerificationError.{Unreachable, UserError} import scala.collection.immutable.ListMap import scala.collection.mutable @@ -33,6 +37,27 @@ case object LangCToCol { override def text: String = s"The parameter `$param` has a type that is not allowed`outside of a GPU kernel." } + case class NotDynamicSharedMem(e: Expr[_]) extends UserError { + override def code: String = "notDynamicSharedMem" + override def text: String = s"The expression \\shared_mem_size(`$e`) is not referencing to a shared memory location." + } + + case class WrongBarrierSpecifier(b: GpgpuBarrier[_]) extends UserError { + override def code: String = "wrongBarrierSpecifier" + override def text: String = s"The barrier `$b` has incorrect specifiers." + } + + case class UnsupportedBarrierPermission(e: Node[_]) extends UserError { + override def code: String = "unsupportedBarrierPermission" + override def text: String = s"The permission `$e` is unsupported for barrier for now." + } + + case class RedistributingBarrier(v: CNameTarget[_], global: Boolean) extends UserError { + def memFence: String = if(global) "CLK_GLOBAL_MEM_FENCE" else "CLK_LOCAL_MEM_FENCE" + override def code: String = "redistributingBarrier" + override def text: String = s"Trying to redistribute the variable `$v` in a GPU barrier, but need the fence `$memFence` to do this." + } + case class CDoubleContracted(decl: CGlobalDeclaration[_], defn: CFunctionDefinition[_]) extends UserError { override def code: String = "multipleContracts" override def text: String = @@ -66,7 +91,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O private val staticSharedMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() private val globalMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() private var inKernel: Boolean = false - private var inKernelArgs: Boolean = false case class CudaIndexVariableOrigin(dim: RefCudaVecDim[_]) extends Origin { override def preferredName: String = @@ -86,15 +110,21 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O private def hasNoSharedMemNames(node: Node[Pre]): Boolean = { val allowedNonRefs = Set("get_local_id", "get_group_id", "get_local_size", "get_num_groups") - node match { - // SharedMemSize gets rewritten towards the length of a shared memory name, so is valid in global context - case _: SharedMemSize[Pre] => - case l: CLocal[Pre] => l.ref match { + + def varIsNotShared(l: CLocal[Pre]): Boolean = { + l.ref match { case Some(ref: RefCParam[Pre]) if dynamicSharedMemNames.contains(ref) || staticSharedMemNames.contains(ref) => return false - case None => if(!allowedNonRefs.contains(l.name)) ??? + case None => if (!allowedNonRefs.contains(l.name)) ??? case _ => } + true + } + + node match { + // SharedMemSize gets rewritten towards the length of a shared memory name, so is valid in global context + case _: SharedMemSize[Pre] => + case l: CLocal[Pre] => return varIsNotShared(l) case e => if(!e.subnodes.forall(hasNoSharedMemNames)) return false } true @@ -121,73 +151,74 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } } - def rewriteParam(cParam: CParam[Pre]): Unit = { + def rewriteGPUParam(cParam: CParam[Pre]): Unit = { cParam.drop() val o = InterpretedOriginVariable(cDeclToName(cParam.declarator), cParam.o) - var array = false - var pointer = false + var arrayOrPointer = false var global = false var shared = false var extern = false var innerType: Option[Type[Pre]] = None - var isDynamicShared = false - val cRef = RefCParam(cParam) cParam.specifiers.foreach{ case GPULocal() => shared = true case GPUGlobal() => global = true case CSpecificationType(TPointer(t)) => - pointer = true + arrayOrPointer = true innerType = Some(t) case CSpecificationType(TArray(t)) => - array = true + arrayOrPointer = true innerType = Some(t) case CExtern() => extern = true case _ => } - if((shared || global) && !inKernel) throw WrongGPUType(cParam) - if(inKernelArgs){ - language.get match { - case Language.C => - if(global && !shared && (pointer || array) && !extern) globalMemNames.add(cRef) - else if(shared && !global && (pointer || array) && !extern) { + language match { + case Some(Language.C) => + if(global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) + else if(shared && !global && arrayOrPointer && !extern) { dynamicSharedMemNames.add(cRef) // Create Var with array type here and return - isDynamicShared = true val v = new Variable[Post](TArray[Post]( rw.dispatch(innerType.get) )(o) )(o) cNameSuccessor(cRef) = v return } - else if(!shared && !global && !pointer && !array && !extern) () + else if(!shared && !global && !arrayOrPointer && !extern) () else throw WrongGPUKernelParameterType(cParam) - case Language.CUDA => - if(!global && !shared && (pointer || array) && !extern) globalMemNames.add(cRef) - else if(!shared && !global && !pointer && !array && !extern) () + case Some(Language.CUDA) => + if(!global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) + else if(!shared && !global && !arrayOrPointer && !extern) () else throw WrongGPUKernelParameterType(cParam) - case _ => ??? // Should not happen + case _ => throw Unreachable(f"The language '$language' should not have GPU kernels.'") } - } val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) cNameSuccessor(cRef) = v rw.variables.declare(v) } + def rewriteParam(cParam: CParam[Pre]): Unit = { + if(inKernel) return rewriteGPUParam(cParam) + cParam.specifiers.collectFirst{ + case GPULocal() => throw WrongGPUType(cParam) + case GPUGlobal() => throw WrongGPUType(cParam) + } + + cParam.drop() + val o = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) + + val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) + cNameSuccessor(RefCParam(cParam)) = v + rw.variables.declare(v) + } + def rewriteFunctionDef(func: CFunctionDefinition[Pre]): Unit = { func.drop() val info = C.getDeclaratorInfo(func.declarator) val returnType = func.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) - if (func.specs.collectFirst { case CKernel() => () }.nonEmpty){ - inKernel = true - inKernelArgs = true - } - val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 - inKernel = false - inKernelArgs = false val (contract, subs: Map[CParam[Pre], CParam[Pre]]) = func.ref match { case Some(RefCGlobalDeclaration(decl, idx)) if decl.decl.contract.nonEmpty => @@ -208,6 +239,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val namedO = InterpretedOriginVariable(cDeclToName(func.declarator), func.o) kernelProcedure(namedO, contract, info, Some(func.body)) } else { + val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 new Procedure[Post]( returnType = returnType, args = params, @@ -276,9 +308,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val gridDim = new CudaVec(RefCudaGridDim())(o) cudaCurrentBlockDim.having(blockDim) { cudaCurrentGridDim.having(gridDim) { - inKernelArgs = true val args = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 - inKernelArgs = false rw.variables.collect { dynamicSharedMemNames.foreach(d => rw.variables.declare(cNameSuccessor(d)) ) } val (sharedMemSizes, sharedMemInit: Seq[Statement[Post]]) = rw.variables.collect { var result: Seq[Statement[Post]] = Seq() @@ -435,11 +465,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val v = new Variable[Post](t)(init.o) cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v implicit val o: Origin = init.o - init.init match { - case Some(value) => - Block(Seq(LocalDecl(v), assignLocal(v.get, rw.dispatch(value)))) - case None => LocalDecl(v) - } + init.init + .map(value => Block(Seq(LocalDecl(v), assignLocal(v.get, rw.dispatch(value))))) + .getOrElse(LocalDecl(v)) } })(decl.o) } @@ -447,8 +475,30 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def rewriteGoto(goto: CGoto[Pre]): Statement[Post] = Goto[Post](rw.succ(goto.ref.getOrElse(???)))(goto.o) - def localBarrier(barrier: GpgpuLocalBarrier[Pre]): Statement[Post] = { + def gpuBarrier(barrier: GpgpuBarrier[Pre]): Statement[Post] = { implicit val o: Origin = barrier.o + + var globalFence = false + var localFence = false + + barrier.specifiers.foreach { + case GpuLocalMemoryFence() => localFence = true + case GpuGlobalMemoryFence() => globalFence = true + case GpuZeroMemoryFence(i) => if(i != 0) throw WrongBarrierSpecifier(barrier) + } + // TODO: create requirement that shared memory arrays are not NULL? + if(!globalFence || !localFence){ + val redist = permissionScanner(barrier) + if(!globalFence) + redist + .intersect(globalMemNames.toSet) + .foreach(v => throw RedistributingBarrier(v, global = true)) + if(!localFence) + redist + .intersect(dynamicSharedMemNames.union(dynamicSharedMemNames).toSet) + .foreach(v => throw RedistributingBarrier(v, global = false)) + } + ParBarrier[Post]( block = cudaCurrentBlock.top.ref, invs = Nil, @@ -458,15 +508,55 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O )(PanicBlame("more panic")) } - def globalBarrier(barrier: GpgpuGlobalBarrier[Pre]): Statement[Post] = { - implicit val o: Origin = barrier.o - ParBarrier[Post]( - block = cudaCurrentGrid.top.ref, - invs = Nil, - requires = rw.dispatch(barrier.requires), - ensures = rw.dispatch(barrier.ensures), - content = Block(Nil), - )(PanicBlame("more panic")) + def isPointer(t: Type[Pre]) : Boolean = t match { + case TPointer(_) => true + case CPrimitiveType(specs) => + specs.collectFirst{case CSpecificationType(TPointer(_)) => }.nonEmpty + case _ => false + } + + def isNumeric(t: Type[Pre]): Boolean = t match{ + case _: NumericType[Pre] => true + case CPrimitiveType(specs) => + specs.collectFirst{case CSpecificationType(_ : NumericType[Pre]) =>}.nonEmpty + case _ => false + } + + def searchNames(e: Expr[Pre], original: Node[Pre]): Seq[CNameTarget[Pre]] = e match { + case arr : CLocal[Pre] => Seq(arr.ref.get) + case PointerAdd(arr : CLocal[Pre], _) => Seq(arr.ref.get) + case AmbiguousSubscript(arr : CLocal[Pre], _) => Seq(arr.ref.get) + case AmbiguousPlus(l, r) if isPointer(l.t) && isNumeric(r.t) => searchNames(l, original) + case _ => throw UnsupportedBarrierPermission(original) + } + + def searchNames(loc: Location[Pre], original: Node[Pre]): Seq[CNameTarget[Pre]] = loc match { + case ArrayLocation(arr : CLocal[Pre], _) => Seq(arr.ref.get) + case PointerLocation(arr : CLocal[Pre]) => Seq(arr.ref.get) + case PointerLocation(PointerAdd(arr : CLocal[Pre], _)) => Seq(arr.ref.get) + case AmbiguousLocation(expr) => searchNames(expr, original) + case _ => throw UnsupportedBarrierPermission(original) + } + + def searchPermission(e: Node[Pre]): Seq[CNameTarget[Pre]] = { + e match { + case e: Expr[Pre] if e.t != TResource[Pre]() => return Seq() + case Perm(loc, _) => searchNames(loc, e) + case PointsTo(loc, _, _) => searchNames(loc, e) + case CurPerm(loc) => searchNames(loc, e) + case PermPointer(pointer, _, _) => searchNames(pointer, e) + case PermPointerIndex(pointer, _, _) => searchNames(pointer, e) + case _ => e.subnodes.flatMap(searchPermission) + } + } + + def permissionScanner(barrier: GpgpuBarrier[Pre]): Set[CNameTarget[Pre]] ={ + val pres = unfoldStar(barrier.requires).toSet + val posts = unfoldStar(barrier.ensures).toSet + val context = pres intersect posts + // Only scan the non context permissions + val nonContext = (pres union posts) diff context + nonContext.flatMap(searchPermission) } def result(ref: RefCFunctionDefinition[Pre])(implicit o: Origin): Expr[Post] = @@ -500,16 +590,19 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def deref(deref: CStructAccess[Pre]): Expr[Post] = { implicit val o: Origin = deref.o + + def getCuda(dim: RefCudaVecDim[Pre]): Expr[Post] = dim.vec match { + case RefCudaThreadIdx() => cudaCurrentThreadIdx.top.indices(dim).get + case RefCudaBlockIdx() => cudaCurrentBlockIdx.top.indices(dim).get + case RefCudaBlockDim() => cudaCurrentBlockDim.top.indices(dim).get + case RefCudaGridDim() => cudaCurrentGridDim.top.indices(dim).get + } + deref.ref.get match { case RefModelField(decl) => ModelDeref[Post](rw.currentThis.top, rw.succ(decl))(deref.blame) case BuiltinField(f) => rw.dispatch(f(deref.struct)) case target: SpecInvocationTarget[Pre] => ??? - case dim: RefCudaVecDim[Pre] => dim.vec match { - case RefCudaThreadIdx() => cudaCurrentThreadIdx.top.indices(dim).get - case RefCudaBlockIdx() => cudaCurrentBlockIdx.top.indices(dim).get - case RefCudaBlockDim() => cudaCurrentBlockDim.top.indices(dim).get - case RefCudaGridDim() => cudaCurrentGridDim.top.indices(dim).get - } + case dim: RefCudaVecDim[Pre] => getCuda(dim) } } @@ -541,10 +634,19 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O ProcedureInvocation[Post](cFunctionSuccessor.ref(ref.decl), args.map(rw.dispatch), Nil, Nil, givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) - case e@ RefCGlobalDeclaration(decls, initIdx) => - val arg = if(args.size == 1){ - args.head match { - case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt) + case e: RefCGlobalDeclaration[Pre] => globalInvocation(e, inv) + + } + } + + def globalInvocation(e: RefCGlobalDeclaration[Pre], inv: CInvocation[Pre]): Expr[Post] = { + val CInvocation(_, args, givenMap, yields) = inv + val RefCGlobalDeclaration(decls, initIdx) = e + implicit val o: Origin = inv.o + + val arg = if(args.size == 1){ + args.head match { + case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt) case _ => None } } else None @@ -553,10 +655,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O case ("get_group_id", Some(i)) => cudaCurrentBlockIdx.top.indices.values.toSeq.apply(i).get case ("get_local_size", Some(i)) => cudaCurrentBlockDim.top.indices.values.toSeq.apply(i).get case ("get_num_groups", Some(i)) => cudaCurrentGridDim.top.indices.values.toSeq.apply(i).get - case _ => ProcedureInvocation[Post](cFunctionDeclSuccessor.ref((decls, initIdx)), args.map(rw.dispatch), Nil, Nil, - givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, - yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) - } + case _ => ProcedureInvocation[Post](cFunctionDeclSuccessor.ref((decls, initIdx)), args.map(rw.dispatch), Nil, Nil, + givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, + yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) } } } diff --git a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala index ce34f232f4..4b72549431 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala @@ -88,8 +88,7 @@ case class LangSpecificToCol[Pre <: Generation](language: Option[Language]) exte case CDeclarationStatement(decl) => c.rewriteLocal(decl) case goto: CGoto[Pre] => c.rewriteGoto(goto) - case barrier: GpgpuLocalBarrier[Pre] => c.localBarrier(barrier) - case barrier: GpgpuGlobalBarrier[Pre] => c.globalBarrier(barrier) + case barrier: GpgpuBarrier[Pre] => c.gpuBarrier(barrier) case other => rewriteDefault(other) } diff --git a/src/main/universal/res/c/cuda.h b/src/main/universal/res/c/cuda.h index e8d0c0cc62..f2e46c3689 100644 --- a/src/main/universal/res/c/cuda.h +++ b/src/main/universal/res/c/cuda.h @@ -2,6 +2,7 @@ #define CUDA_H #define __global__ __vercors_kernel__ +#define __shared__ __vercors_local_memory__ #define bool _Bool @@ -45,7 +46,7 @@ extern /*@ pure @*/ int get_enqueued_num_sub_groups (); // extern /*@ pure @*/ int get_sub_group_id (); // Sub-group ID -#define __syncthreads() __vercors_barrier__(__vercors_local_barrier__) +#define __syncthreads() __vercors_barrier__(__vercors_local_mem_fence__ | __vercors_global_mem_fence__) extern /*@ pure @*/ int get_sub_group_local_id (); // Unique work-item ID diff --git a/src/main/universal/res/c/opencl.h b/src/main/universal/res/c/opencl.h index 8ec3733f8c..a5e6f1fae2 100644 --- a/src/main/universal/res/c/opencl.h +++ b/src/main/universal/res/c/opencl.h @@ -3,11 +3,16 @@ #define __kernel __vercors_kernel__ -#define CLK_GLOBAL_MEM_FENCE __vercors_global_barrier__ -#define CLK_LOCAL_MEM_FENCE __vercors_local_barrier__ +#define CLK_GLOBAL_MEM_FENCE __vercors_global_mem_fence__ +#define CLK_LOCAL_MEM_FENCE __vercors_local_mem_fence__ #define barrier(locality) __vercors_barrier__(locality) +#define __global __vercors_global_memory__ +#define global __vercors_global_memory__ +#define __local __vercors_local_memory__ +#define local __vercors_local_memory__ + #define bool _Bool extern /*@ pure @*/ int get_work_dim(); // Number of dimensions in use From a9c4e99db3aadb54aa2cbc6314408ebf89ad99a5 Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 20 Sep 2022 15:43:58 +0200 Subject: [PATCH 22/25] Changes based on pull request review --- col/src/main/java/vct/col/ast/Node.scala | 1 + .../expr/heap/read/PointerLengthtImpl.scala | 7 + .../vct/col/coerce/CoercingRewriter.scala | 2 + col/src/main/java/vct/col/print/Printer.scala | 16 ++ .../java/vct/col/util/AstBuildHelpers.scala | 11 +- .../java/vct/parsers/transform/CToCol.scala | 2 +- .../col/newrewrite/ApplyTermRewriter.scala | 4 +- .../java/vct/col/newrewrite/ImportADT.scala | 2 + .../vct/col/newrewrite/ParBlockEncoder.scala | 116 ++++++----- .../vct/col/newrewrite/lang/LangCToCol.scala | 181 +++++++++--------- .../newrewrite/lang/LangSpecificToCol.scala | 1 + 11 files changed, 175 insertions(+), 168 deletions(-) create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/expr/heap/read/PointerLengthtImpl.scala diff --git a/col/src/main/java/vct/col/ast/Node.scala b/col/src/main/java/vct/col/ast/Node.scala index 2c1b8a434b..df3a5e9cc1 100644 --- a/col/src/main/java/vct/col/ast/Node.scala +++ b/col/src/main/java/vct/col/ast/Node.scala @@ -536,6 +536,7 @@ final case class Length[G](arr: Expr[G])(val blame: Blame[ArrayNull])(implicit v final case class Size[G](obj: Expr[G])(implicit val o: Origin) extends Expr[G] with SizeImpl[G] final case class PointerBlockLength[G](pointer: Expr[G])(val blame: Blame[PointerNull])(implicit val o: Origin) extends Expr[G] with PointerBlockLengthImpl[G] final case class PointerBlockOffset[G](pointer: Expr[G])(val blame: Blame[PointerNull])(implicit val o: Origin) extends Expr[G] with PointerBlockOffsetImpl[G] +final case class PointerLength[G](pointer: Expr[G])(val blame: Blame[PointerNull])(implicit val o: Origin) extends Expr[G] with PointerLengthtImpl[G] final case class SharedMemSize[G](pointer: Expr[G])(implicit val o: Origin) extends Expr[G] with SharedMemSizeImpl[G] final case class Cons[G](x: Expr[G], xs: Expr[G])(implicit val o: Origin) extends Expr[G] with ConsImpl[G] diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/heap/read/PointerLengthtImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/heap/read/PointerLengthtImpl.scala new file mode 100644 index 0000000000..822535863d --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/heap/read/PointerLengthtImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.expr.heap.read + +import vct.col.ast.{PointerLength, TInt, Type} + +trait PointerLengthtImpl[G] { this: PointerLength[G] => + override def t: Type[G] = TInt() +} diff --git a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala index b2242f60d9..49ab9abffc 100644 --- a/col/src/main/java/vct/col/coerce/CoercingRewriter.scala +++ b/col/src/main/java/vct/col/coerce/CoercingRewriter.scala @@ -915,6 +915,8 @@ abstract class CoercingRewriter[Pre <: Generation]() extends Rewriter[Pre] with PointerBlockLength(pointer(p)._1)(len.blame) case off @ PointerBlockOffset(p) => PointerBlockOffset(pointer(p)._1)(off.blame) + case len @ PointerLength(p) => + PointerLength(pointer(p)._1)(len.blame) case get @ PointerSubscript(p, index) => PointerSubscript(pointer(p)._1, int(index))(get.blame) case PointsTo(loc, perm, value) => diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index c5c7c6bf35..cb6508e93e 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -853,6 +853,8 @@ case class Printer(out: Appendable, (phrase("pointer_block(", pointer ,")"), 100) case PointerBlockLength(pointer) => (phrase("block_length(", pointer ,")"), 100) + case PointerLength(pointer) => + (phrase("pointer_length(", pointer ,")"), 100) case Cons(x, xs) => (phrase(bind(87, x), space, "::", space, assoc(87, xs)), 87) case Head(xs) => @@ -1269,6 +1271,19 @@ case class Printer(out: Appendable, def printJavaName(node: JavaName[_]): Unit = say(node.names.mkString(".")) + def printLocation(loc: Location[_]): Unit = loc match { +// case FieldLocation(obj, field) => +// case ModelLocation(obj, field) => +// case SilverFieldLocation(obj, field) => + case ArrayLocation(array, subscript) => (phrase(assoc(100, array), "[", subscript, "]"), 100) + case PointerLocation(pointer) => say(pointer) +// case PredicateLocation(predicate, args) => +// case InstancePredicateLocation(predicate, obj, args) => + case AmbiguousLocation(expr) => say(expr) + case x => + say(s"Unknown node type in Printer.scala: ${x.getClass.getCanonicalName}") + } + def printVerification(node: Verification[_]): Unit = node.tasks.foreach(print) @@ -1301,6 +1316,7 @@ case class Printer(out: Appendable, case node: JavaModifier[_] => printJavaModifier(node) case node: JavaImport[_] => printJavaImport(node) case node: JavaName[_] => printJavaName(node) + case node : Location[_] => printLocation(node) case node: Verification[_] => printVerification(node) case node: VerificationContext[_] => printVerificationContext(node) case x => diff --git a/col/src/main/java/vct/col/util/AstBuildHelpers.scala b/col/src/main/java/vct/col/util/AstBuildHelpers.scala index d58ca6846c..204fb5b53d 100644 --- a/col/src/main/java/vct/col/util/AstBuildHelpers.scala +++ b/col/src/main/java/vct/col/util/AstBuildHelpers.scala @@ -316,13 +316,10 @@ object AstBuildHelpers { exprs.reduceOption(And(_, _)).getOrElse(tt) def unfoldPredicate[G](p: AccountedPredicate[G]): Seq[Expr[G]] = p match { - case UnitAccountedPredicate(pred) => unfoldStar(pred) + case UnitAccountedPredicate(pred) => Seq(pred) case SplitAccountedPredicate(left, right) => unfoldPredicate(left) ++ unfoldPredicate(right) } - def filterPredicate[G](p: AccountedPredicate[G], f: Expr[G] => Boolean): AccountedPredicate[G] = - foldPredicate(unfoldPredicate(p).filter(f))(p.o) - def mapPredicate[G1, G2](p: AccountedPredicate[G1], f: Expr[G1] => Expr[G2]): AccountedPredicate[G2] = p match { case UnitAccountedPredicate(pred) => UnitAccountedPredicate(f(pred))(p.o) case SplitAccountedPredicate(left, right) => SplitAccountedPredicate(mapPredicate(left, f), mapPredicate(right, f))(p.o) @@ -357,12 +354,6 @@ object AstBuildHelpers { case SplitAccountedPredicate(left, right) => Star(foldStar(left), foldStar(right)) } - def foldPredicate[G](exprs: Seq[Expr[G]])(implicit o: Origin): AccountedPredicate[G] = - exprs - .map(e => UnitAccountedPredicate(e)(e.o)) - .reduceOption[AccountedPredicate[G]](SplitAccountedPredicate(_, _)) - .getOrElse(UnitAccountedPredicate(tt)) - def foldOr[G](exprs: Seq[Expr[G]])(implicit o: Origin): Expr[G] = exprs.reduceOption(Or(_, _)).getOrElse(ff) } diff --git a/parsers/src/main/java/vct/parsers/transform/CToCol.scala b/parsers/src/main/java/vct/parsers/transform/CToCol.scala index 61b8f05718..b31c13f296 100644 --- a/parsers/src/main/java/vct/parsers/transform/CToCol.scala +++ b/parsers/src/main/java/vct/parsers/transform/CToCol.scala @@ -967,7 +967,7 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b case ValPointerIndex(_, _, ptr, _, idx, _, perm, _) => PermPointerIndex(convert(ptr), convert(idx), convert(perm)) case ValPointerBlockLength(_, _, ptr, _) => PointerBlockLength(convert(ptr))(blame(e)) case ValPointerBlockOffset(_, _, ptr, _) => PointerBlockOffset(convert(ptr))(blame(e)) - case ValPointerLength(_, _, ptr, _) => PointerLength((convert(ptr))(blame(e))) + case ValPointerLength(_, _, ptr, _) => PointerLength(convert(ptr))(blame(e)) } def convert(implicit e: ValPrimaryBinderContext): Expr[G] = e match { diff --git a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala index 73bcc6c450..fb9cb23db4 100644 --- a/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala +++ b/src/main/java/vct/col/newrewrite/ApplyTermRewriter.scala @@ -355,9 +355,7 @@ case class ApplyTermRewriter[Rule, Pre <: Generation] override def dispatch(e: Expr[Pre]): Expr[Post] = if(simplificationDone.nonEmpty) rewriteDefault(e) else simplificationDone.having(()) { - //TODO: This progress "nextPhase" prints out a lot of garbage when having to much foralls (for instance when verifying CUDA/OpenCl programs - // Disabled it for now, probably a bug? - //Progress.nextPhase(s"`$e`") + Progress.nextPhase(s"`$e`") countApply = 0 countSuccess = 0 currentExpr = e diff --git a/src/main/java/vct/col/newrewrite/ImportADT.scala b/src/main/java/vct/col/newrewrite/ImportADT.scala index a86cc0daf9..f24268634b 100644 --- a/src/main/java/vct/col/newrewrite/ImportADT.scala +++ b/src/main/java/vct/col/newrewrite/ImportADT.scala @@ -578,6 +578,8 @@ case class ImportADT[Pre <: Generation](importer: ImportADTImporter) extends Coe typeArgs = Seq(TAxiomatic[Post](pointerAdt.ref, Nil)), Nil, Nil, )(NoContext(PointerNullPreconditionFailed(off.blame, pointer)))) ) + case len @ PointerLength(pointer) => + dispatch(Minus(PointerBlockLength(pointer)(len.blame), PointerBlockOffset(pointer)(len.blame))) case other => rewriteDefault(other) } } diff --git a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala index f5466ff6b9..99d0d55c7a 100644 --- a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala +++ b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala @@ -104,20 +104,6 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { def from(v: Variable[Pre]): Expr[Post] = range(v)._1 def to(v: Variable[Pre]): Expr[Post] = range(v)._2 -// def quantify(block: ParBlock[Pre], expr: Expr[Pre])(implicit o: Origin): Expr[Post] = { -// val quantVars = block.iters.map(_.variable).map(v => v -> new Variable[Pre](v.t)(v.o)).toMap -// val body = Substitute(quantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]]).dispatch(expr) -// block.iters.foldLeft(dispatch(body))((body, iter) => { -// val v = quantVars(iter.variable) -// Starall[Post]( -// Seq(variables.dispatch(v)), -// Nil, -// SeqMember(Local[Post](succ(v)), Range(from(iter.variable), to(iter.variable))) ==> body -//// (from(iter.variable) <= Local[Post](succ(v)) && Local[Post](succ(v)) < to(iter.variable)) ==> body -// )(ParBlockNotInjective(block, expr)) -// }) -// } - def depVars[G](bindings: Set[Variable[G]], e: Expr[G]): Set[Variable[G]] = { val result: mutable.Set[Variable[G]] = mutable.Set() e.transSubnodes.foreach { @@ -135,15 +121,17 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { val quantVars = if (nonEmpty) depVars(vars, e) else vars val nonQuantVars = vars.diff(quantVars) val newQuantVars = quantVars.map(v => v -> new Variable[Pre](v.t)(v.o)).toMap - var body = dispatch(Substitute(newQuantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]]).dispatch(e)) - body = if (nonPermissionExpr(e)) { - body - } else { + var body = dispatch(Substitute( + newQuantVars.map { case (l, r) => Local[Pre](l.ref) -> Local[Pre](r.ref) }.toMap[Expr[Pre], Expr[Pre]] + ).dispatch(e)) + body = if (e.t == TResource[Pre]()) { // Scale the body if it contains permissions nonQuantVars.foldLeft(body)((body, iter) => { val scale = to(iter) - from(iter) - Scale(const[Post](1) /:/ scale, body)(PanicBlame("Par block was checked to be non-empty")) + Scale(scale, body)(PanicBlame("Par block was checked to be non-empty")) }) + } else { + body } // Result, quantify over all the relevant variables quantVars.foldLeft(body)((body, iter) => { @@ -173,15 +161,40 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { quantify(block, block.context_everywhere &* block.ensures, nonEmpty) } - def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Unit = region match { - case ParParallel(regions) => regions.foreach(ranges(_, rangeValues)) - case ParSequential(regions) => regions.foreach(ranges(_, rangeValues)) + def constantExpression(e: Expr[_]): Boolean = e match { + case _: Constant[_, _] => true + case op: BinExpr[_] => constantExpression(op.left) && constantExpression(op.right) + /* TODO: This is the hard part, do we do an analysis on locals to see if they are never changed in the body of a a + * parallel block? We need this, otherwise we can never hope to rewrite nested foralls, since the bounds will + * depend on the hi_var en var_low values. And on that the nested quantifier pass will get stuck. + */ + case _: Local[_] => false + case _ => false + } + + def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Statement[Post] = region match { + case ParParallel(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) + case ParSequential(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) case block @ ParBlock(decl, iters, _, _, _, _) => decl.drop() blockDecl(decl) = block - iters.foreach { v => - rangeValues(v.variable) = (dispatch(v.from), dispatch(v.to)) - } + Block(iters.foldLeft(Seq[Statement[Post]]()) { case(res,v) => + implicit val o: Origin = v.o + val from = dispatch(v.from) + val to = dispatch(v.to) + if(constantExpression(from) && constantExpression (to)){ + rangeValues(v.variable) = (from, to) + res + } else { + val lo = variables.declare(new Variable[Post](TInt())(LowEvalOrigin(v))) + val hi = variables.declare(new Variable[Post](TInt())(HighEvalOrigin(v))) + rangeValues(v.variable) = (lo.get, hi.get) + res ++ Seq( + assignLocal(lo.get, dispatch(v.from)), + assignLocal(hi.get, dispatch(v.to)), + ) + } + })(region.o) } def execute(region: ParRegion[Pre], nonEmpty: Boolean): Statement[Post] = { @@ -207,9 +220,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { Block(iters.map { v => dispatch(v.variable) assignLocal(Local[Post](succ(v.variable)), IndeterminateInteger(from(v.variable), to(v.variable))) - }) - } - + }) } Scope(vars, Block(Seq( init, FramedProof( @@ -220,31 +231,32 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { ))) } - override def dispatch(stat: Statement[Pre]): Statement[Rewritten[Pre]] = stat match { - case ParStatement(region) => - val (isSingleBlock: Boolean, iters: Option[Seq[IterVariable[Pre]]]) = region match { - case pb : ParBlock[Pre] => (true, Some(pb.iters)) - case _ => (false, None) - } + def isParBlock(stat: ParRegion[Pre]): Boolean = stat match { + case _: ParBlock[Pre] => true + case _ => false + } + override def dispatch(stat: Statement[Pre]): Statement[Post] = stat match { + case ParStatement(region) => + val isSingleBlock: Boolean = isParBlock(region) implicit val o: Origin = stat.o - val rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])] = mutable.Map() - - ranges(region, rangeValues) + val (vars, evalRanges) = variables.collect { + ranges(region, rangeValues) + } currentRanges.having(rangeValues.toMap) { - var res: Statement[Post] = Block(Seq( + var res: Statement[Post] = IndetBranch(Seq( execute(region, isSingleBlock), Block(Seq(check(region), Inhale(ff))) - )), - )) + )) if(isSingleBlock){ val condition: Expr[Post] = foldAnd(rangeValues.values.map{case (low, hi) => low < hi}) res = Branch(Seq((condition, res))) } - res + Scope(vars, + Block(Seq(evalRanges,res))) } case inv @ ParInvariant(decl, dependentInvariant, body) => @@ -302,8 +314,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { } override def dispatch(e: Expr[Pre]): Expr[Rewritten[Pre]] = e match { -// case ScaleByParBlock(Ref(decl), res) if !nonPermissionExpr(res) => - case ScaleByParBlock(Ref(decl), res) => + case ScaleByParBlock(Ref(decl), res) if e.t == TResource[Pre]() => implicit val o: Origin = e.o val block = blockDecl(decl) block.iters.foldLeft(dispatch(res)) { @@ -311,24 +322,7 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { val scale = to(v.variable) - from(v.variable) Implies(scale > const(0), Scale(const[Post](1) /:/ scale, res)(PanicBlame("framed positive"))) } -// case ScaleByParBlock(Ref(_), res) => dispatch(res) + case ScaleByParBlock(Ref(_), res) => dispatch(res) case other => rewriteDefault(other) } - - def nonPermissionExpr[G](e: Node[G]): Boolean = e match { - case _: Perm[G] => false - case _: PointsTo[G] => false - case _: PermPointer[G] => false - case _: PermPointerIndex[G] => false - case _: ModelState[G] => false - case _: ModelSplit[G] => false - case _: ModelMerge[G] => false - case _: ModelChoose[G] => false - case _: ModelPerm[G] => false - case _: ActionPerm[G] => false - case _: PredicateApply[G] => false - case _: InstancePredicateApply[G] => false - case _: CoalesceInstancePredicateApply[G] => false - case other => other.subnodes.forall(nonPermissionExpr) - } } diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index 7ad2ec181d..d607dbef51 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -134,26 +134,21 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O cUnit.declarations.foreach(rw.dispatch) } - def cDeclToName(cDecl: CDeclarator[Pre]): String = cDecl match { - case CPointerDeclarator(_, inner) => cDeclToName(inner) - case CArrayDeclarator(_, _, inner) => cDeclToName(inner) - case CTypedFunctionDeclarator(_, _, inner) => cDeclToName(inner) - case CAnonymousFunctionDeclarator(_, inner) => cDeclToName(inner) - case CName(name: String) => name - } - def sharedSize(shared: SharedMemSize[Pre]): Expr[Post] = { val SharedMemSize(pointer) = shared - pointer match { - case loc: CLocal[Pre] => Local[Post](dynamicSharedMemLengthVar(loc.ref.get).ref)(shared.o) - case _ => ??? + val res = pointer match { + case loc: CLocal[Pre] => + loc.ref flatMap { dynamicSharedMemLengthVar.get } map {v => Local[Post](v.ref)(shared.o)} + case _ => None } + res.getOrElse(throw NotDynamicSharedMem(pointer)) } def rewriteGPUParam(cParam: CParam[Pre]): Unit = { cParam.drop() - val o = InterpretedOriginVariable(cDeclToName(cParam.declarator), cParam.o) + val varO = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) + implicit val o: Origin = cParam.o var arrayOrPointer = false var global = false @@ -180,22 +175,23 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O case Some(Language.C) => if(global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) else if(shared && !global && arrayOrPointer && !extern) { - dynamicSharedMemNames.add(cRef) - // Create Var with array type here and return - val v = new Variable[Post](TArray[Post]( rw.dispatch(innerType.get) )(o) )(o) - cNameSuccessor(cRef) = v - return - } + dynamicSharedMemNames.add(cRef) + // Create Var with array type here and return + val v = new Variable[Post](TArray[Post](rw.dispatch(innerType.get)) )(varO) + cNameSuccessor(cRef) = v + return + } else if(!shared && !global && !arrayOrPointer && !extern) () - else throw WrongGPUKernelParameterType(cParam) + else throw WrongGPUKernelParameterType(cParam) case Some(Language.CUDA) => if(!global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) else if(!shared && !global && !arrayOrPointer && !extern) () - else throw WrongGPUKernelParameterType(cParam) + else throw WrongGPUKernelParameterType(cParam) case _ => throw Unreachable(f"The language '$language' should not have GPU kernels.'") - } + } - val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) + val v = new Variable[Post](cParam.specifiers.collectFirst + { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(varO) cNameSuccessor(cRef) = v rw.variables.declare(v) } @@ -208,9 +204,10 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } cParam.drop() - val o = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) + val varO = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) - val v = new Variable[Post](cParam.specifiers.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(o) + val v = new Variable[Post](cParam.specifiers.collectFirst + { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(varO) cNameSuccessor(RefCParam(cParam)) = v rw.variables.declare(v) } @@ -236,7 +233,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O cCurrentDefinitionParamSubstitutions.having(subs) { rw.globalDeclarations.declare( if (func.specs.collectFirst { case CKernel() => () }.nonEmpty) { - val namedO = InterpretedOriginVariable(cDeclToName(func.declarator), func.o) + val namedO = InterpretedOriginVariable(C.getDeclaratorInfo(func.declarator).name, func.o) kernelProcedure(namedO, contract, info, Some(func.body)) } else { val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 @@ -268,8 +265,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def allOneExpr(idx: CudaVec, dim: CudaVec, e: Expr[Post]): Expr[Post] = { implicit val o: Origin = e.o val vars = findVars(e) - val filteredIdx = idx.indices.values.zip(dim.indices.values).filter{ case (i, _) => vars.contains(i)} - val otherIdx = idx.indices.values.zip(dim.indices.values).filterNot{ case (i, _) => vars.contains(i)} + val (filteredIdx, otherIdx) = idx.indices.values.zip(dim.indices.values).partition{ case (i, _) => vars.contains(i)} val body = otherIdx.map{case (_,range) => range}.foldLeft(e)((newE, scaleFactor) => Scale(scaleFactor.get, newE)(PanicBlame("Framed positive")) ) if(filteredIdx.isEmpty){ @@ -286,7 +282,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def findVars(e: Node[Post], vars: Set[Variable[Post]] = Set()): Set[Variable[Post]] = e match { case Local(ref) => vars + ref.decl - case _ => e.subnodes.foldLeft(vars)( (set, node) => set ++ findVars(node) ) + case _ => e.transSubnodes.collect { case Local(Ref(v)) => v}.toSet } def allThreadsInBlock(e: Expr[Pre]): Expr[Post] = { @@ -315,30 +311,28 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O dynamicSharedMemNames.foreach(d => { implicit val o: Origin = d.decl.o - val v = new Variable[Post](TInt()) + val varO: Origin = InterpretedOriginVariable(s"${C.getDeclaratorInfo(d.decl.declarator).name}_size", d.decl.o) + val v = new Variable[Post](TInt())(varO) dynamicSharedMemLengthVar(d) = v rw.variables.declare(v) val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) , NewArray[Post](v.t, Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) - result = result ++ Seq(decl, assign) + result ++= Seq(decl, assign) }) result } val newArgs = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ args ++ sharedMemSizes - val newGivenArgs = rw.variables.dispatch(contract.givenArgs) val newYieldsArgs = rw.variables.dispatch(contract.yieldsArgs) // We add the requirement that a GPU kernel must always have threads (non zero block or grid dimensions) - val nonZeroThreadsSeq: Seq[Expr[Post]] - = (blockDim.indices.values ++ gridDim.indices.values).map( v => Less(IntegerValue(0)(o), v.get(o))(o) ).toSeq - val nonZeroThreads = foldStar(nonZeroThreadsSeq)(o) - - val nonZeroThreadsPred = - nonZeroThreadsSeq - .map(UnitAccountedPredicate(_)(o)) - .reduceLeft[AccountedPredicate[Post]](SplitAccountedPredicate(_,_)(o)) + val nonZeroThreads: Expr[Post] = foldStar( + (blockDim.indices.values ++ gridDim.indices.values) + .map( v => Less(IntegerValue(0)(o), v.get(o))(o)) + .toSeq)(o) + val UnitAccountedPredicate(contractRequires: Expr[Pre]) = contract.requires + val UnitAccountedPredicate(contractEnsures: Expr[Pre]) = contract.ensures val parBody = body.map(impl => { implicit val o: Origin = impl.o @@ -360,8 +354,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O }.toSeq, // Context is already inherited context_everywhere = Star(nonZeroThreads, rw.dispatch(contract.contextEverywhere)), - requires = rw.dispatch(foldStar(unfoldPredicate(contract.requires))), - ensures = rw.dispatch(foldStar(unfoldPredicate(contract.ensures))), + requires = rw.dispatch(contractRequires), + ensures = rw.dispatch(contractEnsures), content = rw.dispatch(impl), )(PanicBlame("where blame?"))) ParStatement(ParBlock( @@ -371,20 +365,20 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O }.toSeq, // Context is added to requires and ensures here context_everywhere = tt, - requires = Star( - Star(nonZeroThreads, + requires = Star(nonZeroThreads, + Star(contextBlock, foldStar( - unfoldPredicate(contract.requires) + unfoldStar(contractRequires) .filter(hasNoSharedMemNames) .map(allThreadsInBlock))) - , contextBlock), - ensures = Star( - Star(nonZeroThreads, + ), + ensures = Star(nonZeroThreads, + Star(contextBlock, foldStar( - unfoldPredicate(contract.ensures) + unfoldStar(contractEnsures) .filter(hasNoSharedMemNames) .map(allThreadsInBlock)) ) - , contextBlock), + ), // Add shared memory initialization before beginning of inner parallel block content = Block[Post](sharedMemInit ++ Seq(innerContent)) )(PanicBlame("where blame?"))) @@ -394,24 +388,23 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } }) - val gridContext: AccountedPredicate[Post] = - foldPredicate(unfoldStar(contract.contextEverywhere) + val gridContext: Expr[Post] = + foldStar(unfoldStar(contract.contextEverywhere) .filter(hasNoSharedMemNames) - .map(allThreadsInBlock))(o) - new Procedure[Post]( + // TODO: unsure if we need to map over all threads. context anywhere feels like it should apply as general knowledge unrelated to thread ids + .map(allThreadsInGrid))(o) + val requires: Expr[Post] = foldStar(Seq(gridContext, nonZeroThreads) + ++ unfoldStar(contractRequires).filter(hasNoSharedMemNames).map(allThreadsInGrid) )(o) + val ensures: Expr[Post] = foldStar(Seq(gridContext, nonZeroThreads) + ++ unfoldStar(contractEnsures).filter(hasNoSharedMemNames).map(allThreadsInGrid) )(o) + val result = new Procedure[Post]( returnType = TVoid(), args = newArgs, outArgs = Nil, typeArgs = Nil, body = parBody, contract = ApplicableContract( - SplitAccountedPredicate( - SplitAccountedPredicate(nonZeroThreadsPred, - mapPredicate(filterPredicate(contract.requires, hasNoSharedMemNames), allThreadsInGrid))(o), - gridContext)(o), - SplitAccountedPredicate( - SplitAccountedPredicate(nonZeroThreadsPred, - mapPredicate(filterPredicate(contract.ensures, hasNoSharedMemNames), allThreadsInGrid))(o), - gridContext)(o), + UnitAccountedPredicate(requires)(o), + UnitAccountedPredicate(ensures)(o), // Context everywhere is already passed down in the body tt, contract.signals.map(rw.dispatch), @@ -420,35 +413,36 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O contract.decreases.map(rw.dispatch), )(contract.blame)(contract.o) )(AbstractApplicable)(o) + inKernel = false + + result } } } def rewriteGlobalDecl(decl: CGlobalDeclaration[Pre]): Unit = { val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) - for((init, idx) <- decl.decl.inits.zipWithIndex) { - if(init.ref.isEmpty) { - // Otherwise, skip the declaration: the definition is used instead. - val info = C.getDeclaratorInfo(init.decl) - info.params match { - case Some(params) => - cFunctionDeclSuccessor((decl, idx)) = rw.globalDeclarations.declare( - if(decl.decl.specs.collectFirst { case CKernel() => () }.nonEmpty) { - kernelProcedure(init.o, decl.decl.contract, info, None) - } else { - new Procedure[Post]( - returnType = t, - args = rw.variables.collect { params.foreach(rw.dispatch) }._1, - outArgs = Nil, - typeArgs = Nil, - body = None, - contract = rw.dispatch(decl.decl.contract), - )(AbstractApplicable)(init.o) - } - ) - case None => - throw CGlobalStateNotSupported(init) - } + for((init, idx) <- decl.decl.inits.zipWithIndex if init.ref.isEmpty) { + // If the reference is empty , skip the declaration: the definition is used instead. + val info = C.getDeclaratorInfo(init.decl) + info.params match { + case Some(params) => + cFunctionDeclSuccessor((decl, idx)) = rw.globalDeclarations.declare( + if(decl.decl.specs.collectFirst { case CKernel() => () }.nonEmpty) { + kernelProcedure(init.o, decl.decl.contract, info, None) + } else { + new Procedure[Post]( + returnType = t, + args = rw.variables.collect { params.foreach(rw.dispatch) }._1, + outArgs = Nil, + typeArgs = Nil, + body = None, + contract = rw.dispatch(decl.decl.contract), + )(AbstractApplicable)(init.o) + } + ) + case None => + throw CGlobalStateNotSupported(init) } } } @@ -462,7 +456,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O info.params match { case Some(params) => ??? case None => - val v = new Variable[Post](t)(init.o) + val varO: Origin = InterpretedOriginVariable(C.getDeclaratorInfo(init.decl).name, init.o) + val v = new Variable[Post](t)(varO) cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v implicit val o: Origin = init.o init.init @@ -647,14 +642,14 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val arg = if(args.size == 1){ args.head match { case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt) - case _ => None - } - } else None - (e.name, arg) match { - case ("get_local_id", Some(i)) => cudaCurrentThreadIdx.top.indices.values.toSeq.apply(i).get - case ("get_group_id", Some(i)) => cudaCurrentBlockIdx.top.indices.values.toSeq.apply(i).get - case ("get_local_size", Some(i)) => cudaCurrentBlockDim.top.indices.values.toSeq.apply(i).get - case ("get_num_groups", Some(i)) => cudaCurrentGridDim.top.indices.values.toSeq.apply(i).get + case _ => None + } + } else None + (e.name, arg) match { + case ("get_local_id", Some(i)) => cudaCurrentThreadIdx.top.indices.values.toSeq.apply(i).get + case ("get_group_id", Some(i)) => cudaCurrentBlockIdx.top.indices.values.toSeq.apply(i).get + case ("get_local_size", Some(i)) => cudaCurrentBlockDim.top.indices.values.toSeq.apply(i).get + case ("get_num_groups", Some(i)) => cudaCurrentGridDim.top.indices.values.toSeq.apply(i).get case _ => ProcedureInvocation[Post](cFunctionDeclSuccessor.ref((decls, initIdx)), args.map(rw.dispatch), Nil, Nil, givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, yields.map { case (Ref(e), Ref(v)) => (rw.succ(e), rw.succ(v)) })(inv.blame) diff --git a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala index 4b72549431..49e1e0d853 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala @@ -102,6 +102,7 @@ case class LangSpecificToCol[Pre <: Generation](language: Option[Language]) exte case RefFunction(decl) => Result[Post](anySucc(decl)) case RefProcedure(decl) => Result[Post](anySucc(decl)) case RefJavaMethod(decl) => Result[Post](java.javaMethod.ref(decl)) + case RefJavaAnnotationMethod(decL) => ??? case RefInstanceFunction(decl) => Result[Post](anySucc(decl)) case RefInstanceMethod(decl) => Result[Post](anySucc(decl)) } From b348bad99ce736d41aeec4c42fc19ad7c8576050 Mon Sep 17 00:00:00 2001 From: Lars Date: Thu, 22 Sep 2022 11:57:18 +0200 Subject: [PATCH 23/25] Added analysis to decide if an expression remains constant --- .../vct/col/newrewrite/ParBlockEncoder.scala | 94 +++++++++++++++++-- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala index 99d0d55c7a..87362b4a86 100644 --- a/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala +++ b/src/main/java/vct/col/newrewrite/ParBlockEncoder.scala @@ -161,17 +161,92 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { quantify(block, block.context_everywhere &* block.ensures, nonEmpty) } - def constantExpression(e: Expr[_]): Boolean = e match { + def constantExpression[A](e: Expr[A], nonConstVars: Set[Variable[A]]): Boolean = e match { case _: Constant[_, _] => true - case op: BinExpr[_] => constantExpression(op.left) && constantExpression(op.right) - /* TODO: This is the hard part, do we do an analysis on locals to see if they are never changed in the body of a a - * parallel block? We need this, otherwise we can never hope to rewrite nested foralls, since the bounds will - * depend on the hi_var en var_low values. And on that the nested quantifier pass will get stuck. - */ - case _: Local[_] => false + case op: BinExpr[_] => constantExpression(op.left, nonConstVars) && constantExpression(op.right, nonConstVars) + case Local(v) => !nonConstVars.contains(v.decl) case _ => false } + def applyFunc[A,B](f: A => Option[Set[B]], xs: Iterable[A]): Option[Set[B]] = + xs.foldLeft(Some(Set()): Option[Set[B]]){case (res, x) => res.flatMap(r => f(x).map(s => r ++ s))} + + def combine[A](xs: Option[Set[A]], ys: Option[Set[A]]): Option[Set[A]] = + xs.flatMap(xs => ys.map(ys => xs ++ ys)) + + def scanIter(xs: Iterable[Statement[Pre]]): Option[Set[Variable[Pre]]] = + applyFunc(scanForAssign, xs) + + def scanIterE(xs: Iterable[Expr[Pre]]): Option[Set[Variable[Pre]]] = + applyFunc(scanForAssignE, xs) + + /* Scans statements and returns a set of variables, which will change value. Meaning other variables, did not change value. + * The analysis is conservative, if we are not sure, we return None, meaning we are not sure about any variable if it + * is potentially constant + */ + def scanForAssign(s: Statement[Pre]): Option[Set[Variable[Pre]]] = s match { + case Block(statements) => scanIter(statements) + case Scope(_, statement) => scanForAssign(statement) + case Branch(branches) => + for { + set1 <- scanIter(branches.map(_._2)) + set2 <- scanIterE(branches.map(_._1)) + } yield set1 ++ set2 + case IndetBranch(statements) => scanIter(statements) + case Switch(expr, body) => combine(scanForAssignE(expr), scanForAssign(body)) + case Loop(init, cond, update, _, body) => + scanForAssign(init).flatMap(r1 => scanForAssignE(cond).flatMap(r2 => scanForAssign(update) + .flatMap(r3 => scanForAssign(body).map(r4 => r1 ++ r2 ++ r3 ++ r4)))) + case Assign(target, value) => combine(scanAssignTarget(target), scanForAssignE(value)) + case ParStatement(par) => scanForAssignP(par) + case ParBarrier(_, _, _, _, content) => scanForAssign(content) + case Eval(e) => scanForAssignE(e) + case _ => None + } + + def scanForAssignP(par: ParRegion[Pre]): Option[Set[Variable[Pre]]] = par match { + case ParParallel(pars) => applyFunc(scanForAssignP, pars) + case ParSequential(pars) => applyFunc(scanForAssignP, pars) + case ParBlock(_, _, _, _, _, content) => scanForAssign(content) + } + + def scanForAssignE(e: Expr[Pre]): Option[Set[Variable[Pre]]] = e match { + case _ : Constant[Pre, _] => Some(Set()) + case Local(_) => Some(Set()) + case op : BinExpr[Pre] => combine(scanForAssignE(op.left), scanForAssignE(op.right)) + case ProcedureInvocation(_, args, outArgs, _, givenMap, yields) => + // Primitive types cannot be changed due to a function call, so we filter them + // Other arguments can possibly be changed, so we collect them + applyFunc(scanAssignTarget, args.filterNot(e => isPrimitive(e.t))) + .flatMap(r1 => scanIterE(givenMap.map(_._2)) + // TODO This is how yield works right? The first variable of a yield tupple gets assigned to? + .map(r2 => r1 ++ r2 ++ outArgs.map(a => a.decl) ++ yields.map(y => y._1.decl))) + case FunctionInvocation(_, args, _, givenMap, yields) => + applyFunc(scanAssignTarget, args.filterNot(e => isPrimitive(e.t))) + .flatMap(r1 => scanIterE(givenMap.map(_._2)) + .map(r2 => r1 ++ r2 ++ yields.map(y => y._1.decl))) + case PreAssignExpression(target, value) => combine(scanAssignTarget(target), scanForAssignE(value)) + case PointerSubscript(pointer, index) => combine(scanForAssignE(pointer), scanForAssignE(index)) + case ArraySubscript(pointer, index) => combine(scanForAssignE(pointer), scanForAssignE(index)) + case SeqSubscript(seq, index) => combine(scanForAssignE(seq), scanForAssignE(index)) + case _ => None + } + + def scanAssignTarget(e: Expr[Pre]): Option[Set[Variable[Pre]]] = e match { + case Local(v) => Some(Set(v.decl)) + case ArraySubscript(arr, index) => combine(scanAssignTarget(arr), scanForAssignE(index)) + case PointerSubscript(pointer, index) => combine(scanAssignTarget(pointer), scanForAssignE(index)) + case SeqSubscript(seq, index) => combine(scanAssignTarget(seq), scanForAssignE(index)) + case _ => None + } + + def isPrimitive(t: Type[_]): Boolean = t match { + case _: PrimitiveType[_] => true + case _ => false + } + + case class NonConstantVariables(vars: Set[Variable[Pre]]) + def ranges(region: ParRegion[Pre], rangeValues: mutable.Map[Variable[Pre], (Expr[Post], Expr[Post])]): Statement[Post] = region match { case ParParallel(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) case ParSequential(regions) => Block(regions.map(ranges(_, rangeValues)))(region.o) @@ -182,9 +257,10 @@ case class ParBlockEncoder[Pre <: Generation]() extends Rewriter[Pre] { implicit val o: Origin = v.o val from = dispatch(v.from) val to = dispatch(v.to) - if(constantExpression(from) && constantExpression (to)){ + val nonConstVars = scanForAssign(block.content) + if(nonConstVars.nonEmpty && constantExpression(v.from, nonConstVars.get) && constantExpression (v.to, nonConstVars.get)){ rangeValues(v.variable) = (from, to) - res + res } else { val lo = variables.declare(new Variable[Post](TInt())(LowEvalOrigin(v))) val hi = variables.declare(new Variable[Post](TInt())(HighEvalOrigin(v))) From bbf3b2b0d7ef6dd9b919abfeeed2b20e4d93878e Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 26 Sep 2022 11:11:14 +0200 Subject: [PATCH 24/25] Added static shared memory --- .../ambiguous/AmbiguousSubscriptImpl.scala | 4 +- .../statement/composite/ScopeImpl.scala | 2 +- col/src/main/java/vct/col/print/Printer.scala | 25 +- col/src/main/java/vct/col/resolve/C.scala | 2 +- .../main/java/vct/col/resolve/Resolve.scala | 32 ++- .../vct/col/newrewrite/lang/LangCToCol.scala | 244 ++++++++++++++---- .../col/newrewrite/lang/LangTypesToCol.scala | 8 +- 7 files changed, 244 insertions(+), 73 deletions(-) diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/ambiguous/AmbiguousSubscriptImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/ambiguous/AmbiguousSubscriptImpl.scala index a9a858495d..e354bdb44a 100644 --- a/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/ambiguous/AmbiguousSubscriptImpl.scala +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/expr/ambiguous/AmbiguousSubscriptImpl.scala @@ -2,6 +2,7 @@ package vct.col.ast.temporaryimplpackage.expr.ambiguous import vct.col.ast.{AmbiguousSubscript, Type} import vct.col.coerce.CoercionUtils +import vct.result.VerificationError.{Unreachable} trait AmbiguousSubscriptImpl[G] { this: AmbiguousSubscript[G] => def isSeqOp: Boolean = CoercionUtils.getAnySeqCoercion(collection.t).isDefined @@ -13,5 +14,6 @@ trait AmbiguousSubscriptImpl[G] { this: AmbiguousSubscript[G] => if (isSeqOp) collection.t.asSeq.get.element else if (isArrayOp) collection.t.asArray.get.element else if (isPointerOp) collection.t.asPointer.get.element - else collection.t.asMap.get.value + else if (isMapOp) collection.t.asMap.get.value + else throw Unreachable(s"Trying to subscript ($this) a non subscriptable variable with type $collection.t") } \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/statement/composite/ScopeImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/statement/composite/ScopeImpl.scala index 22a413ff6f..585d52bccf 100644 --- a/col/src/main/java/vct/col/ast/temporaryimplpackage/statement/composite/ScopeImpl.scala +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/statement/composite/ScopeImpl.scala @@ -6,5 +6,5 @@ import vct.col.resolve.ResolveReferences trait ScopeImpl[G] { this: Scope[G] => override def enterCheckContext(context: CheckContext[G]): CheckContext[G] = - context.withScope((locals ++ ResolveReferences.scanScope(body)).toSet) + context.withScope((locals ++ ResolveReferences.scanScope(body, inGPUKernel = false)).toSet) } \ No newline at end of file diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 0f6f505477..1782aedfc1 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -38,7 +38,7 @@ sealed trait PrinterState { case class InLine(lastWasSpace: Boolean, specDepth: Int, banNewlinesDepth: Int, indent: Int) extends PrinterState { override def say(text: String)(implicit printer: Printer): PrinterState = { printer.out.append(text) - InLine(text.last == ' ', specDepth, banNewlinesDepth, indent) + InLine(if(text.nonEmpty) text.last == ' ' else false, specDepth, banNewlinesDepth, indent) } override def space()(implicit printer: Printer): PrinterState = { @@ -412,7 +412,7 @@ case class Printer(out: Appendable, def printStatement(stat: Statement[_]): Unit = say(stat match { case CDeclarationStatement(decl) => - statement(syntax(C -> phrase(decl.decl.specs, commas(decl.decl.inits.map(NodePhrase))))) + statement(syntax(C -> phrase(intersperse(" ", decl.decl.specs.map(NodePhrase)), space, commas(decl.decl.inits.map(NodePhrase))))) case ref @ CGoto(label) => statement(syntax(C -> phrase("goto", space, Text(ref.ref.map(name).getOrElse(label))))) case GpgpuBarrier(requires, ensures, specifier) => @@ -1128,6 +1128,10 @@ case class Printer(out: Appendable, statement(field.t, space, name(field)) case variable: Variable[_] => phrase(variable.t, space, name(variable)) + case decl: CLocalDeclaration[_] => + phrase(decl.decl) + case decl: CGlobalDeclaration[_] => + phrase(decl.decl) case decl: LabelDecl[_] => ??? case decl: ParBlockDecl[_] => @@ -1185,7 +1189,9 @@ case class Printer(out: Appendable, def printCDeclarator(node: CDeclarator[_]): Unit = node match { case CPointerDeclarator(pointers, inner) => say("*".repeat(pointers.size), inner) - case CArrayDeclarator(qualifiers, size, inner) => + case CArrayDeclarator(qualifiers, Some(size), inner) => + say(inner, "[", size ,"]") + case CArrayDeclarator(qualifiers, None, inner) => say(inner, "[]") case CTypedFunctionDeclarator(params, varargs, inner) => say(inner, "(", commas(params.map(NodePhrase)), ")") @@ -1220,7 +1226,7 @@ case class Printer(out: Appendable, )) case GPUGlobal() => say(syntax( OpenCL -> phrase("__global"), - )) + )) } def printCTypeQualifier(node: CTypeQualifier[_]): Unit = node match { @@ -1295,6 +1301,16 @@ case class Printer(out: Appendable, print(node.program) } + def printCDeclaration(node: CDeclaration[_]): Unit = { + say(newline, node.contract, newline) + node.kernelInvariant match { + case BooleanValue(true) => + case _ => say("kernel_invariant: ", node.kernelInvariant, space) + } + say(spaced(node.specs.map(NodePhrase)), space) + say(spaced(node.inits.map(NodePhrase))) + } + def print(node: Node[_]): Unit = node match { case program: Program[_] => printProgram(program) case stat: Statement[_] => printStatement(stat) @@ -1319,6 +1335,7 @@ case class Printer(out: Appendable, case node : Location[_] => printLocation(node) case node: Verification[_] => printVerification(node) case node: VerificationContext[_] => printVerificationContext(node) + case node: CDeclaration[_] => printCDeclaration(node) case x => say(s"Unknown node type in Printer.scala: ${x.getClass.getCanonicalName}") } diff --git a/col/src/main/java/vct/col/resolve/C.scala b/col/src/main/java/vct/col/resolve/C.scala index 02f19cfd9d..4655627bb5 100644 --- a/col/src/main/java/vct/col/resolve/C.scala +++ b/col/src/main/java/vct/col/resolve/C.scala @@ -49,7 +49,7 @@ case object C { case CArrayDeclarator(_, _, inner) => val innerInfo = getDeclaratorInfo(inner) // TODO PB: I think pointer is not correct here. - DeclaratorInfo(innerInfo.params, t => TPointer(innerInfo.typeOrReturnType(t)), innerInfo.name) + DeclaratorInfo(innerInfo.params, t => TArray(innerInfo.typeOrReturnType(t)), innerInfo.name) case CTypedFunctionDeclarator(params, _, inner) => val innerInfo = getDeclaratorInfo(inner) DeclaratorInfo(params=Some(params), typeOrReturnType=(t => t), innerInfo.name) diff --git a/col/src/main/java/vct/col/resolve/Resolve.scala b/col/src/main/java/vct/col/resolve/Resolve.scala index fd4bbe338a..8052cfdb8d 100644 --- a/col/src/main/java/vct/col/resolve/Resolve.scala +++ b/col/src/main/java/vct/col/resolve/Resolve.scala @@ -78,9 +78,13 @@ case object ResolveReferences { resolve(program, ReferenceResolutionContext[G]()) } - def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G]): Seq[CheckError] = { - val innerCtx = enterContext(node, ctx) - val childErrors = node.subnodes.flatMap(resolve(_, innerCtx)) + def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean=false): Seq[CheckError] = { + val inGPU = inGPUKernel || (node match { + case f: CFunctionDefinition[G] => f.specs.collectFirst{case _: CKernel[G] => ()}.isDefined + case _ => false + }) + val innerCtx = enterContext(node, ctx, inGPU) + val childErrors = node.subnodes.flatMap(resolve(_, innerCtx, inGPU)) if(childErrors.nonEmpty) childErrors else { @@ -89,12 +93,14 @@ case object ResolveReferences { } } - def scanScope[G](node: Node[G]): Seq[Declaration[G]] = node match { + def scanScope[G](node: Node[G], inGPUKernel: Boolean): Seq[Declaration[G]] = node match { case _: Scope[G] => Nil - case CDeclarationStatement(decl) => Seq(decl) + // Remove shared memory locations from the body level of a GPU kernel, we want to reason about them at the top level + case CDeclarationStatement(decl) if !(inGPUKernel && decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined) + => Seq(decl) case JavaLocalDeclarationStatement(decl) => Seq(decl) case LocalDecl(v) => Seq(v) - case other => other.subnodes.flatMap(scanScope) + case other => other.subnodes.flatMap(scanScope(_, inGPUKernel)) } def scanLabels[G](node: Node[G]): Seq[Declaration[G]] = node.transSubnodes.collect { @@ -108,7 +114,11 @@ case object ResolveReferences { case block: ParBlock[G] => Seq(block) } - def enterContext[G](node: Node[G], ctx: ReferenceResolutionContext[G]): ReferenceResolutionContext[G] = (node match { + def scanShared[G](node: Node[G]): Seq[Declaration[G]] = node.transSubnodes.collect { + case decl: CLocalDeclaration[G] if decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined => decl + } + + def enterContext[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean = false): ReferenceResolutionContext[G] = (node match { case ns: JavaNamespace[G] => ctx .copy(currentJavaNamespace=Some(ns)).declare(ns.declarations) case cls: JavaClassOrInterface[G] => ctx @@ -137,9 +147,13 @@ case object ResolveReferences { case TArray(elem) => elem case _ => throw WrongArrayInitializer(init) })) - case func: CFunctionDefinition[G] => ctx + case func: CFunctionDefinition[G] => + var res = ctx .copy(currentResult=Some(RefCFunctionDefinition(func))) .declare(C.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body)) // FIXME suspect wrt contract declarations and stuff + if(func.specs.collectFirst{case CKernel() => ()}.isDefined) + res = res.declare(scanShared(func.body)) + res case func: CGlobalDeclaration[G] => ctx // PB: This is a bit dubious. It's like this because one global declaration can contain multiple forward function // declarations, but the contract is before the whole declaration. @@ -147,7 +161,7 @@ case object ResolveReferences { case par: ParStatement[G] => ctx .declare(scanBlocks(par.impl).map(_.decl)) case Scope(locals, body) => ctx - .declare(locals ++ scanScope(body)) + .declare(locals ++ scanScope(body, inGPUKernel)) case app: Applicable[G] => ctx .declare(app.declarations ++ app.body.map(scanLabels).getOrElse(Nil)) case declarator: Declarator[G] => ctx diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index d607dbef51..ed5509d400 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -27,6 +27,11 @@ case object LangCToCol { example.o.messageInContext("Global variables in C are not supported.") } + case class MultipleSharedMemoryDeclaration(decl: Node[_]) extends UserError { + override def code: String = "multipleSharedMemoryDeclaration" + override def text: String = s"We don't support declaring multiple shared memory variables at a single line: '$decl'." + } + case class WrongGPUKernelParameterType(param: CParam[_]) extends UserError { override def code: String = "wrongParameterType" override def text: String = s"The parameter `$param` has a type that is not allowed`as parameter in a GPU kernel." @@ -34,12 +39,23 @@ case object LangCToCol { case class WrongGPUType(param: CParam[_]) extends UserError { override def code: String = "wrongGPUType" - override def text: String = s"The parameter `$param` has a type that is not allowed`outside of a GPU kernel." + override def text: String = s"The parameter `$param` has a type that is not allowed outside of a GPU kernel." + } + + case class WrongCType(decl: CLocalDeclaration[_]) extends UserError { + override def code: String = "wrongCType" + override def text: String = s"The declaration `$decl` has a type that is not supported." + } + + // TODO: LvdH: How do I get prettier error messages that refer to the origin and etc? + case class WrongGPULocalType(local: CLocalDeclaration[_]) extends UserError { + override def code: String = "wrongGPULocalType" + override def text: String = s"The local declaration `$local` has a type that is not allowed inside a GPU kernel." } case class NotDynamicSharedMem(e: Expr[_]) extends UserError { override def code: String = "notDynamicSharedMem" - override def text: String = s"The expression \\shared_mem_size(`$e`) is not referencing to a shared memory location." + override def text: String = s"The expression `\\shared_mem_size($e)` is not referencing to a dynamic shared memory location." } case class WrongBarrierSpecifier(b: GpgpuBarrier[_]) extends UserError { @@ -86,9 +102,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val cudaCurrentGrid: ScopedStack[ParBlockDecl[Post]] = ScopedStack() val cudaCurrentBlock: ScopedStack[ParBlockDecl[Post]] = ScopedStack() - private val dynamicSharedMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() + private val dynamicSharedMemNames: mutable.Set[CNameTarget[Pre]] = mutable.Set() private val dynamicSharedMemLengthVar: mutable.Map[CNameTarget[Pre], Variable[Post]] = mutable.Map() - private val staticSharedMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() + private val staticSharedMemNames: mutable.Map[CNameTarget[Pre], BigInt] = mutable.Map() private val globalMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() private var inKernel: Boolean = false @@ -113,7 +129,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def varIsNotShared(l: CLocal[Pre]): Boolean = { l.ref match { - case Some(ref: RefCParam[Pre]) + case Some(ref: CNameTarget[Pre]) if dynamicSharedMemNames.contains(ref) || staticSharedMemNames.contains(ref) => return false case None => if (!allowedNonRefs.contains(l.name)) ??? case _ => @@ -149,45 +165,25 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O cParam.drop() val varO = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) implicit val o: Origin = cParam.o - - var arrayOrPointer = false - var global = false - var shared = false - var extern = false - var innerType: Option[Type[Pre]] = None - val cRef = RefCParam(cParam) - - cParam.specifiers.foreach{ - case GPULocal() => shared = true - case GPUGlobal() => global = true - case CSpecificationType(TPointer(t)) => - arrayOrPointer = true - innerType = Some(t) - case CSpecificationType(TArray(t)) => - arrayOrPointer = true - innerType = Some(t) - case CExtern() => extern = true - case _ => - } + val tp = new TypeProperties(cParam.specifiers, cParam.declarator) language match { case Some(Language.C) => - if(global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) - else if(shared && !global && arrayOrPointer && !extern) { - dynamicSharedMemNames.add(cRef) - // Create Var with array type here and return - val v = new Variable[Post](TArray[Post](rw.dispatch(innerType.get)) )(varO) - cNameSuccessor(cRef) = v + if(tp.isGlobal && tp.arrayOrPointer && !tp.extern) globalMemNames.add(cRef) + else if(tp.isShared && tp.arrayOrPointer && !tp.extern && tp.innerType.isDefined) { + addDynamicShared(cRef, tp.innerType.get, varO) + // Return, since shared memory locations are declared and initialized at thread block level, not kernel level return } - else if(!shared && !global && !arrayOrPointer && !extern) () + else if(!tp.shared && !tp.global && !tp.arrayOrPointer && !tp.extern) () else throw WrongGPUKernelParameterType(cParam) case Some(Language.CUDA) => - if(!global && !shared && arrayOrPointer && !extern) globalMemNames.add(cRef) - else if(!shared && !global && !arrayOrPointer && !extern) () + if(!tp.global && !tp.shared && tp.arrayOrPointer && !tp.extern) globalMemNames.add(cRef) + else if(!tp.shared && !tp.global && !tp.arrayOrPointer && !tp.extern) () else throw WrongGPUKernelParameterType(cParam) - case _ => throw Unreachable(f"The language '$language' should not have GPU kernels.'") + case Some(l) => throw Unreachable(f"The language '$l' should not have GPU kernels.") + case None => throw Unreachable(f"We have GPU kernels, but the source code language could not be determined.") } val v = new Variable[Post](cParam.specifiers.collectFirst @@ -295,6 +291,50 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O cudaCurrentBlockIdx.having(block) { all(block, cudaCurrentGridDim.top, allThreadsInBlock(e)) } } + def getCDecl(d: CNameTarget[Pre]): CDeclarator[Pre] = d match { + case RefCParam(decl) => decl.declarator + case RefCFunctionDefinition(decl) => decl.declarator + case RefCGlobalDeclaration(decls, initIdx) => decls.decl.inits(initIdx).decl + case RefCLocalDeclaration(decls, initIdx) => decls.decl.inits(initIdx).decl + case _ => throw Unreachable("Should not happen") + } + + def getInnerType(t: Type[Post]): Type[Post] = t match { + case TArray(element) => element + case TPointer(element) => element + case _ => throw Unreachable("Already checked on pointer or array type") + } + + def declareSharedMemory(): (Seq[Variable[Post]], Seq[Statement[Post]]) = { + rw.variables.collect { + var result: Seq[Statement[Post]] = Seq() + dynamicSharedMemNames.foreach(d => + { + implicit val o: Origin = getCDecl(d).o + val varO: Origin = InterpretedOriginVariable(s"${C.getDeclaratorInfo(getCDecl(d)).name}_size", o) + val v = new Variable[Post](TInt())(varO) + dynamicSharedMemLengthVar(d) = v + rw.variables.declare(v) + val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) + val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) + , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) + result ++= Seq(decl, assign) + }) + staticSharedMemNames.foreach{case (d,size) => + { + implicit val o: Origin = getCDecl(d).o + val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) + val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) + , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(IntegerValue(size)), 0))(PanicBlame("Assign should work")) + result ++= Seq(decl, assign) + }} + + result + } + + + } + def kernelProcedure(o: Origin, contract: ApplicableContract[Pre], info: C.DeclaratorInfo[Pre], body: Option[Statement[Pre]]): Procedure[Post] = { dynamicSharedMemNames.clear() staticSharedMemNames.clear() @@ -306,22 +346,9 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O cudaCurrentGridDim.having(gridDim) { val args = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 rw.variables.collect { dynamicSharedMemNames.foreach(d => rw.variables.declare(cNameSuccessor(d)) ) } - val (sharedMemSizes, sharedMemInit: Seq[Statement[Post]]) = rw.variables.collect { - var result: Seq[Statement[Post]] = Seq() - dynamicSharedMemNames.foreach(d => - { - implicit val o: Origin = d.decl.o - val varO: Origin = InterpretedOriginVariable(s"${C.getDeclaratorInfo(d.decl.declarator).name}_size", d.decl.o) - val v = new Variable[Post](TInt())(varO) - dynamicSharedMemLengthVar(d) = v - rw.variables.declare(v) - val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) - val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) - , NewArray[Post](v.t, Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) - result ++= Seq(decl, assign) - }) - result - } + rw.variables.collect { staticSharedMemNames.foreach(d => rw.variables.declare(cNameSuccessor(d._1)) ) } + val implFiltered = body.map(init => filterSharedDecl(init)) + val (sharedMemSizes, sharedMemInit: Seq[Statement[Post]]) = declareSharedMemory() val newArgs = blockDim.indices.values.toSeq ++ gridDim.indices.values.toSeq ++ args ++ sharedMemSizes val newGivenArgs = rw.variables.dispatch(contract.givenArgs) @@ -356,7 +383,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O context_everywhere = Star(nonZeroThreads, rw.dispatch(contract.contextEverywhere)), requires = rw.dispatch(contractRequires), ensures = rw.dispatch(contractEnsures), - content = rw.dispatch(impl), + content = rw.dispatch(implFiltered.get), )(PanicBlame("where blame?"))) ParStatement(ParBlock( decl = gridDecl, @@ -420,6 +447,100 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } } + class TypeProperties(specs: Seq[CDeclarationSpecifier[Pre]], decl: CDeclarator[Pre]){ + var arrayOrPointer = false + var global = false + var shared = false + var extern = false + var innerType: Option[Type[Pre]] = None + var maybeInnerType: Option[Type[Pre]] = None + + specs.foreach { + case GPULocal() => shared = true + case GPUGlobal() => global = true + case CSpecificationType(TPointer(t)) => + arrayOrPointer = true + innerType = Some(t) + case CSpecificationType(TArray(t)) => + arrayOrPointer = true + innerType = Some(t) + case CSpecificationType(t) => maybeInnerType = Some(t) + case CExtern() => extern = true + case _ => + } + + decl match { + case _ :CArrayDeclarator[Pre] => + arrayOrPointer = true + innerType = maybeInnerType + case _ => + } + + def isShared: Boolean = shared && !global + def isGlobal: Boolean = !shared && global + } + + def addDynamicShared(cRef: CNameTarget[Pre], t: Type[Pre], o: Origin): Unit = { + dynamicSharedMemNames.add(cRef) + val v = new Variable[Post](TArray[Post](rw.dispatch(t)))(o) + cNameSuccessor(cRef) = v + } + + def addStaticShared(decl: CDeclarator[Pre], cRef: CNameTarget[Pre], t: Type[Pre], o: Origin, declStatement: CLocalDeclaration[Pre]): Unit = decl match { + case CArrayDeclarator(Seq(), Some(IntegerValue(size)), _) => + val v = new Variable[Post](TArray[Post](rw.dispatch(t)))(o) + staticSharedMemNames(cRef) = size + cNameSuccessor(cRef) = v + case _ => throw WrongGPULocalType(declStatement) + } + + + def isShared(s: Statement[Pre]): Boolean = { + if(!s.isInstanceOf[CDeclarationStatement[Pre]]) + return false + + val CDeclarationStatement(decl) = s + + if (decl.decl.inits.size != 1) + throw MultipleSharedMemoryDeclaration(decl) + + val prop = new TypeProperties(decl.decl.specs, decl.decl.inits.head.decl) + if (!prop.shared) return false + val init: CInit[Pre] = decl.decl.inits.head + val varO = InterpretedOriginVariable(C.getDeclaratorInfo(init.decl).name, decl.o) + val cRef = RefCLocalDeclaration(decl, 0) + + language match { + case Some(Language.CUDA) => + if (!prop.global && prop.arrayOrPointer && prop.innerType.isDefined && prop.extern) { + addDynamicShared(cRef, prop.innerType.get, varO) + return true + } + if (!prop.global && prop.arrayOrPointer && prop.innerType.isDefined && !prop.extern) { + addStaticShared(init.decl, cRef, prop.innerType.get, varO, decl) + return true + } + case Some(Language.C) => + if (!prop.global && prop.arrayOrPointer && prop.innerType.isDefined && !prop.extern) { + addStaticShared(init.decl, cRef, prop.innerType.get, varO, decl) + return true + } + case Some(l) => throw Unreachable(f"It should not be possible for the language '$l' to have GPU kernels.") + case None => throw Unreachable(f"We have GPU kernels, but the source code language could not be determined.") + } + + // We are shared, but couldn't add it + throw WrongGPULocalType(decl) + } + + def filterSharedDecl(s: Statement[Pre]): Statement[Pre] = { + s match { + case Scope(locals, block@Block(stats)) => + Scope(locals, Block(stats.filterNot(isShared))(block.o))(s.o) + case _ => s + } + } + def rewriteGlobalDecl(decl: CGlobalDeclaration[Pre]): Unit = { val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) for((init, idx) <- decl.decl.inits.zipWithIndex if init.ref.isEmpty) { @@ -451,12 +572,23 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O decl.drop() // PB: this is correct because Seq[CInit]'s are flattened, but the structure is a bit stupid. val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) + decl.decl.specs.foreach { + case _: CSpecificationType[Pre] => + case _ => throw WrongCType(decl) + } Block(for((init, idx) <- decl.decl.inits.zipWithIndex) yield { val info = C.getDeclaratorInfo(init.decl) - info.params match { - case Some(params) => ??? - case None => - val varO: Origin = InterpretedOriginVariable(C.getDeclaratorInfo(init.decl).name, init.o) + val varO: Origin = InterpretedOriginVariable(info.name, init.o) + (info.params, init) match { + case (Some(params), _) => ??? + case (None, CInit(CArrayDeclarator(qualifiers, size, inner), _)) => + if(qualifiers.nonEmpty || init.init.isDefined) throw WrongCType(decl) + implicit val o: Origin = init.o + val v = new Variable[Post](TArray(t))(varO) + cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v + val newArr = NewArray[Post](t, Seq(rw.dispatch(size.get)), 0) + Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) + case (None, CInit(d, _)) => val v = new Variable[Post](t)(varO) cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v implicit val o: Origin = init.o diff --git a/src/main/java/vct/col/newrewrite/lang/LangTypesToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangTypesToCol.scala index b5623faefd..5225814d28 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangTypesToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangTypesToCol.scala @@ -1,6 +1,6 @@ package vct.col.newrewrite.lang -import vct.col.ast.{AxiomaticDataType, CBool, CChar, CDeclaration, CDeclarationSpecifier, CDeclarator, CDouble, CFloat, CFunctionDefinition, CGlobalDeclaration, CInit, CLocalDeclaration, CLong, CName, CParam, CPrimitiveType, CSpecificationType, CTypeSpecifier, CTypedFunctionDeclarator, CTypedefName, CVoid, Declaration, JavaNamedType, JavaTClass, Model, Node, PVLNamedType, SilverPartialTAxiomatic, TAxiomatic, TBool, TChar, TClass, TFloat, TInt, TModel, TNotAValue, TUnion, TVar, TVoid, Type} +import vct.col.ast.{AxiomaticDataType, CArrayDeclarator, CBool, CChar, CDeclaration, CDeclarationSpecifier, CDeclarator, CDouble, CFloat, CFunctionDefinition, CGlobalDeclaration, CInit, CLocalDeclaration, CLong, CName, CParam, CPrimitiveType, CSpecificationType, CTypeSpecifier, CTypedFunctionDeclarator, CTypedefName, CVoid, Declaration, JavaNamedType, JavaTClass, Model, Node, PVLNamedType, SilverPartialTAxiomatic, TAxiomatic, TBool, TChar, TClass, TFloat, TInt, TModel, TNotAValue, TUnion, TVar, TVoid, Type} import vct.col.origin.Origin import vct.col.resolve.{C, RefAxiomaticDataType, RefClass, RefJavaClass, RefModel, RefVariable, SpecTypeNameTarget} import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten} @@ -82,6 +82,12 @@ case class LangTypesToCol[Pre <: Generation]() extends Rewriter[Pre] { case None => CName[Post](info.name) } + declarator match { + case CArrayDeclarator(_, Some(size), _) => + val spec = CSpecificationType[Post](dispatch(baseType)) +: otherSpecifiers + return (spec, CArrayDeclarator(Seq(), Some(dispatch(size)), newDeclarator)) + case _ => + } (newSpecifiers, newDeclarator) } From 1b4296b1d4ef75a0d2c6a0d77f0937e9f2eca499 Mon Sep 17 00:00:00 2001 From: Lars Date: Mon, 26 Sep 2022 15:14:10 +0200 Subject: [PATCH 25/25] Safe specific kernel type as cuda or opencl --- col/src/main/java/vct/col/ast/Node.scala | 3 +- .../lang/CKernelImpl.scala | 7 - .../lang/CUDAKernelImpl.scala | 7 + .../lang/OpenCLKernelImpl.scala | 7 + .../java/vct/col/feature/FeatureRainbow.scala | 3 +- col/src/main/java/vct/col/print/Printer.scala | 3 +- col/src/main/java/vct/col/resolve/C.scala | 2 +- .../main/java/vct/col/resolve/Resolve.scala | 4 +- parsers/lib/antlr4/LangGPGPULexer.g4 | 3 +- parsers/lib/antlr4/LangGPGPUParser.g4 | 5 +- .../main/java/vct/parsers/ColCParser.scala | 3 +- .../main/java/vct/parsers/ParseResult.scala | 9 +- .../java/vct/parsers/transform/CToCol.scala | 3 +- .../vct/col/newrewrite/lang/LangCToCol.scala | 187 +++++++++--------- .../newrewrite/lang/LangSpecificToCol.scala | 6 +- src/main/java/vct/main/stages/Parsing.scala | 9 +- .../java/vct/main/stages/Resolution.scala | 9 +- src/main/java/vct/main/util/Util.scala | 2 +- .../main/java/vct/parsers/Language.scala | 4 +- src/main/universal/res/c/cuda.h | 2 +- src/main/universal/res/c/opencl.h | 2 +- 21 files changed, 142 insertions(+), 138 deletions(-) delete mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CKernelImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CUDAKernelImpl.scala create mode 100644 col/src/main/java/vct/col/ast/temporaryimplpackage/lang/OpenCLKernelImpl.scala rename {parsers/src => src}/main/java/vct/parsers/Language.scala (83%) diff --git a/col/src/main/java/vct/col/ast/Node.scala b/col/src/main/java/vct/col/ast/Node.scala index c25e8208be..eff783ff9e 100644 --- a/col/src/main/java/vct/col/ast/Node.scala +++ b/col/src/main/java/vct/col/ast/Node.scala @@ -666,7 +666,8 @@ sealed trait CFunctionSpecifier[G] extends CDeclarationSpecifier[G] with CFuncti sealed trait CAlignmentSpecifier[G] extends CDeclarationSpecifier[G] with CAlignmentSpecifierImpl[G] sealed trait CGpgpuKernelSpecifier[G] extends CDeclarationSpecifier[G] with CGpgpuKernelSpecifierImpl[G] -final case class CKernel[G]()(implicit val o: Origin) extends CGpgpuKernelSpecifier[G] with CKernelImpl[G] +final case class CUDAKernel[G]()(implicit val o: Origin) extends CGpgpuKernelSpecifier[G] with CUDAKernelImpl[G] +final case class OpenCLKernel[G]()(implicit val o: Origin) extends CGpgpuKernelSpecifier[G] with OpenCLKernelImpl[G] final case class CPointer[G](qualifiers: Seq[CTypeQualifier[G]])(implicit val o: Origin) extends NodeFamily[G] with CPointerImpl[G] diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CKernelImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CKernelImpl.scala deleted file mode 100644 index 47bb2afb49..0000000000 --- a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CKernelImpl.scala +++ /dev/null @@ -1,7 +0,0 @@ -package vct.col.ast.temporaryimplpackage.lang - -import vct.col.ast.CKernel - -trait CKernelImpl[G] { this: CKernel[G] => - -} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CUDAKernelImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CUDAKernelImpl.scala new file mode 100644 index 0000000000..accfcdf98d --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/CUDAKernelImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.CUDAKernel + +trait CUDAKernelImpl[G] { this: CUDAKernel[G] => + +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/OpenCLKernelImpl.scala b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/OpenCLKernelImpl.scala new file mode 100644 index 0000000000..226da50194 --- /dev/null +++ b/col/src/main/java/vct/col/ast/temporaryimplpackage/lang/OpenCLKernelImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.temporaryimplpackage.lang + +import vct.col.ast.OpenCLKernel + +trait OpenCLKernelImpl[G] { this: OpenCLKernel[G] => + +} \ No newline at end of file diff --git a/col/src/main/java/vct/col/feature/FeatureRainbow.scala b/col/src/main/java/vct/col/feature/FeatureRainbow.scala index 076afeffab..2ca9394728 100644 --- a/col/src/main/java/vct/col/feature/FeatureRainbow.scala +++ b/col/src/main/java/vct/col/feature/FeatureRainbow.scala @@ -435,7 +435,8 @@ class FeatureRainbow[G] { case node: CRestrict[G] => return Nil case node: CVolatile[G] => return Nil case node: CAtomic[G] => return Nil - case node: CKernel[G] => return Nil + case node: CUDAKernel[G] => return Nil + case node: OpenCLKernel[G] => return Nil case node: CPointer[G] => return Nil case node: CParam[G] => return Nil case node: CPointerDeclarator[G] => return Nil diff --git a/col/src/main/java/vct/col/print/Printer.scala b/col/src/main/java/vct/col/print/Printer.scala index 1782aedfc1..54d75027f6 100644 --- a/col/src/main/java/vct/col/print/Printer.scala +++ b/col/src/main/java/vct/col/print/Printer.scala @@ -1219,7 +1219,8 @@ case class Printer(out: Appendable, case CSpecificationType(t) => say(t) case CTypeQualifierDeclarationSpecifier(typeQual) => say(typeQual) case CExtern() => say("extern") - case CKernel() => say("__kernel") + case OpenCLKernel() => say("__kernel") + case CUDAKernel() => say("__global__") case GPULocal() => say(syntax( Cuda -> phrase("__shared__"), OpenCL -> phrase("__local"), diff --git a/col/src/main/java/vct/col/resolve/C.scala b/col/src/main/java/vct/col/resolve/C.scala index 4655627bb5..02f19cfd9d 100644 --- a/col/src/main/java/vct/col/resolve/C.scala +++ b/col/src/main/java/vct/col/resolve/C.scala @@ -49,7 +49,7 @@ case object C { case CArrayDeclarator(_, _, inner) => val innerInfo = getDeclaratorInfo(inner) // TODO PB: I think pointer is not correct here. - DeclaratorInfo(innerInfo.params, t => TArray(innerInfo.typeOrReturnType(t)), innerInfo.name) + DeclaratorInfo(innerInfo.params, t => TPointer(innerInfo.typeOrReturnType(t)), innerInfo.name) case CTypedFunctionDeclarator(params, _, inner) => val innerInfo = getDeclaratorInfo(inner) DeclaratorInfo(params=Some(params), typeOrReturnType=(t => t), innerInfo.name) diff --git a/col/src/main/java/vct/col/resolve/Resolve.scala b/col/src/main/java/vct/col/resolve/Resolve.scala index 8052cfdb8d..11a3857b1d 100644 --- a/col/src/main/java/vct/col/resolve/Resolve.scala +++ b/col/src/main/java/vct/col/resolve/Resolve.scala @@ -80,7 +80,7 @@ case object ResolveReferences { def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean=false): Seq[CheckError] = { val inGPU = inGPUKernel || (node match { - case f: CFunctionDefinition[G] => f.specs.collectFirst{case _: CKernel[G] => ()}.isDefined + case f: CFunctionDefinition[G] => f.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined case _ => false }) val innerCtx = enterContext(node, ctx, inGPU) @@ -151,7 +151,7 @@ case object ResolveReferences { var res = ctx .copy(currentResult=Some(RefCFunctionDefinition(func))) .declare(C.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body)) // FIXME suspect wrt contract declarations and stuff - if(func.specs.collectFirst{case CKernel() => ()}.isDefined) + if(func.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined) res = res.declare(scanShared(func.body)) res case func: CGlobalDeclaration[G] => ctx diff --git a/parsers/lib/antlr4/LangGPGPULexer.g4 b/parsers/lib/antlr4/LangGPGPULexer.g4 index caacc6218c..f8e3b8ef24 100644 --- a/parsers/lib/antlr4/LangGPGPULexer.g4 +++ b/parsers/lib/antlr4/LangGPGPULexer.g4 @@ -3,7 +3,8 @@ lexer grammar LangGPGPULexer; GPGPU_BARRIER: '__vercors_barrier__'; GPGPU_LOCAL_MEMORY_FENCE: '__vercors_local_mem_fence__'; GPGPU_GLOBAL_MEMORY_FENCE: '__vercors_global_mem_fence__'; -GPGPU_KERNEL: '__vercors_kernel__'; +OPENCL_KERNEL: '__opencl_kernel__'; +CUDA_KERNEL: '__cuda_kernel__'; GPGPU_ATOMIC: '__vercors_atomic__'; GPGPU_GLOBAL_MEMORY: '__vercors_global_memory__'; GPGPU_LOCAL_MEMORY: '__vercors_local_memory__'; diff --git a/parsers/lib/antlr4/LangGPGPUParser.g4 b/parsers/lib/antlr4/LangGPGPUParser.g4 index bee08e5546..f413e79a4b 100644 --- a/parsers/lib/antlr4/LangGPGPUParser.g4 +++ b/parsers/lib/antlr4/LangGPGPUParser.g4 @@ -23,7 +23,10 @@ gpgpuMemFence | Constant ; -gpgpuKernelSpecifier: GPGPU_KERNEL; +gpgpuKernelSpecifier + : CUDA_KERNEL + | OPENCL_KERNEL + ; gpgpuLocalMemory: GPGPU_LOCAL_MEMORY; diff --git a/parsers/src/main/java/vct/parsers/ColCParser.scala b/parsers/src/main/java/vct/parsers/ColCParser.scala index dd534e0f9c..f6cdbc9b03 100644 --- a/parsers/src/main/java/vct/parsers/ColCParser.scala +++ b/parsers/src/main/java/vct/parsers/ColCParser.scala @@ -24,8 +24,7 @@ case class ColCParser(override val originProvider: OriginProvider, cc: Path, systemInclude: Path, otherIncludes: Seq[Path], - defines: Map[String, String], - language: Language) extends Parser(originProvider, blameProvider) with LazyLogging { + defines: Map[String, String]) extends Parser(originProvider, blameProvider) with LazyLogging { def interpret(localInclude: Seq[Path], input: String, output: String): Process = { var command = Seq(cc.toString, "-C", "-E") diff --git a/parsers/src/main/java/vct/parsers/ParseResult.scala b/parsers/src/main/java/vct/parsers/ParseResult.scala index b6a5ade416..2acd2c9aea 100644 --- a/parsers/src/main/java/vct/parsers/ParseResult.scala +++ b/parsers/src/main/java/vct/parsers/ParseResult.scala @@ -4,12 +4,11 @@ import vct.col.ast.{GlobalDeclaration, VerificationContext} import vct.col.util.ExpectedError case object ParseResult { - def reduce[G](parses: Seq[(ParseResult[G], Option[Language])]): (ParseResult[G], Option[Language]) = + def reduce[G](parses: Seq[ParseResult[G]]): ParseResult[G] = parses.reduceOption((l, r) => (l, r) match { - case ((ParseResult(declsLeft, expectedLeft), l1), (ParseResult(declsRight, expectedRight), l2)) => - val lan = if(l1 == l2) l1 else None - (ParseResult(declsLeft ++ declsRight, expectedLeft ++ expectedRight), lan) - }).getOrElse((ParseResult(Nil, Nil), None)) + case (ParseResult(declsLeft, expectedLeft), ParseResult(declsRight, expectedRight)) => + ParseResult(declsLeft ++ declsRight, expectedLeft ++ expectedRight) + }).getOrElse(ParseResult(Nil, Nil)) } case class ParseResult[G](decls: Seq[GlobalDeclaration[G]], expectedErrors: Seq[ExpectedError]) \ No newline at end of file diff --git a/parsers/src/main/java/vct/parsers/transform/CToCol.scala b/parsers/src/main/java/vct/parsers/transform/CToCol.scala index 61ab359312..b9aaaf06d1 100644 --- a/parsers/src/main/java/vct/parsers/transform/CToCol.scala +++ b/parsers/src/main/java/vct/parsers/transform/CToCol.scala @@ -122,7 +122,8 @@ case class CToCol[G](override val originProvider: OriginProvider, override val b } def convert(implicit kernel: GpgpuKernelSpecifierContext): CGpgpuKernelSpecifier[G] = kernel match { - case GpgpuKernelSpecifier0(_) => CKernel() + case GpgpuKernelSpecifier0(_) => CUDAKernel() + case GpgpuKernelSpecifier1(_) => OpenCLKernel() } def convert(implicit decls: InitDeclaratorListContext): Seq[CInit[G]] = decls match { diff --git a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala index ed5509d400..3d2b1e669d 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangCToCol.scala @@ -14,7 +14,6 @@ import vct.col.resolve.{BuiltinField, BuiltinInstanceMethod, C, CNameTarget, Ref import vct.col.rewrite.{Generation, Rewritten} import vct.col.util.SuccessionMap import vct.col.util.AstBuildHelpers._ -import vct.parsers.Language import vct.result.VerificationError.{Unreachable, UserError} import scala.collection.immutable.ListMap @@ -29,49 +28,57 @@ case object LangCToCol { case class MultipleSharedMemoryDeclaration(decl: Node[_]) extends UserError { override def code: String = "multipleSharedMemoryDeclaration" - override def text: String = s"We don't support declaring multiple shared memory variables at a single line: '$decl'." + override def text: String = + decl.o.messageInContext(s"We don't support declaring multiple shared memory variables at a single line.") } case class WrongGPUKernelParameterType(param: CParam[_]) extends UserError { override def code: String = "wrongParameterType" - override def text: String = s"The parameter `$param` has a type that is not allowed`as parameter in a GPU kernel." + override def text: String = + param.o.messageInContext(s"This parameter has a type that is not allowed as parameter in a GPU kernel.") } case class WrongGPUType(param: CParam[_]) extends UserError { override def code: String = "wrongGPUType" - override def text: String = s"The parameter `$param` has a type that is not allowed outside of a GPU kernel." + override def text: String = + param.o.messageInContext(s"This parameter has a type that is not allowed outside of a GPU kernel.") } case class WrongCType(decl: CLocalDeclaration[_]) extends UserError { override def code: String = "wrongCType" - override def text: String = s"The declaration `$decl` has a type that is not supported." + override def text: String = + decl.o.messageInContext(s"This declaration has a type that is not supported.") } - // TODO: LvdH: How do I get prettier error messages that refer to the origin and etc? case class WrongGPULocalType(local: CLocalDeclaration[_]) extends UserError { override def code: String = "wrongGPULocalType" - override def text: String = s"The local declaration `$local` has a type that is not allowed inside a GPU kernel." + override def text: String = + local.o.messageInContext(s"This local declaration has a type that is not allowed inside a GPU kernel.") } case class NotDynamicSharedMem(e: Expr[_]) extends UserError { override def code: String = "notDynamicSharedMem" - override def text: String = s"The expression `\\shared_mem_size($e)` is not referencing to a dynamic shared memory location." + override def text: String = + e.o.messageInContext(s"`\\shared_mem_size` should reference a dynamic shared memory location.") } case class WrongBarrierSpecifier(b: GpgpuBarrier[_]) extends UserError { override def code: String = "wrongBarrierSpecifier" - override def text: String = s"The barrier `$b` has incorrect specifiers." + override def text: String = + b.o.messageInContext(s"The barrier has incorrect specifiers.") } case class UnsupportedBarrierPermission(e: Node[_]) extends UserError { override def code: String = "unsupportedBarrierPermission" - override def text: String = s"The permission `$e` is unsupported for barrier for now." + override def text: String = + e.o.messageInContext(s"This is unsupported for barrier for now.") } - case class RedistributingBarrier(v: CNameTarget[_], global: Boolean) extends UserError { + case class RedistributingBarrier(v: CNameTarget[_], barrier: GpgpuBarrier[_], global: Boolean) extends UserError { def memFence: String = if(global) "CLK_GLOBAL_MEM_FENCE" else "CLK_LOCAL_MEM_FENCE" override def code: String = "redistributingBarrier" - override def text: String = s"Trying to redistribute the variable `$v` in a GPU barrier, but need the fence `$memFence` to do this." + override def text: String = barrier.o.messageInContext( + s"Trying to redistribute the variable `$v` in a GPU barrier, but need the fence `$memFence` to do this.") } case class CDoubleContracted(decl: CGlobalDeclaration[_], defn: CFunctionDefinition[_]) extends UserError { @@ -84,7 +91,7 @@ case object LangCToCol { } } -case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: Option[Language]) extends LazyLogging { +case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends LazyLogging { import LangCToCol._ type Post = Rewritten[Pre] implicit val implicitRewriter: AbstractRewriter[Pre, Post] = rw @@ -106,7 +113,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O private val dynamicSharedMemLengthVar: mutable.Map[CNameTarget[Pre], Variable[Post]] = mutable.Map() private val staticSharedMemNames: mutable.Map[CNameTarget[Pre], BigInt] = mutable.Map() private val globalMemNames: mutable.Set[RefCParam[Pre]] = mutable.Set() - private var inKernel: Boolean = false + private var kernelSpecifier: Option[CGpgpuKernelSpecifier[Pre]] = None case class CudaIndexVariableOrigin(dim: RefCudaVecDim[_]) extends Origin { override def preferredName: String = @@ -131,7 +138,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O l.ref match { case Some(ref: CNameTarget[Pre]) if dynamicSharedMemNames.contains(ref) || staticSharedMemNames.contains(ref) => return false - case None => if (!allowedNonRefs.contains(l.name)) ??? + case None => if (!allowedNonRefs.contains(l.name)) Unreachable("Should not happen") case _ => } true @@ -161,15 +168,15 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O res.getOrElse(throw NotDynamicSharedMem(pointer)) } - def rewriteGPUParam(cParam: CParam[Pre]): Unit = { + def rewriteGPUParam(cParam: CParam[Pre], kernelSpecifier: CGpgpuKernelSpecifier[Pre]): Unit = { cParam.drop() val varO = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) implicit val o: Origin = cParam.o val cRef = RefCParam(cParam) val tp = new TypeProperties(cParam.specifiers, cParam.declarator) - language match { - case Some(Language.C) => + kernelSpecifier match { + case OpenCLKernel() => if(tp.isGlobal && tp.arrayOrPointer && !tp.extern) globalMemNames.add(cRef) else if(tp.isShared && tp.arrayOrPointer && !tp.extern && tp.innerType.isDefined) { addDynamicShared(cRef, tp.innerType.get, varO) @@ -178,22 +185,20 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } else if(!tp.shared && !tp.global && !tp.arrayOrPointer && !tp.extern) () else throw WrongGPUKernelParameterType(cParam) - case Some(Language.CUDA) => + case CUDAKernel() => if(!tp.global && !tp.shared && tp.arrayOrPointer && !tp.extern) globalMemNames.add(cRef) else if(!tp.shared && !tp.global && !tp.arrayOrPointer && !tp.extern) () else throw WrongGPUKernelParameterType(cParam) - case Some(l) => throw Unreachable(f"The language '$l' should not have GPU kernels.") - case None => throw Unreachable(f"We have GPU kernels, but the source code language could not be determined.") } val v = new Variable[Post](cParam.specifiers.collectFirst - { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(varO) + { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.get)(varO) cNameSuccessor(cRef) = v rw.variables.declare(v) } def rewriteParam(cParam: CParam[Pre]): Unit = { - if(inKernel) return rewriteGPUParam(cParam) + if(kernelSpecifier.isDefined) return rewriteGPUParam(cParam, kernelSpecifier.get) cParam.specifiers.collectFirst{ case GPULocal() => throw WrongGPUType(cParam) case GPUGlobal() => throw WrongGPUType(cParam) @@ -203,7 +208,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val varO = InterpretedOriginVariable(C.getDeclaratorInfo(cParam.declarator).name, cParam.o) val v = new Variable[Post](cParam.specifiers.collectFirst - { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???))(varO) + { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.get)(varO) cNameSuccessor(RefCParam(cParam)) = v rw.variables.declare(v) } @@ -211,7 +216,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def rewriteFunctionDef(func: CFunctionDefinition[Pre]): Unit = { func.drop() val info = C.getDeclaratorInfo(func.declarator) - val returnType = func.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) + val returnType = func.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.get val (contract, subs: Map[CParam[Pre], CParam[Pre]]) = func.ref match { case Some(RefCGlobalDeclaration(decl, idx)) if decl.decl.contract.nonEmpty => @@ -225,23 +230,23 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O (func.contract, Map.empty) } + val namedO = InterpretedOriginVariable(C.getDeclaratorInfo(func.declarator).name, func.o) val proc = cCurrentDefinitionParamSubstitutions.having(subs) { rw.globalDeclarations.declare( - if (func.specs.collectFirst { case CKernel() => () }.nonEmpty) { - val namedO = InterpretedOriginVariable(C.getDeclaratorInfo(func.declarator).name, func.o) - kernelProcedure(namedO, contract, info, Some(func.body)) - } else { - val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 - new Procedure[Post]( - returnType = returnType, - args = params, - outArgs = Nil, - typeArgs = Nil, - body = Some(rw.dispatch(func.body)), - contract = rw.dispatch(contract), - )(func.blame)(func.o) - } + func.specs.collectFirst { case k: CGpgpuKernelSpecifier[Pre] + => kernelProcedure(namedO, contract, info, Some(func.body), k) } + .getOrElse( { + val params = rw.variables.collect { info.params.get.foreach(rw.dispatch) }._1 + new Procedure[Post]( + returnType = returnType, + args = params, + outArgs = Nil, + typeArgs = Nil, + body = Some(rw.dispatch(func.body)), + contract = rw.dispatch(contract), + )(func.blame)(namedO) + } ) ) } @@ -305,40 +310,36 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O case _ => throw Unreachable("Already checked on pointer or array type") } - def declareSharedMemory(): (Seq[Variable[Post]], Seq[Statement[Post]]) = { - rw.variables.collect { - var result: Seq[Statement[Post]] = Seq() - dynamicSharedMemNames.foreach(d => - { - implicit val o: Origin = getCDecl(d).o - val varO: Origin = InterpretedOriginVariable(s"${C.getDeclaratorInfo(getCDecl(d)).name}_size", o) - val v = new Variable[Post](TInt())(varO) - dynamicSharedMemLengthVar(d) = v - rw.variables.declare(v) - val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) - val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) - , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) - result ++= Seq(decl, assign) - }) - staticSharedMemNames.foreach{case (d,size) => - { - implicit val o: Origin = getCDecl(d).o - val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) - val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) - , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(IntegerValue(size)), 0))(PanicBlame("Assign should work")) - result ++= Seq(decl, assign) - }} - - result + def declareSharedMemory(): (Seq[Variable[Post]], Seq[Statement[Post]]) = rw.variables.collect { + var result: Seq[Statement[Post]] = Seq() + dynamicSharedMemNames.foreach(d => + { + implicit val o: Origin = getCDecl(d).o + val varO: Origin = InterpretedOriginVariable(s"${C.getDeclaratorInfo(getCDecl(d)).name}_size", o) + val v = new Variable[Post](TInt())(varO) + dynamicSharedMemLengthVar(d) = v + rw.variables.declare(v) + val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) + val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) + , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(Local(v.ref)), 0))(PanicBlame("Assign should work")) + result ++= Seq(decl, assign) + }) + staticSharedMemNames.foreach{case (d,size) => + implicit val o: Origin = getCDecl(d).o + val decl: Statement[Post] = LocalDecl(cNameSuccessor(d)) + val assign: Statement[Post] = Assign[Post](Local(cNameSuccessor(d).ref) + , NewArray[Post](getInnerType(cNameSuccessor(d).t), Seq(IntegerValue(size)), 0))(PanicBlame("Assign should work")) + result ++= Seq(decl, assign) } - + result } - def kernelProcedure(o: Origin, contract: ApplicableContract[Pre], info: C.DeclaratorInfo[Pre], body: Option[Statement[Pre]]): Procedure[Post] = { + def kernelProcedure(o: Origin, contract: ApplicableContract[Pre], info: C.DeclaratorInfo[Pre], body: Option[Statement[Pre]] + , kernelSpec: CGpgpuKernelSpecifier[Pre]): Procedure[Post] = { dynamicSharedMemNames.clear() staticSharedMemNames.clear() - inKernel = true + kernelSpecifier = Some(kernelSpec) val blockDim = new CudaVec(RefCudaBlockDim())(o) val gridDim = new CudaVec(RefCudaGridDim())(o) @@ -418,7 +419,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val gridContext: Expr[Post] = foldStar(unfoldStar(contract.contextEverywhere) .filter(hasNoSharedMemNames) - // TODO: unsure if we need to map over all threads. context anywhere feels like it should apply as general knowledge unrelated to thread ids .map(allThreadsInGrid))(o) val requires: Expr[Post] = foldStar(Seq(gridContext, nonZeroThreads) ++ unfoldStar(contractRequires).filter(hasNoSharedMemNames).map(allThreadsInGrid) )(o) @@ -440,7 +440,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O contract.decreases.map(rw.dispatch), )(contract.blame)(contract.o) )(AbstractApplicable)(o) - inKernel = false + kernelSpecifier = None result } @@ -494,7 +494,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O case _ => throw WrongGPULocalType(declStatement) } - def isShared(s: Statement[Pre]): Boolean = { if(!s.isInstanceOf[CDeclarationStatement[Pre]]) return false @@ -510,8 +509,8 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val varO = InterpretedOriginVariable(C.getDeclaratorInfo(init.decl).name, decl.o) val cRef = RefCLocalDeclaration(decl, 0) - language match { - case Some(Language.CUDA) => + kernelSpecifier match { + case Some(CUDAKernel()) => if (!prop.global && prop.arrayOrPointer && prop.innerType.isDefined && prop.extern) { addDynamicShared(cRef, prop.innerType.get, varO) return true @@ -520,13 +519,12 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O addStaticShared(init.decl, cRef, prop.innerType.get, varO, decl) return true } - case Some(Language.C) => + case Some(OpenCLKernel()) => if (!prop.global && prop.arrayOrPointer && prop.innerType.isDefined && !prop.extern) { addStaticShared(init.decl, cRef, prop.innerType.get, varO, decl) return true } - case Some(l) => throw Unreachable(f"It should not be possible for the language '$l' to have GPU kernels.") - case None => throw Unreachable(f"We have GPU kernels, but the source code language could not be determined.") + case None => throw Unreachable(f"This should have been called from inside a GPU kernel scope.") } // We are shared, but couldn't add it @@ -542,25 +540,25 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O } def rewriteGlobalDecl(decl: CGlobalDeclaration[Pre]): Unit = { - val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) + val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.get for((init, idx) <- decl.decl.inits.zipWithIndex if init.ref.isEmpty) { // If the reference is empty , skip the declaration: the definition is used instead. val info = C.getDeclaratorInfo(init.decl) info.params match { case Some(params) => cFunctionDeclSuccessor((decl, idx)) = rw.globalDeclarations.declare( - if(decl.decl.specs.collectFirst { case CKernel() => () }.nonEmpty) { - kernelProcedure(init.o, decl.decl.contract, info, None) - } else { - new Procedure[Post]( - returnType = t, - args = rw.variables.collect { params.foreach(rw.dispatch) }._1, - outArgs = Nil, - typeArgs = Nil, - body = None, - contract = rw.dispatch(decl.decl.contract), - )(AbstractApplicable)(init.o) - } + decl.decl.specs.collectFirst { case k: CGpgpuKernelSpecifier[Pre] + => kernelProcedure(init.o, decl.decl.contract, info, None, k) } + .getOrElse( + new Procedure[Post]( + returnType = t, + args = rw.variables.collect { params.foreach(rw.dispatch) }._1, + outArgs = Nil, + typeArgs = Nil, + body = None, + contract = rw.dispatch(decl.decl.contract), + )(AbstractApplicable)(init.o) + ) ) case None => throw CGlobalStateNotSupported(init) @@ -571,7 +569,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def rewriteLocal(decl: CLocalDeclaration[Pre]): Statement[Post] = { decl.drop() // PB: this is correct because Seq[CInit]'s are flattened, but the structure is a bit stupid. - val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.getOrElse(???) + val t = decl.decl.specs.collectFirst { case t: CSpecificationType[Pre] => rw.dispatch(t.t) }.get decl.decl.specs.foreach { case _: CSpecificationType[Pre] => case _ => throw WrongCType(decl) @@ -581,14 +579,14 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O val varO: Origin = InterpretedOriginVariable(info.name, init.o) (info.params, init) match { case (Some(params), _) => ??? - case (None, CInit(CArrayDeclarator(qualifiers, size, inner), _)) => + case (None, CInit(CArrayDeclarator(qualifiers, size, _), _)) => if(qualifiers.nonEmpty || init.init.isDefined) throw WrongCType(decl) implicit val o: Origin = init.o val v = new Variable[Post](TArray(t))(varO) cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v val newArr = NewArray[Post](t, Seq(rw.dispatch(size.get)), 0) Block(Seq(LocalDecl(v), assignLocal(v.get, newArr))) - case (None, CInit(d, _)) => + case (None, CInit(_, _)) => val v = new Variable[Post](t)(varO) cNameSuccessor(RefCLocalDeclaration(decl, idx)) = v implicit val o: Origin = init.o @@ -613,17 +611,16 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O case GpuGlobalMemoryFence() => globalFence = true case GpuZeroMemoryFence(i) => if(i != 0) throw WrongBarrierSpecifier(barrier) } - // TODO: create requirement that shared memory arrays are not NULL? if(!globalFence || !localFence){ val redist = permissionScanner(barrier) if(!globalFence) redist .intersect(globalMemNames.toSet) - .foreach(v => throw RedistributingBarrier(v, global = true)) + .foreach(v => throw RedistributingBarrier(v, barrier, global = true)) if(!localFence) redist .intersect(dynamicSharedMemNames.union(dynamicSharedMemNames).toSet) - .foreach(v => throw RedistributingBarrier(v, global = false)) + .foreach(v => throw RedistributingBarrier(v, barrier, global = false)) } ParBarrier[Post]( @@ -667,7 +664,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre], language: O def searchPermission(e: Node[Pre]): Seq[CNameTarget[Pre]] = { e match { - case e: Expr[Pre] if e.t != TResource[Pre]() => return Seq() + case e: Expr[Pre] if e.t != TResource[Pre]() => Seq() case Perm(loc, _) => searchNames(loc, e) case PointsTo(loc, _, _) => searchNames(loc, e) case CurPerm(loc) => searchNames(loc, e) diff --git a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala index 49e1e0d853..500c815281 100644 --- a/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala +++ b/src/main/java/vct/col/newrewrite/lang/LangSpecificToCol.scala @@ -7,11 +7,9 @@ import vct.col.ast._ import vct.col.origin._ import vct.col.resolve._ import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder} -import vct.parsers.Language import vct.result.VerificationError.UserError case object LangSpecificToCol extends RewriterBuilder { - def apply[Pre <: Generation](): AbstractRewriter[Pre, _ <: Generation] = apply(None) override def key: String = "langSpecific" override def desc: String = "Translate language-specific constructs to a common subset of nodes." @@ -28,9 +26,9 @@ case object LangSpecificToCol extends RewriterBuilder { } } -case class LangSpecificToCol[Pre <: Generation](language: Option[Language]) extends Rewriter[Pre] with LazyLogging { +case class LangSpecificToCol[Pre <: Generation]() extends Rewriter[Pre] with LazyLogging { val java: LangJavaToCol[Pre] = LangJavaToCol(this) - val c: LangCToCol[Pre] = LangCToCol(this, language) + val c: LangCToCol[Pre] = LangCToCol(this) val pvl: LangPVLToCol[Pre] = LangPVLToCol(this) val silver: LangSilverToCol[Pre] = LangSilverToCol(this) diff --git a/src/main/java/vct/main/stages/Parsing.scala b/src/main/java/vct/main/stages/Parsing.scala index 4249727ce5..054d0b0ab9 100644 --- a/src/main/java/vct/main/stages/Parsing.scala +++ b/src/main/java/vct/main/stages/Parsing.scala @@ -36,11 +36,11 @@ case class Parsing[G <: Generation] cSystemInclude: Path = Resources.getCIncludePath, cOtherIncludes: Seq[Path] = Nil, cDefines: Map[String, String] = Map.empty, -) extends Stage[Seq[Readable], (ParseResult[G], Option[Language])] { +) extends Stage[Seq[Readable], ParseResult[G]] { override def friendlyName: String = "Parsing" override def progressWeight: Int = 4 - override def run(in: Seq[Readable]): (ParseResult[G], Option[Language]) = + override def run(in: Seq[Readable]): ParseResult[G] = ParseResult.reduce(in.map { readable => val language = forceLanguage .orElse(Language.fromFilename(readable.fileName)) @@ -49,14 +49,13 @@ case class Parsing[G <: Generation] val originProvider = ReadableOriginProvider(readable) val parser = language match { - case Language.C | Language.CUDA - => ColCParser(originProvider, blameProvider, cc, cSystemInclude, cOtherIncludes, cDefines, language) + case Language.C => ColCParser(originProvider, blameProvider, cc, cSystemInclude, cOtherIncludes, cDefines) case Language.InterpretedC => ColIParser(originProvider, blameProvider) case Language.Java => ColJavaParser(originProvider, blameProvider) case Language.PVL => ColPVLParser(originProvider, blameProvider) case Language.Silver => ColSilverParser(originProvider, blameProvider) } - (parser.parse[G](readable), Some(language)) + parser.parse[G](readable) }) } \ No newline at end of file diff --git a/src/main/java/vct/main/stages/Resolution.scala b/src/main/java/vct/main/stages/Resolution.scala index 9005ff4806..712b7c7b92 100644 --- a/src/main/java/vct/main/stages/Resolution.scala +++ b/src/main/java/vct/main/stages/Resolution.scala @@ -12,7 +12,7 @@ import vct.main.Main.TemporarilyUnsupported import vct.main.stages.Resolution.InputResolutionError import vct.main.stages.Transformation.TransformationCheckError import vct.options.Options -import vct.parsers.{Language, ParseResult} +import vct.parsers.{ParseResult} import vct.parsers.transform.BlameProvider import vct.resources.Resources import vct.result.VerificationError.UserError @@ -38,12 +38,11 @@ case class Resolution[G <: Generation] blameProvider: BlameProvider, withJava: Boolean = true, javaLibraryPath: Path = Resources.getJrePath, -) extends Stage[(ParseResult[G], Option[Language]), VerificationContext[_ <: Generation]] { +) extends Stage[ParseResult[G], VerificationContext[_ <: Generation]] { override def friendlyName: String = "Name Resolution" override def progressWeight: Int = 1 - override def run(inLanguage: (ParseResult[G], Option[Language]) ): VerificationContext[_ <: Generation] = { - val (in, language) = inLanguage + override def run(in: ParseResult[G]): VerificationContext[_ <: Generation] = { in.decls.foreach(_.transSubnodes.foreach { case decl: CGlobalDeclaration[_] => decl.decl.inits.foreach(init => { if(C.getDeclaratorInfo(init.decl).params.isEmpty) { @@ -64,7 +63,7 @@ case class Resolution[G <: Generation] case Nil => // ok case some => throw InputResolutionError(some) } - val resolvedProgram = LangSpecificToCol(language).dispatch(typedProgram) + val resolvedProgram = LangSpecificToCol().dispatch(typedProgram) resolvedProgram.check match { case Nil => // ok case some => throw TransformationCheckError(some) diff --git a/src/main/java/vct/main/util/Util.scala b/src/main/java/vct/main/util/Util.scala index bc1730ebae..825d0cb723 100644 --- a/src/main/java/vct/main/util/Util.scala +++ b/src/main/java/vct/main/util/Util.scala @@ -23,7 +23,7 @@ case object Util { def loadPVLLibraryFile[G](readable: Readable): Program[G] = { val res = ColPVLParser(ReadableOriginProvider(readable), ConstantBlameProvider(LibraryFileBlame)).parse(readable) - val context = Resolution(ConstantBlameProvider(LibraryFileBlame), withJava = false).run((res, Some(Language.PVL))) + val context = Resolution(ConstantBlameProvider(LibraryFileBlame), withJava = false).run(res) assert(context.expectedErrors.isEmpty) val unambiguousProgram: Program[_] = Disambiguate().dispatch(context.program) unambiguousProgram.asInstanceOf[Program[G]] diff --git a/parsers/src/main/java/vct/parsers/Language.scala b/src/main/java/vct/parsers/Language.scala similarity index 83% rename from parsers/src/main/java/vct/parsers/Language.scala rename to src/main/java/vct/parsers/Language.scala index 20780f923e..60b22696a6 100644 --- a/parsers/src/main/java/vct/parsers/Language.scala +++ b/src/main/java/vct/parsers/Language.scala @@ -3,8 +3,7 @@ package vct.parsers case object Language { def fromFilename(filename: String): Option[Language] = filename.split('.').last match { - case "cl" | "c" => Some(C) - case "cu" => Some(CUDA) + case "cl" | "c" | "cu" => Some(C) case "i" => Some(InterpretedC) case "java" => Some(Java) case "pvl" => Some(PVL) @@ -13,7 +12,6 @@ case object Language { } case object C extends Language - case object CUDA extends Language case object InterpretedC extends Language case object Java extends Language case object PVL extends Language diff --git a/src/main/universal/res/c/cuda.h b/src/main/universal/res/c/cuda.h index f2e46c3689..a36b7c406e 100644 --- a/src/main/universal/res/c/cuda.h +++ b/src/main/universal/res/c/cuda.h @@ -1,7 +1,7 @@ #ifndef CUDA_H #define CUDA_H -#define __global__ __vercors_kernel__ +#define __global__ __cuda_kernel__ #define __shared__ __vercors_local_memory__ #define bool _Bool diff --git a/src/main/universal/res/c/opencl.h b/src/main/universal/res/c/opencl.h index a5e6f1fae2..491fd93c59 100644 --- a/src/main/universal/res/c/opencl.h +++ b/src/main/universal/res/c/opencl.h @@ -1,7 +1,7 @@ #ifndef OPENCL_H #define OPENCL_H -#define __kernel __vercors_kernel__ +#define __kernel __opencl_kernel__ #define CLK_GLOBAL_MEM_FENCE __vercors_global_mem_fence__ #define CLK_LOCAL_MEM_FENCE __vercors_local_mem_fence__