From fa188580ac328f610d1e5c538c6a943b0e36d9a3 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 12 Dec 2024 18:16:28 +0000 Subject: [PATCH 1/2] POC: Make ensuring inline --- .../frontends/dotc/ASTExtractors.scala | 43 +++++++++++++++++-- .../frontends/dotc/CodeExtraction.scala | 11 ++--- .../library/stainless/lang/StaticChecks.scala | 16 ++++--- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala index 6b9d8d51df..2507d5e272 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala @@ -18,6 +18,7 @@ import core.Annotations._ import util.SourcePosition import scala.collection.mutable.{Map => MutableMap} +import dotty.tools.dotc.transform.Inlining trait ASTExtractors { val dottyCtx: DottyContext @@ -527,6 +528,32 @@ trait ASTExtractors { } } + object ExInlinedCall { + /** Extracts an inlined function or method call, returning a 4-tuple + * containing the receiver, function or method symbol, type arguments, + * and term arguments. + * + * Unlift the receiver and arguments to their definitions if they are + * inline proxies. + */ + def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree])] = tree match { + // TODO(mbovel): should probably use InlineOrProxy instead of + // Synthetic. Why is this not working? + case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, _)) if stats.forall(_.symbol.is(Synthetic)) => + def unliftArg(arg: tpd.Tree) = arg match { + case arg @ Ident(_) if arg.symbol.is(Synthetic) => + stats.find(_.symbol == arg.symbol) match + case Some(proxyDef: tpd.ValOrDefDef) => proxyDef.rhs + case _ => throw new IllegalStateException(s"Could not find inline proxy definition for $arg") + case arg => arg + } + Some(rec.map(unliftArg), sym, tps, args.map(unliftArg)) + case Inlined(ExCall(rec, sym, tps, args), _, _) => + Some(rec, sym, tps, args) + case _ => None + } + } + object ExClassConstruction { def unapply(tree: tpd.Tree): Option[(Type, Seq[tpd.Tree])] = tree match { case Apply(Select(New(tpt), nme.CONSTRUCTOR), args) => @@ -1180,17 +1207,25 @@ trait ASTExtractors { } object ExEnsuredExpression { + /** Extracts an `ensuring` call. + * + * When matching, returngs a triple containing the receiver, the contract + * and a boolean indicating if the check is static. + */ def unapply(tree: tpd.Tree): Option[(tpd.Tree, tpd.Tree, Boolean)] = tree match { + // Dynamic check (Predef.ensuring) // An optional message may comes after `contract`, but we do not make use of it. case ExCall(Some(rec), ExSymbol("scala", "Predef$", "Ensuring", "ensuring"), _, contract +: _ ) => Some((rec, contract, false)) - // Ditto - case ExCall(Some(rec), - ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring", "ensuring"), - _, contract +: _ + // Static check (stainless.lang.StaticChecks.ensuring) + case ExInlinedCall( + _, + ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"), + _, + Seq(rec, contract, message) ) => Some((rec, contract, true)) case _ => None diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala index 0553d3d7a0..3c49959ca6 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala @@ -1327,10 +1327,6 @@ class CodeExtraction(inoxCtx: inox.Context, case (v, vd) => v.symbol -> (() => vd.toVariable) }))) - case Block(es, e) => - val b = extractBlock(es :+ e) - xt.exprOps.flattenBlocks(b) - case Try(body, cses, fin) => val rb = extractTree(body) val rc = cses.map(extractMatchCase) @@ -1382,6 +1378,11 @@ class CodeExtraction(inoxCtx: inox.Context, })).setPos(post) }) + // Needs to be after `ExEnsuredExpression`, as it matches blocks. + case Block(es, e) => + val b = extractBlock(es :+ e) + xt.exprOps.flattenBlocks(b) + case ExThrowingExpression(body, contract) => val pred = extractTree(contract) val b = extractTreeOrNoTree(body) @@ -1393,7 +1394,7 @@ class CodeExtraction(inoxCtx: inox.Context, val vd = xt.ValDef.fresh("res", tpe).setPos(other) xt.Lambda(Seq(vd), xt.Application(other, Seq(vd.toVariable)).setPos(other)).setPos(other) }) - + case t @ ExHoldsWithProofExpression(body, ExMaybeBecauseExpressionWrapper(proof)) => val vd = xt.ValDef.fresh("holds", xt.BooleanType().setPos(tr.sourcePos)).setPos(tr.sourcePos) val p = extractTreeOrNoTree(proof) diff --git a/frontends/library/stainless/lang/StaticChecks.scala b/frontends/library/stainless/lang/StaticChecks.scala index 26492947f4..e404527d54 100644 --- a/frontends/library/stainless/lang/StaticChecks.scala +++ b/frontends/library/stainless/lang/StaticChecks.scala @@ -4,12 +4,16 @@ import stainless.annotation._ object StaticChecks { - @library - implicit class Ensuring[A](val x: A) extends AnyVal { - def ensuring(@ghost cond: A => Boolean): A = x - - def ensuring(@ghost cond: A => Boolean, msg: => String): A = x - } + extension [A](x: A) + @library + inline def ensuring(@ghost inline cond: A => Boolean, inline msg: String = ""): A = x + + //@library + //implicit class Ensuring[A](val x: A) extends AnyVal { + // def ensuring(@ghost cond: A => Boolean): A = x + // + // def ensuring(@ghost cond: A => Boolean, msg: => String): A = x + //} @library @ignore implicit class WhileDecorations(val u: Unit) { From 7224d8cb835ffeda074720b81e925c49bc79ed87 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Fri, 13 Dec 2024 12:42:13 +0000 Subject: [PATCH 2/2] Partially fix ExInlinedCall --- .../frontends/dotc/ASTExtractors.scala | 74 +++++++++++++------ .../library/stainless/lang/StaticChecks.scala | 2 +- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala index 2507d5e272..a6269063e6 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala @@ -529,27 +529,58 @@ trait ASTExtractors { } object ExInlinedCall { - /** Extracts an inlined function or method call, returning a 4-tuple - * containing the receiver, function or method symbol, type arguments, - * and term arguments. - * + /** Extracts an inlined function or method call, returning a 5-tuple + * containing the inlined call's receiver, function or method symbol, + * type arguments, term arguments, and expansion. + * * Unlift the receiver and arguments to their definitions if they are * inline proxies. */ - def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree])] = tree match { - // TODO(mbovel): should probably use InlineOrProxy instead of - // Synthetic. Why is this not working? - case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, _)) if stats.forall(_.symbol.is(Synthetic)) => - def unliftArg(arg: tpd.Tree) = arg match { - case arg @ Ident(_) if arg.symbol.is(Synthetic) => - stats.find(_.symbol == arg.symbol) match - case Some(proxyDef: tpd.ValOrDefDef) => proxyDef.rhs - case _ => throw new IllegalStateException(s"Could not find inline proxy definition for $arg") - case arg => arg - } - Some(rec.map(unliftArg), sym, tps, args.map(unliftArg)) - case Inlined(ExCall(rec, sym, tps, args), _, _) => - Some(rec, sym, tps, args) + def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree], tpd.Tree)] = tree match { + case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, expansion)) + if stats.forall(_.symbol.isOneOf(InlineProxy | Synthetic)) => + + val unliftInlineProxies = + new tpd.TreeMap: + /** Adapted from Dotty's `tpd.MapToUnderlying.transform`. */ + override def transform(tree: tpd.Tree)(using DottyContext): tpd.Tree = + tree match + // The `InlineProxy` flag does not seem to be set for inline + // argument proxies, not sure why. It should be set at: + // https://github.com/scala/scala3/blob/20e6f11f4fe47982259eba949eea78d65765142f/compiler/src/dotty/tools/dotc/inlines/Inliner.scala#L224. + // Including `Synthetic` to make this work, like in Dotty's + // `tpd.TreeOps.underlyingArgument`, but this might be an + // over-approximation. + case tree: tpd.Ident if tree.symbol.isOneOf(InlineProxy | Synthetic) => + stats.find(_.symbol == tree.symbol) match + case Some(defTree: tpd.ValOrDefDef) => + val rhs = defTree.rhs + assert(!rhs.isEmpty) + // Contrary to Dotty's `tpd.MapToUnderlying`, do not + // recurse over rhs. + rhs + case defTree => tree + case tree => + super.transform(tree) + + def unwrapExtraction(tree: tpd.Tree): tpd.Tree = + tree match + case Inlined(tpd.EmptyTree, _, expansion) => + unwrapExtraction(expansion) + case Typed(expansion, _) => + unwrapExtraction(expansion) + case _ => + tree + + Some( + rec.map(unliftInlineProxies.transform), + sym, + tps, + args.map(unliftInlineProxies.transform), + unliftInlineProxies.transform(unwrapExtraction(expansion)) + ) + case Inlined(ExCall(rec, sym, tps, args), _, expansion) => + Some(rec, sym, tps, args, expansion) case _ => None } } @@ -607,7 +638,7 @@ trait ASTExtractors { def unapply(tree: tpd.Tree): Option[tpd.Tree] = tree match { case Apply( ExSymbol("scala", "Predef$", "Ensuring") | - ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring"), + ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"), Seq(arg)) => Some(arg) case Apply(ExSymbol("stainless", "lang", "package$", "Throwing"), Seq(arg)) => Some(arg) @@ -1225,8 +1256,9 @@ trait ASTExtractors { _, ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"), _, - Seq(rec, contract, message) - ) => Some((rec, contract, true)) + Seq(_, contract, message), + expansion + ) => Some((expansion, contract, true)) case _ => None } diff --git a/frontends/library/stainless/lang/StaticChecks.scala b/frontends/library/stainless/lang/StaticChecks.scala index e404527d54..b17ec4ab5e 100644 --- a/frontends/library/stainless/lang/StaticChecks.scala +++ b/frontends/library/stainless/lang/StaticChecks.scala @@ -6,7 +6,7 @@ object StaticChecks { extension [A](x: A) @library - inline def ensuring(@ghost inline cond: A => Boolean, inline msg: String = ""): A = x + inline def ensuring(cond: A => Boolean, msg: => String = ""): A = x //@library //implicit class Ensuring[A](val x: A) extends AnyVal {