Skip to content

Commit

Permalink
Handle binding of beta reduced inlined lambdas
Browse files Browse the repository at this point in the history
Handle all inline beta-reduction in the InlineReducer. All these
applications will contain `Inlined` nodes that need to be handled
without changing the nestedness of expressions in inlining scopes.

Fixes #16374
  • Loading branch information
nicolasstucki committed Nov 22, 2022
1 parent b4f8eef commit d0a17ce
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,6 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
Some(meth)
case Block(Nil, expr) =>
unapply(expr)
case Inlined(_, bindings, expr) if bindings.forall(isPureBinding) =>
unapply(expr)
case _ =>
None
}
Expand Down
65 changes: 38 additions & 27 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,35 +158,46 @@ class InlineReducer(inliner: Inliner)(using Context):
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* This variant of beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl @ closureDef(ddef), nme.apply), args) if defn.isFunctionType(cl.tpe) =>
// closureDef also returns a result for closures wrapped in Inlined nodes.
// These need to be preserved.
def recur(cl: Tree): Tree = cl match
case Inlined(call, bindings, expr) =>
cpy.Inlined(cl)(call, bindings, recur(expr))
case _ => ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val bindingsBuf = new DefBuffer
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new DefBuffer
def recur(cl: Tree): Option[Tree] = cl match
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr).map(cpy.Inlined(cl)(call, bindings, _))
case Block(Nil, expr) =>
recur(expr).map(cpy.Block(cl)(Nil, _))
case Typed(expr, tpt) =>
recur(expr).map(cpy.Typed(cl)(_, tpt))
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
}
}
}
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Block(bindingsBuf.toList, expander.transform(ddef.rhs)).withSpan(tree.span)
case _ => tree
recur(cl)
case _ => tree
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Some(expander.transform(ddef.rhs))
case _ => None
case _ => None
recur(cl) match
case Some(reduced) =>
Block(bindingsBuf.toList, reduced).withSpan(tree.span)
case None =>
tree
case _ =>
tree
}

/** The result type of reducing a match. It consists optionally of a list of bindings
Expand Down Expand Up @@ -281,7 +292,7 @@ class InlineReducer(inliner: Inliner)(using Context):
// Test case is pos-macros/i15971
val tptBinds = getBinds(Set.empty[TypeSymbol], tpt)
val binds: Set[TypeSymbol] = pat match {
case UnApply(TypeApply(_, tpts), _, _) =>
case UnApply(TypeApply(_, tpts), _, _) =>
getBinds(Set.empty[TypeSymbol], tpts) ++ tptBinds
case _ => tptBinds
}
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/i16374a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def method(using String): String = ???

inline def inlineMethod(inline op: String => Unit)(using String): Unit =
println(op(method))

def test(using String) =
inlineMethod(c => print(c))
9 changes: 9 additions & 0 deletions tests/pos/i16374b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def method(using String): String = ???

inline def identity[T](inline x: T): T = x

inline def inlineMethod(inline op: String => Unit)(using String): Unit =
println(identity(op)(method))

def test(using String) =
inlineMethod(c => print(c))

0 comments on commit d0a17ce

Please sign in to comment.