Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Make ensuring inline #1624

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import core.Annotations._
import util.SourcePosition

import scala.collection.mutable.{Map => MutableMap}
import dotty.tools.dotc.transform.Inlining

trait ASTExtractors {
val dottyCtx: DottyContext
Expand Down Expand Up @@ -527,6 +528,63 @@ trait ASTExtractors {
}
}

object ExInlinedCall {
/** Extracts an inlined function or method call, returning a 5-tuple
* containing the inlined call's receiver, function or method symbol,
* type arguments, term arguments, and expansion.
*
* Unlift the receiver and arguments to their definitions if they are
* inline proxies.
*/
def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree], tpd.Tree)] = tree match {
case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, expansion))
if stats.forall(_.symbol.isOneOf(InlineProxy | Synthetic)) =>

val unliftInlineProxies =
new tpd.TreeMap:
/** Adapted from Dotty's `tpd.MapToUnderlying.transform`. */
override def transform(tree: tpd.Tree)(using DottyContext): tpd.Tree =
tree match
// The `InlineProxy` flag does not seem to be set for inline
// argument proxies, not sure why. It should be set at:
// https://github.com/scala/scala3/blob/20e6f11f4fe47982259eba949eea78d65765142f/compiler/src/dotty/tools/dotc/inlines/Inliner.scala#L224.
// Including `Synthetic` to make this work, like in Dotty's
// `tpd.TreeOps.underlyingArgument`, but this might be an
// over-approximation.
case tree: tpd.Ident if tree.symbol.isOneOf(InlineProxy | Synthetic) =>
stats.find(_.symbol == tree.symbol) match
case Some(defTree: tpd.ValOrDefDef) =>
val rhs = defTree.rhs
assert(!rhs.isEmpty)
// Contrary to Dotty's `tpd.MapToUnderlying`, do not
// recurse over rhs.
rhs
case defTree => tree
case tree =>
super.transform(tree)

def unwrapExtraction(tree: tpd.Tree): tpd.Tree =
tree match
case Inlined(tpd.EmptyTree, _, expansion) =>
unwrapExtraction(expansion)
case Typed(expansion, _) =>
unwrapExtraction(expansion)
case _ =>
tree

Some(
rec.map(unliftInlineProxies.transform),
sym,
tps,
args.map(unliftInlineProxies.transform),
unliftInlineProxies.transform(unwrapExtraction(expansion))
)
case Inlined(ExCall(rec, sym, tps, args), _, expansion) =>
Some(rec, sym, tps, args, expansion)
case _ => None
}
}

object ExClassConstruction {
def unapply(tree: tpd.Tree): Option[(Type, Seq[tpd.Tree])] = tree match {
case Apply(Select(New(tpt), nme.CONSTRUCTOR), args) =>
Expand Down Expand Up @@ -580,7 +638,7 @@ trait ASTExtractors {
def unapply(tree: tpd.Tree): Option[tpd.Tree] = tree match {
case Apply(
ExSymbol("scala", "Predef$", "Ensuring") |
ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring"),
ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"),
Seq(arg)) => Some(arg)

case Apply(ExSymbol("stainless", "lang", "package$", "Throwing"), Seq(arg)) => Some(arg)
Expand Down Expand Up @@ -1180,18 +1238,27 @@ trait ASTExtractors {
}

object ExEnsuredExpression {
/** Extracts an `ensuring` call.
*
* When matching, returngs a triple containing the receiver, the contract
* and a boolean indicating if the check is static.
*/
def unapply(tree: tpd.Tree): Option[(tpd.Tree, tpd.Tree, Boolean)] = tree match {
// Dynamic check (Predef.ensuring)
// An optional message may comes after `contract`, but we do not make use of it.
case ExCall(Some(rec),
ExSymbol("scala", "Predef$", "Ensuring", "ensuring"),
_, contract +: _
) => Some((rec, contract, false))

// Ditto
case ExCall(Some(rec),
ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring", "ensuring"),
_, contract +: _
) => Some((rec, contract, true))
// Static check (stainless.lang.StaticChecks.ensuring)
case ExInlinedCall(
_,
ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"),
_,
Seq(_, contract, message),
expansion
) => Some((expansion, contract, true))

case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1327,10 +1327,6 @@ class CodeExtraction(inoxCtx: inox.Context,
case (v, vd) => v.symbol -> (() => vd.toVariable)
})))

case Block(es, e) =>
val b = extractBlock(es :+ e)
xt.exprOps.flattenBlocks(b)

case Try(body, cses, fin) =>
val rb = extractTree(body)
val rc = cses.map(extractMatchCase)
Expand Down Expand Up @@ -1382,6 +1378,11 @@ class CodeExtraction(inoxCtx: inox.Context,
})).setPos(post)
})

// Needs to be after `ExEnsuredExpression`, as it matches blocks.
case Block(es, e) =>
val b = extractBlock(es :+ e)
xt.exprOps.flattenBlocks(b)

case ExThrowingExpression(body, contract) =>
val pred = extractTree(contract)
val b = extractTreeOrNoTree(body)
Expand All @@ -1393,7 +1394,7 @@ class CodeExtraction(inoxCtx: inox.Context,
val vd = xt.ValDef.fresh("res", tpe).setPos(other)
xt.Lambda(Seq(vd), xt.Application(other, Seq(vd.toVariable)).setPos(other)).setPos(other)
})

case t @ ExHoldsWithProofExpression(body, ExMaybeBecauseExpressionWrapper(proof)) =>
val vd = xt.ValDef.fresh("holds", xt.BooleanType().setPos(tr.sourcePos)).setPos(tr.sourcePos)
val p = extractTreeOrNoTree(proof)
Expand Down
16 changes: 10 additions & 6 deletions frontends/library/stainless/lang/StaticChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import stainless.annotation._

object StaticChecks {

@library
implicit class Ensuring[A](val x: A) extends AnyVal {
def ensuring(@ghost cond: A => Boolean): A = x

def ensuring(@ghost cond: A => Boolean, msg: => String): A = x
}
extension [A](x: A)
@library
inline def ensuring(cond: A => Boolean, msg: => String = ""): A = x

//@library
//implicit class Ensuring[A](val x: A) extends AnyVal {
// def ensuring(@ghost cond: A => Boolean): A = x
//
// def ensuring(@ghost cond: A => Boolean, msg: => String): A = x
//}

@library @ignore
implicit class WhileDecorations(val u: Unit) {
Expand Down
Loading