Skip to content

Commit

Permalink
Support polymorphic functions with erased parameters
Browse files Browse the repository at this point in the history
This adds support for
```scala
[T1, ..., Tn] => ([erased] x1: X1, ..., [erased] xm: Xm) => R
```

Polymorphic function types with erased parameters are represented as
using a refinement on `PolyFunction`. `ErasedFunction` is not needed.
```scala
PolyFunction {
  def apply[T1, ..., Tn]([erased] x1: X1, ..., [erased] xm: Xm): R
}
```
  • Loading branch information
nicolasstucki committed Jul 26, 2023
1 parent 97677cc commit 2fb1d1f
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 56 deletions.
22 changes: 12 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1100,19 +1100,21 @@ object desugar {
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val funFlags = fun match
val paramFlags = fun match
case fun: FunctionWithMods =>
fun.mods.flags
case _ => EmptyFlags
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)
// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

// Function flags to be propagated to each parameter in the desugared method type.
val paramFlags = funFlags.toTermFlags & Given
val vparams = vparamTypes.zipWithIndex.map:
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
Expand Down
19 changes: 0 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1426,23 +1426,6 @@ object Parsers {
case _ => None
}

private def checkFunctionNotErased(f: Function, context: String) =
def fail(span: Span) =
syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span)
// erased parameter in type
val hasErasedParam = f match
case f: FunctionWithMods => f.hasErasedParams
case _ => false
if hasErasedParam then
fail(f.span)
// erased parameter in term
val hasErasedMods = f.args.collectFirst {
case v: ValDef if v.mods.is(Flags.Erased) => v
}
hasErasedMods match
case Some(param) => fail(param.span)
case _ =>

/** CaptureRef ::= ident | `this`
*/
def captureRef(): Tree =
Expand Down Expand Up @@ -1592,7 +1575,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match {
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, body)
case None =>
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
Expand Down Expand Up @@ -2159,7 +2141,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, f)
case None =>
syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2604,17 +2604,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
sym.setFlag(Erased)
if
sym.info.isInstanceOf[PolyType] &&
((sym.name eq nme.ANON_FUN) ||
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
then
mdef match
case DefDef(_, _ :: vparams :: Nil, _, _) =>
vparams.foreach: vparam =>
if vparam.symbol.is(Erased) then
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
case _ =>

def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
val TypeDef(name, rhs) = tdef
Expand Down
16 changes: 0 additions & 16 deletions tests/neg-custom-args/erased/poly-functions.scala

This file was deleted.

8 changes: 8 additions & 0 deletions tests/neg/polymorphic-erased-functions.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/polymorphic-erased-functions.scala:3:33 ------------------------------------------------------------
3 |def t1 = [T] => (erased t: T) => t // error
| ^
| parameter t is declared as `erased`, but is in fact used
-- Error: tests/neg/polymorphic-erased-functions.scala:4:42 ------------------------------------------------------------
4 |def t2 = [T, U] => (t: T, erased u: U) => u // error
| ^
| parameter u is declared as `erased`, but is in fact used
4 changes: 4 additions & 0 deletions tests/neg/polymorphic-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import language.experimental.erasedDefinitions

def t1 = [T] => (erased t: T) => t // error
def t2 = [T, U] => (t: T, erased u: U) => u // error
15 changes: 15 additions & 0 deletions tests/pos/poly-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import language.experimental.erasedDefinitions

object Test:
type T1 = [X] => (erased x: X, y: Int) => Int
type T2 = [X] => (x: X, erased y: Int) => X

val t1 = [X] => (erased x: X, y: Int) => y
val t2 = [X] => (x: X, erased y: Int) => x

erased class A

type T3 = [X] => (x: A, y: X) => X

val t3 = [X] => (x: A, y: X) => y

22 changes: 22 additions & 0 deletions tests/run/polymorphic-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import language.experimental.erasedDefinitions

object Test extends App {

// Types
type F1 = [T] => (erased T) => Int
type F2 = [T, U] => (T, erased U) => T

// Terms
val t1 = [T] => (erased t: T) => 3
assert(t1(List(1, 2, 3)) == 3)
val t1a: F1 = t1
val t1b: F1 = [T] => (erased t) => 3
assert(t1b(List(1, 2, 3)) == 3)

val t2 = [T, U] => (t: T, erased u: U) => t
assert(t2(1, "abc") == 1)
val t2a: F2 = t2
val t2b: F2 = [T, U] => (t, erased u) => t
assert(t2b(1, "abc") == 1)

}

0 comments on commit 2fb1d1f

Please sign in to comment.