diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala index 6b9d8d51d..a6269063e 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala @@ -18,6 +18,7 @@ import core.Annotations._ import util.SourcePosition import scala.collection.mutable.{Map => MutableMap} +import dotty.tools.dotc.transform.Inlining trait ASTExtractors { val dottyCtx: DottyContext @@ -527,6 +528,63 @@ trait ASTExtractors { } } + object ExInlinedCall { + /** Extracts an inlined function or method call, returning a 5-tuple + * containing the inlined call's receiver, function or method symbol, + * type arguments, term arguments, and expansion. + * + * Unlift the receiver and arguments to their definitions if they are + * inline proxies. + */ + def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree], tpd.Tree)] = tree match { + case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, expansion)) + if stats.forall(_.symbol.isOneOf(InlineProxy | Synthetic)) => + + val unliftInlineProxies = + new tpd.TreeMap: + /** Adapted from Dotty's `tpd.MapToUnderlying.transform`. */ + override def transform(tree: tpd.Tree)(using DottyContext): tpd.Tree = + tree match + // The `InlineProxy` flag does not seem to be set for inline + // argument proxies, not sure why. It should be set at: + // https://github.com/scala/scala3/blob/20e6f11f4fe47982259eba949eea78d65765142f/compiler/src/dotty/tools/dotc/inlines/Inliner.scala#L224. + // Including `Synthetic` to make this work, like in Dotty's + // `tpd.TreeOps.underlyingArgument`, but this might be an + // over-approximation. + case tree: tpd.Ident if tree.symbol.isOneOf(InlineProxy | Synthetic) => + stats.find(_.symbol == tree.symbol) match + case Some(defTree: tpd.ValOrDefDef) => + val rhs = defTree.rhs + assert(!rhs.isEmpty) + // Contrary to Dotty's `tpd.MapToUnderlying`, do not + // recurse over rhs. + rhs + case defTree => tree + case tree => + super.transform(tree) + + def unwrapExtraction(tree: tpd.Tree): tpd.Tree = + tree match + case Inlined(tpd.EmptyTree, _, expansion) => + unwrapExtraction(expansion) + case Typed(expansion, _) => + unwrapExtraction(expansion) + case _ => + tree + + Some( + rec.map(unliftInlineProxies.transform), + sym, + tps, + args.map(unliftInlineProxies.transform), + unliftInlineProxies.transform(unwrapExtraction(expansion)) + ) + case Inlined(ExCall(rec, sym, tps, args), _, expansion) => + Some(rec, sym, tps, args, expansion) + case _ => None + } + } + object ExClassConstruction { def unapply(tree: tpd.Tree): Option[(Type, Seq[tpd.Tree])] = tree match { case Apply(Select(New(tpt), nme.CONSTRUCTOR), args) => @@ -580,7 +638,7 @@ trait ASTExtractors { def unapply(tree: tpd.Tree): Option[tpd.Tree] = tree match { case Apply( ExSymbol("scala", "Predef$", "Ensuring") | - ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring"), + ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"), Seq(arg)) => Some(arg) case Apply(ExSymbol("stainless", "lang", "package$", "Throwing"), Seq(arg)) => Some(arg) @@ -1180,18 +1238,27 @@ trait ASTExtractors { } object ExEnsuredExpression { + /** Extracts an `ensuring` call. + * + * When matching, returngs a triple containing the receiver, the contract + * and a boolean indicating if the check is static. + */ def unapply(tree: tpd.Tree): Option[(tpd.Tree, tpd.Tree, Boolean)] = tree match { + // Dynamic check (Predef.ensuring) // An optional message may comes after `contract`, but we do not make use of it. case ExCall(Some(rec), ExSymbol("scala", "Predef$", "Ensuring", "ensuring"), _, contract +: _ ) => Some((rec, contract, false)) - // Ditto - case ExCall(Some(rec), - ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring", "ensuring"), - _, contract +: _ - ) => Some((rec, contract, true)) + // Static check (stainless.lang.StaticChecks.ensuring) + case ExInlinedCall( + _, + ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"), + _, + Seq(_, contract, message), + expansion + ) => Some((expansion, contract, true)) case _ => None } diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala index 0553d3d7a..3c49959ca 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala @@ -1327,10 +1327,6 @@ class CodeExtraction(inoxCtx: inox.Context, case (v, vd) => v.symbol -> (() => vd.toVariable) }))) - case Block(es, e) => - val b = extractBlock(es :+ e) - xt.exprOps.flattenBlocks(b) - case Try(body, cses, fin) => val rb = extractTree(body) val rc = cses.map(extractMatchCase) @@ -1382,6 +1378,11 @@ class CodeExtraction(inoxCtx: inox.Context, })).setPos(post) }) + // Needs to be after `ExEnsuredExpression`, as it matches blocks. + case Block(es, e) => + val b = extractBlock(es :+ e) + xt.exprOps.flattenBlocks(b) + case ExThrowingExpression(body, contract) => val pred = extractTree(contract) val b = extractTreeOrNoTree(body) @@ -1393,7 +1394,7 @@ class CodeExtraction(inoxCtx: inox.Context, val vd = xt.ValDef.fresh("res", tpe).setPos(other) xt.Lambda(Seq(vd), xt.Application(other, Seq(vd.toVariable)).setPos(other)).setPos(other) }) - + case t @ ExHoldsWithProofExpression(body, ExMaybeBecauseExpressionWrapper(proof)) => val vd = xt.ValDef.fresh("holds", xt.BooleanType().setPos(tr.sourcePos)).setPos(tr.sourcePos) val p = extractTreeOrNoTree(proof) diff --git a/frontends/library/stainless/lang/StaticChecks.scala b/frontends/library/stainless/lang/StaticChecks.scala index 26492947f..b17ec4ab5 100644 --- a/frontends/library/stainless/lang/StaticChecks.scala +++ b/frontends/library/stainless/lang/StaticChecks.scala @@ -4,12 +4,16 @@ import stainless.annotation._ object StaticChecks { - @library - implicit class Ensuring[A](val x: A) extends AnyVal { - def ensuring(@ghost cond: A => Boolean): A = x - - def ensuring(@ghost cond: A => Boolean, msg: => String): A = x - } + extension [A](x: A) + @library + inline def ensuring(cond: A => Boolean, msg: => String = ""): A = x + + //@library + //implicit class Ensuring[A](val x: A) extends AnyVal { + // def ensuring(@ghost cond: A => Boolean): A = x + // + // def ensuring(@ghost cond: A => Boolean, msg: => String): A = x + //} @library @ignore implicit class WhileDecorations(val u: Unit) {