Skip to content

Commit

Permalink
Use PolyFunction instead of ErasedFunction
Browse files Browse the repository at this point in the history
We generalize the meaning of `PolyFunction` to mean any kind of refined
lambda encoding. These refinements support any type with the following
shape as a lambda type:

```scala
PolyFunction {
  def apply[[T1, ..., Tn]]([given] [erased] x1: X1, ..., [erased] xn: Xn): R
}
```
  • Loading branch information
nicolasstucki committed Jul 26, 2023
1 parent 5625107 commit f701f28
Show file tree
Hide file tree
Showing 21 changed files with 48 additions and 90 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
def isStructuralTermSelect(tree: Select) =
def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
case defn.PolyOrErasedFunctionOf(_) =>
case defn.PolyFunctionOf(_) =>
false
case RefinedType(parent, rname, rinfo) =>
rname == tree.name || hasRefinement(parent)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ extends tpd.TreeTraverser:
val mt = ContextualMethodType(paramName :: Nil)(
_ => paramType :: Nil,
mt => if isLast then res else expandThrowsAlias(res, mt :: encl))
val fntpe = RefinedType(defn.ErasedFunctionClass.typeRef, nme.apply, mt)
val fntpe = RefinedType(defn.PolyFunctionClass.typeRef, nme.apply, mt)
if !encl.isEmpty && isLast then
val cs = CaptureSet(encl.map(_.paramRefs.head)*)
CapturingType(fntpe, cs, boxed = false)
Expand Down
56 changes: 8 additions & 48 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1112,12 +1112,12 @@ class Definitions {
def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
val mt = MethodType.companion(isContextual, false)(args, resultType)
if mt.hasErasedParams then
RefinedType(ErasedFunctionClass.typeRef, nme.apply, mt)
RefinedType(PolyFunctionClass.typeRef, nme.apply, mt)
else
FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil)
def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = {
ft.dealias match
case ErasedFunctionOf(mt) =>
case PolyFunctionOf(mt: MethodType) =>
Some(mt.paramInfos, mt.resType, mt.isContextualMethod)
case dft =>
val tsym = dft.typeSymbol
Expand All @@ -1129,38 +1129,14 @@ class Definitions {
}
}

object PolyOrErasedFunctionOf {
/** Matches a refined `PolyFunction` or `ErasedFunction` type and extracts the apply info.
*
* Pattern: `(PolyFunction | ErasedFunction) { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodicType)
if parent.derivesFrom(defn.PolyFunctionClass) || parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
}

object PolyFunctionOf {
/** Matches a refined `PolyFunction` type and extracts the apply info.
*
* Pattern: `PolyFunction { def apply: $pt }`
* Pattern: `PolyFunction { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[PolyType] = ft.dealias match
case RefinedType(parent, nme.apply, pt: PolyType)
def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodicType)
if parent.derivesFrom(defn.PolyFunctionClass) =>
Some(pt)
case _ => None
}

object ErasedFunctionOf {
/** Matches a refined `ErasedFunction` type and extracts the apply info.
*
* Pattern: `ErasedFunction { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodType)
if parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
}
Expand Down Expand Up @@ -1514,9 +1490,6 @@ class Definitions {
lazy val PolyFunctionClass = requiredClass("scala.PolyFunction")
def PolyFunctionType = PolyFunctionClass.typeRef

lazy val ErasedFunctionClass = requiredClass("scala.runtime.ErasedFunction")
def ErasedFunctionType = ErasedFunctionClass.typeRef

/** If `cls` is a class in the scala package, its name, otherwise EmptyTypeName */
def scalaClassName(cls: Symbol)(using Context): TypeName = cls.denot match
case clsd: ClassDenotation if clsd.owner eq ScalaPackageClass =>
Expand Down Expand Up @@ -1579,8 +1552,6 @@ class Definitions {
/** Is a synthetic function class
* - FunctionN for N > 22
* - ContextFunctionN for N >= 0
* - ErasedFunctionN for N > 0
* - ErasedContextFunctionN for N > 0
*/
def isSyntheticFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isSyntheticFunction

Expand All @@ -1596,8 +1567,6 @@ class Definitions {
* - FunctionN for 22 > N >= 0 remains as FunctionN
* - ContextFunctionN for N > 22 becomes FunctionXXL
* - ContextFunctionN for N <= 22 becomes FunctionN
* - ErasedFunctionN becomes Function0
* - ImplicitErasedFunctionN becomes Function0
* - anything else becomes a NoType
*/
def functionTypeErasure(cls: Symbol): Type =
Expand Down Expand Up @@ -1756,13 +1725,11 @@ class Definitions {
/** Returns whether `tp` is an instance or a refined instance of:
* - scala.FunctionN
* - scala.ContextFunctionN
* - ErasedFunction
* - PolyFunction
*/
def isFunctionType(tp: Type)(using Context): Boolean =
isFunctionNType(tp)
|| tp.derivesFrom(defn.PolyFunctionClass) // TODO check for refinement?
|| tp.derivesFrom(defn.ErasedFunctionClass) // TODO check for refinement?

private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
if !ctx.settings.Yscala2Stdlib.value then
Expand Down Expand Up @@ -1866,7 +1833,7 @@ class Definitions {
tp.stripTypeVar.dealias match
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
case tp1 @ ErasedFunctionOf(mt) if mt.isContextualMethod =>
case tp1 @ PolyFunctionOf(mt: MethodType) if mt.isContextualMethod =>
tp1
case tp1 =>
if tp1.typeSymbol.name.isContextFunction && isFunctionNType(tp1) then tp1
Expand All @@ -1886,21 +1853,14 @@ class Definitions {
atPhase(erasurePhase)(unapply(tp))
else
asContextFunctionType(tp) match
case ErasedFunctionOf(mt) =>
case PolyFunctionOf(mt: MethodType) =>
Some((mt.paramInfos, mt.resType, mt.erasedParams))
case tp1 if tp1.exists =>
val args = tp1.functionArgInfos
val erasedParams = erasedFunctionParameters(tp1)
val erasedParams = List.fill(functionArity(tp1)) { false }
Some((args.init, args.last, erasedParams))
case _ => None

/* Returns a list of erased booleans marking whether parameters are erased, for a function type. */
def erasedFunctionParameters(tp: Type)(using Context): List[Boolean] = tp.dealias match {
case ErasedFunctionOf(mt) => mt.erasedParams
case tp if isFunctionNType(tp) => List.fill(functionArity(tp)) { false }
case _ => Nil
}

/** A whitelist of Scala-2 classes that are known to be pure */
def isAssuredNoInits(sym: Symbol): Boolean =
(sym `eq` SomeClass) || isTupleClass(sym)
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ object StdNames {
inline val MODULE_INSTANCE_FIELD = "MODULE$"

inline val Function = "Function"
inline val ErasedFunction = "ErasedFunction"
inline val ContextFunction = "ContextFunction"
inline val ErasedContextFunction = "ErasedContextFunction"
inline val AbstractFunction = "AbstractFunction"
Expand Down Expand Up @@ -214,7 +213,6 @@ object StdNames {
final val Throwable: N = "Throwable"
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"
final val ErasedFunction: N = "ErasedFunction"

final val Abs: N = "Abs"
final val And: N = "&&"
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,10 @@ class TypeApplications(val self: Type) extends AnyVal {
case _ => Nil

/** If this is an encoding of a function type, return its arguments, otherwise return Nil.
* Handles `ErasedFunction`s and poly functions gracefully.
* Handles poly functions gracefully.
*/
final def functionArgInfos(using Context): List[Type] = self.dealias match
case defn.ErasedFunctionOf(mt) => (mt.paramInfos :+ mt.resultType)
case defn.PolyFunctionOf(mt: MethodType) => (mt.paramInfos :+ mt.resultType)
case _ => self.dropDependentRefinement.dealias.argInfos

/** Argument types where existential types in arguments are disallowed */
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

if defn.isFunctionType(tp2) then
if tp2.derivesFrom(defn.PolyFunctionClass) then
// TODO should we handle ErasedFunction is this same way?
tp1.member(nme.apply).info match
case info1: PolyType =>
return isSubInfo(info1, tp2.refinedInfo)
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ object TypeErasure {
case _ => false
}

/** The erasure of `(PolyFunction | ErasedFunction) { def apply: $applyInfo }` */
/** The erasure of `PolyFunction { def apply: $applyInfo }` */
def eraseRefinedFunctionApply(applyInfo: Type)(using Context): Type =
def functionType(info: Type): Type = info match {
case info: PolyType =>
Expand Down Expand Up @@ -654,7 +654,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
else SuperType(eThis, eSuper)
case ExprType(rt) =>
defn.FunctionType(0)
case defn.PolyOrErasedFunctionOf(mt) =>
case defn.PolyFunctionOf(mt) =>
eraseRefinedFunctionApply(mt)
case tp: TypeVar if !tp.isInstantiated =>
assert(inSigName, i"Cannot erase uninstantiated type variable $tp")
Expand Down Expand Up @@ -936,7 +936,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
sigName(defn.FunctionOf(Nil, rt))
case tp: TypeVar if !tp.isInstantiated =>
tpnme.Uninstantiated
case tp @ defn.PolyOrErasedFunctionOf(_) =>
case tp @ defn.PolyFunctionOf(_) =>
// we need this case rather than falling through to the default
// because RefinedTypes <: TypeProxy and it would be caught by
// the case immediately below
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,7 @@ object Types {
formals1 mapConserve (_.translateFromRepeated(toArray = isJava)),
result1, isContextual)
if mt.hasErasedParams then
RefinedType(defn.ErasedFunctionType, nme.apply, mt)
RefinedType(defn.PolyFunctionType, nme.apply, mt)
else if alwaysDependent || mt.isResultDependent then
RefinedType(nonDependentFunType, nme.apply, mt)
else nonDependentFunType
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ object Erasure {
// Instead, we manually lookup the type of `apply` in the qualifier.
inContext(preErasureCtx) {
val qualTp = tree.qualifier.typeOpt.widen
if qualTp.derivesFrom(defn.PolyFunctionClass) || qualTp.derivesFrom(defn.ErasedFunctionClass) then
if qualTp.derivesFrom(defn.PolyFunctionClass) then
eraseRefinedFunctionApply(qualTp.select(nme.apply).widen).classSymbol
else
NoSymbol
Expand Down
7 changes: 2 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,8 @@ object TreeChecker {
assert(tree.isTerm || !ctx.isAfterTyper, tree.show + " at " + ctx.phase)
val tpe = tree.typeOpt

// PolyFunction and ErasedFunction apply methods stay structural until Erasure
val isRefinedFunctionApply = (tree.name eq nme.apply) && {
val qualTpe = tree.qualifier.typeOpt
qualTpe.derivesFrom(defn.PolyFunctionClass) || qualTpe.derivesFrom(defn.ErasedFunctionClass)
}
// PolyFunction apply method stay structural until Erasure
val isRefinedFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass)

// Outer selects are pickled specially so don't require a symbol
val isOuterSelect = tree.name.is(OuterSelectName)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
expected =:= defn.FunctionOf(actualArgs, actualRet,
defn.isContextFunctionType(baseFun))
val arity: Int =
if fun.derivesFrom(defn.ErasedFunctionClass) then -1 // TODO support?
else if defn.isFunctionNType(fun) then
if defn.isFunctionNType(fun) then
// TupledFunction[(...) => R, ?]
fun.functionArgInfos match
case funArgs :+ funRet
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
if defn.isNonRefinedFunction(parent) && formals.length == defaultArity =>
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
case defn.ErasedFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity =>
case defn.PolyFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity =>
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
case SAMType(mt @ MethodTpe(_, formals, _), samParent) =>
val restpe = mt.resultType match
Expand Down Expand Up @@ -1433,7 +1433,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val resTpt = TypeTree(mt.nonDependentResultApprox).withSpan(body.span)
val typeArgs = appDef.termParamss.head.map(_.tpt) :+ resTpt
val core =
if mt.hasErasedParams then TypeTree(defn.ErasedFunctionClass.typeRef)
if mt.hasErasedParams then TypeTree(defn.PolyFunctionClass.typeRef)
else
val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure)
val tycon = TypeTree(funSym.typeRef)
Expand Down Expand Up @@ -3220,7 +3220,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

val erasedParams = pt match {
case defn.ErasedFunctionOf(mt: MethodType) => mt.erasedParams
case defn.PolyFunctionOf(mt: MethodType) => mt.erasedParams
case _ => paramTypes.map(_ => false)
}

Expand Down
11 changes: 8 additions & 3 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,12 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def isContextFunctionType: Boolean =
dotc.core.Symbols.defn.isContextFunctionType(self)
def isErasedFunctionType: Boolean =
self.derivesFrom(dotc.core.Symbols.defn.ErasedFunctionClass)
self match
case dotc.core.Symbols.defn.PolyFunctionOf(mt) =>
mt match
case mt: MethodType => mt.hasErasedParams
case PolyType(_, _, mt1) => mt1.hasErasedParams
case _ => false
def isDependentFunctionType: Boolean =
val tpNoRefinement = self.dropDependentRefinement
tpNoRefinement != self
Expand Down Expand Up @@ -2823,13 +2828,13 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def FunctionClass(arity: Int, isImplicit: Boolean = false, isErased: Boolean = false): Symbol =
if arity < 0 then throw IllegalArgumentException(s"arity: $arity")
if isErased then
throw new Exception("Erased function classes are not supported. Use a refined `scala.runtime.ErasedFunction`")
throw new Exception("Erased function classes are not supported. Use a refined `scala.PolyFunction`")
else dotc.core.Symbols.defn.FunctionSymbol(arity, isImplicit)
def FunctionClass(arity: Int): Symbol =
FunctionClass(arity, false, false)
def FunctionClass(arity: Int, isContextual: Boolean): Symbol =
FunctionClass(arity, isContextual, false)
def ErasedFunctionClass = dotc.core.Symbols.defn.ErasedFunctionClass
def PolyFunctionClass = dotc.core.Symbols.defn.PolyFunctionClass
def TupleClass(arity: Int): Symbol =
dotc.core.Symbols.defn.TupleType(arity).nn.classSymbol.asClass
def isTupleClass(sym: Symbol): Boolean =
Expand Down
4 changes: 2 additions & 2 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4316,9 +4316,9 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
@experimental
def FunctionClass(arity: Int, isContextual: Boolean): Symbol

/** The `scala.runtime.ErasedFunction` built-in trait. */
/** The `scala.PolyFunction` built-in trait. */
@experimental
def ErasedFunctionClass: Symbol
def PolyFunctionClass: Symbol

/** Function-like object that maps arity to symbols for classes `scala.TupleX`.
* - 0th element is `NoSymbol`
Expand Down
4 changes: 3 additions & 1 deletion library/src/scala/runtime/ErasedFunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ import scala.annotation.experimental
* This trait will be refined with an `apply` method with erased parameters:
* ErasedFunction { def apply([erased] x_1: P_1, ..., [erased] x_N: P_N): R }
* This type will be erased to FunctionL, where L = N - count(erased).
*
* Note: Now we use `scala.PolyFunction` instead. This will be removed.
*/
@experimental trait ErasedFunction
@experimental trait ErasedFunction // TODO delete. Cannot be deleted until the reference compiler stops using it.
4 changes: 1 addition & 3 deletions tests/run-custom-args/erased/erased-15.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import scala.runtime.ErasedFunction

object Test {

def main(args: Array[String]): Unit = {
Expand All @@ -12,7 +10,7 @@ object Test {
}
}

class Foo extends ErasedFunction {
class Foo extends PolyFunction {
def apply(erased x: Int): Int = {
println("Foo.apply")
42
Expand Down
8 changes: 4 additions & 4 deletions tests/run-custom-args/erased/quotes-reflection.check
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ method m2: (i: scala.Int) isGiven=false isImplicit=false erasedArgs=List(true)
method m3: (i: scala.Int, j: scala.Int) isGiven=false isImplicit=false erasedArgs=List(false, true)
method m4: (i: EC) isGiven=false isImplicit=false erasedArgs=List(true)
val l1: scala.ContextFunction1[scala.Int, scala.Int]
val l2: scala.runtime.ErasedFunction with apply: (x: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true)
val l3: scala.runtime.ErasedFunction with apply: (x: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=true erasedParams=List(true)
val l4: scala.runtime.ErasedFunction with apply: (x: scala.Int, y: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(false, true)
val l5: scala.runtime.ErasedFunction with apply: (x: EC @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true)
val l2: scala.PolyFunction with apply: (x: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true)
val l3: scala.PolyFunction with apply: (x: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=true erasedParams=List(true)
val l4: scala.PolyFunction with apply: (x: scala.Int, y: scala.Int @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(false, true)
val l5: scala.PolyFunction with apply: (x: EC @scala.annotation.internal.ErasedParam) isImplicit=false erasedParams=List(true)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def inspect2[A: Type](using Quotes): Expr[String] = {
s"method $name: $paramStr"
case vd @ ValDef(name, tpt, body) =>
tpt.tpe match
case Refinement(parent, "apply", tpe: MethodType) if parent == defn.ErasedFunctionClass.typeRef =>
case Refinement(parent, "apply", tpe: MethodType) if parent == defn.PolyFunctionClass.typeRef =>
assert(tpt.tpe.isErasedFunctionType)

val params = tpe.paramNames.zip(tpe.paramTypes).map((n, t) => s"$n: ${t.show}").mkString("(", ", ", ")")
Expand Down
Loading

0 comments on commit f701f28

Please sign in to comment.