Skip to content

Commit

Permalink
Support polymorphic functions with erased parameters (#18293)
Browse files Browse the repository at this point in the history
This adds support for
```scala
[T1, ..., Tn] => ([erased] x1: X1, ..., [erased] xm: Xm) => r: R
```

Polymorphic function types with erased parameters are represented as
using a refinement on `PolyFunction`. `ErasedFunction` is not needed.
```scala
PolyFunction {
  def apply[[T1, ..., Tn]]([given] [erased] x1: X1, ..., [erased] xm: Xm): R
}
```
  • Loading branch information
nicolasstucki authored Jul 27, 2023
2 parents 97677cc + 5625107 commit 182331b
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 56 deletions.
22 changes: 12 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1100,19 +1100,21 @@ object desugar {
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val funFlags = fun match
val paramFlags = fun match
case fun: FunctionWithMods =>
fun.mods.flags
case _ => EmptyFlags
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)
// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

// Function flags to be propagated to each parameter in the desugared method type.
val paramFlags = funFlags.toTermFlags & Given
val vparams = vparamTypes.zipWithIndex.map:
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
Expand Down
19 changes: 0 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1426,23 +1426,6 @@ object Parsers {
case _ => None
}

private def checkFunctionNotErased(f: Function, context: String) =
def fail(span: Span) =
syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span)
// erased parameter in type
val hasErasedParam = f match
case f: FunctionWithMods => f.hasErasedParams
case _ => false
if hasErasedParam then
fail(f.span)
// erased parameter in term
val hasErasedMods = f.args.collectFirst {
case v: ValDef if v.mods.is(Flags.Erased) => v
}
hasErasedMods match
case Some(param) => fail(param.span)
case _ =>

/** CaptureRef ::= ident | `this`
*/
def captureRef(): Tree =
Expand Down Expand Up @@ -1592,7 +1575,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match {
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, body)
case None =>
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
Expand Down Expand Up @@ -2159,7 +2141,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, f)
case None =>
syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2604,17 +2604,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
sym.setFlag(Erased)
if
sym.info.isInstanceOf[PolyType] &&
((sym.name eq nme.ANON_FUN) ||
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
then
mdef match
case DefDef(_, _ :: vparams :: Nil, _, _) =>
vparams.foreach: vparam =>
if vparam.symbol.is(Erased) then
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
case _ =>

def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
val TypeDef(name, rhs) = tdef
Expand Down
16 changes: 0 additions & 16 deletions tests/neg-custom-args/erased/poly-functions.scala

This file was deleted.

28 changes: 28 additions & 0 deletions tests/neg/polymorphic-erased-functions-types.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:3:28 ---------------------------------
3 |def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T] => (erased t: T) => Unit
| Required: [T] => (x$1: T) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:4:37 ---------------------------------
4 |def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error
| ^^^^^^^^^^^^^^^^^^^
| Found: [T] => (t: T) => Unit
| Required: [T] => (erased x$1: T) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:6:36 ---------------------------------
6 |def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T, U] => (t: T, erased u: U) => Unit
| Required: [T, U] => (x$1: T, x$2: U) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:7:43 ---------------------------------
7 |def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T, U] => (t: T, u: U) => Unit
| Required: [T, U] => (x$1: T, erased x$2: U) => Unit
|
| longer explanation available when compiling with `-explain`
7 changes: 7 additions & 0 deletions tests/neg/polymorphic-erased-functions-types.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import language.experimental.erasedDefinitions

def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error

def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
8 changes: 8 additions & 0 deletions tests/neg/polymorphic-erased-functions-used.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/polymorphic-erased-functions-used.scala:3:33 -------------------------------------------------------
3 |def t1 = [T] => (erased t: T) => t // error
| ^
| parameter t is declared as `erased`, but is in fact used
-- Error: tests/neg/polymorphic-erased-functions-used.scala:4:42 -------------------------------------------------------
4 |def t2 = [T, U] => (t: T, erased u: U) => u // error
| ^
| parameter u is declared as `erased`, but is in fact used
4 changes: 4 additions & 0 deletions tests/neg/polymorphic-erased-functions-used.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import language.experimental.erasedDefinitions

def t1 = [T] => (erased t: T) => t // error
def t2 = [T, U] => (t: T, erased u: U) => u // error
14 changes: 14 additions & 0 deletions tests/pos/poly-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import language.experimental.erasedDefinitions

object Test:
type T1 = [X] => (erased x: X, y: Int) => Int
type T2 = [X] => (x: X, erased y: Int) => X

val t1 = [X] => (erased x: X, y: Int) => y
val t2 = [X] => (x: X, erased y: Int) => x

erased class A

type T3 = [X] => (x: A, y: X) => X

val t3 = [X] => (x: A, y: X) => y
22 changes: 22 additions & 0 deletions tests/run/polymorphic-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import language.experimental.erasedDefinitions

object Test extends App {

// Types
type F1 = [T] => (erased T) => Int
type F2 = [T, U] => (T, erased U) => T

// Terms
val t1 = [T] => (erased t: T) => 3
assert(t1(List(1, 2, 3)) == 3)
val t1a: F1 = t1
val t1b: F1 = [T] => (erased t) => 3
assert(t1b(List(1, 2, 3)) == 3)

val t2 = [T, U] => (t: T, erased u: U) => t
assert(t2(1, "abc") == 1)
val t2a: F2 = t2
val t2b: F2 = [T, U] => (t, erased u) => t
assert(t2b(1, "abc") == 1)

}

0 comments on commit 182331b

Please sign in to comment.