Skip to content

Commit

Permalink
Implement polymorphic lambdas using Closure nodes for efficiency
Browse files Browse the repository at this point in the history
Previously, we desugared them manually into anonymous class instances,
but by using a Closure node instead, we ensure that they get translated
into indy lambdas on the JVM.

Also cleaned up and added a TODO in the desugaring of polymorphic function types
into refinement types since I realized that purity wasn't taken into account.
  • Loading branch information
smarter committed May 21, 2023
1 parent 75ab141 commit 7535ede
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 77 deletions.
97 changes: 41 additions & 56 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,40 @@ object desugar {
name
}

/** Strip parens and empty blocks around the body of `tree`. */
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
def stripped(body: Tree): Tree = body match
case Parens(body1) =>
stripped(body1)
case Block(Nil, body1) =>
stripped(body1)
case _ => body
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val funFlags = fun match
case fun: FunctionWithMods =>
fun.mods.flags
case _ => EmptyFlags

// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val paramFlags = funFlags.toTermFlags & Given
val vparams = vparamTypes.zipWithIndex.map:
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
val str = impl match
Expand Down Expand Up @@ -1413,14 +1447,17 @@ object desugar {
}

/** Make closure corresponding to function.
* params => body
* [tparams] => params => body
* ==>
* def $anonfun(params) = body
* def $anonfun[tparams](params) = body
* Closure($anonfun)
*/
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
val paramss: List[ParamClause] =
if tparams.isEmpty then vparams :: Nil
else tparams :: vparams :: Nil
Block(
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
.withSpan(span)
.withMods(synthetic | Artifact),
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
Expand Down Expand Up @@ -1712,56 +1749,6 @@ object desugar {
}
}

def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
case Parens(body1) =>
makePolyFunction(targs, body1, pt)
case Block(Nil, body1) =>
makePolyFunction(targs, body1, pt)
case Function(vargs, res) =>
assert(targs.nonEmpty)
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
val mods = body match {
case body: FunctionWithMods => body.mods
case _ => untpd.EmptyModifiers
}
val polyFunctionTpt = ref(defn.PolyFunctionType)
val applyTParams = targs.asInstanceOf[List[TypeDef]]
if (ctx.mode.is(Mode.Type)) {
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }

val applyVParams = vargs.zipWithIndex.map {
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
}
RefinedTypeTree(polyFunctionTpt, List(
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
))
}
else {
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.

def typeTree(tp: Type) = tp match
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
untpd.DependentTypeTree((tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
case _ => TypeTree()

val applyVParams = vargs.asInstanceOf[List[ValDef]]
.map(varg => varg.withAddedFlags(mods.flags | Param))
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
))
}
case _ =>
// may happen for erroneous input. An error will already have been reported.
assert(ctx.reporter.errorsReported)
EmptyTree
}

// begin desugar

// Special case for `Parens` desugaring: unlike all the desugarings below,
Expand All @@ -1774,8 +1761,6 @@ object desugar {
}

val desugared = tree match {
case PolyFunction(targs, body) =>
makePolyFunction(targs, body, pt) orElse tree
case SymbolLit(str) =>
Apply(
ref(defn.ScalaSymbolClass.companionModule.termRef),
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,8 @@ object Types {
if alwaysDependent || mt.isResultDependent then
RefinedType(funType, nme.apply, mt)
else funType
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
RefinedType(defn.PolyFunctionType, nme.apply, poly)
}

/** The signature of this type. This is by default NotAMethod,
Expand Down
57 changes: 40 additions & 17 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1633,12 +1633,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
)
cpy.ValDef(param)(tpt = paramTpt)
if isErased then param0.withAddedFlags(Flags.Erased) else param0
desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, tree.span)
desugared = desugar.makeClosure(Nil, inferredParams, fnBody, resultTpt, tree.span)

typed(desugared, pt)
.showing(i"desugared fun $tree --> $desugared with pt = $pt", typr)
}


def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked

val resultTpt = pt.dealias match
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
untpd.DependentTypeTree((tsyms, vsyms) =>
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
case _ => untpd.TypeTree()

val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span)
typed(desugared, pt)
end typedPolyFunctionValue

def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
val env1 = tree.env mapconserve (typed(_))
val meth1 = typedUnadapted(tree.meth)
Expand Down Expand Up @@ -1676,6 +1696,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
else
EmptyTree
}
case _: PolyType =>
// Polymorphic SAMs are not currently supported (#6904).
EmptyTree
case tp =>
if !tp.isErroneous then
throw new java.lang.Error(i"internal error: closing over non-method $tp, pos = ${tree.span}")
Expand Down Expand Up @@ -2433,7 +2456,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
}
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
postProcessInfo(sym)
postProcessInfo(vdef, sym)
vdef1.setDefTree
}

Expand Down Expand Up @@ -2536,19 +2559,31 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

val ddef2 = assignType(cpy.DefDef(ddef)(name, paramss1, tpt1, rhs1), sym)

postProcessInfo(sym)
postProcessInfo(ddef2, sym)
ddef2.setDefTree
//todo: make sure dependent method types do not depend on implicits or by-name params
}

/** (1) Check that the signature of the class member does not return a repeated parameter type
* (2) If info is an erased class, set erased flag of member
* (3) Check that erased classes are not parameters of polymorphic functions.
*/
private def postProcessInfo(sym: Symbol)(using Context): Unit =
private def postProcessInfo(mdef: MemberDef, sym: Symbol)(using Context): Unit =
if (!sym.isOneOf(Synthetic | InlineProxy | Param) && sym.info.finalResultType.isRepeatedParam)
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
sym.setFlag(Erased)
if
sym.info.isInstanceOf[PolyType] &&
((sym.name eq nme.ANON_FUN) ||
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
then
mdef match
case DefDef(_, _ :: vparams :: Nil, _, _) =>
vparams.foreach: vparam =>
if vparam.symbol.is(Erased) then
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
case _ =>

def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
val TypeDef(name, rhs) = tdef
Expand Down Expand Up @@ -2695,19 +2730,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// check value class constraints
checkDerivedValueClass(cls, body1)

// check PolyFunction constraints (no erased functions!)
if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then
body1.foreach {
case ddef: DefDef =>
ddef.paramss.foreach { params =>
val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef }
erasedParam.foreach { p =>
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos)
}
}
case _ =>
}

val effectiveOwner = cls.owner.skipWeakOwner
if !cls.isRefinementClass
&& !cls.isAllOf(PrivateLocal)
Expand Down Expand Up @@ -3059,6 +3081,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case tree: untpd.Block => typedBlock(desugar.block(tree), pt)(using ctx.fresh.setNewScope)
case tree: untpd.If => typedIf(tree, pt)
case tree: untpd.Function => typedFunction(tree, pt)
case tree: untpd.PolyFunction => typedPolyFunction(tree, pt)
case tree: untpd.Closure => typedClosure(tree, pt)
case tree: untpd.Import => typedImport(tree)
case tree: untpd.Export => typedExport(tree)
Expand Down
8 changes: 4 additions & 4 deletions tests/neg/polymorphic-functions1.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:33 ---------------------------------------------
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
| ^
| Found: [T] => (x: Int) => x.type
| Required: [T] => (x: T) => x.type
| ^^^^^^^^^^^^^^^^^^^^
| Found: [T] => (x: Int) => x.type
| Required: [T] => (x: T) => x.type
|
| longer explanation available when compiling with `-explain`

0 comments on commit 7535ede

Please sign in to comment.