From 56251079218efae3d8fc9dd618ad009e7a396a86 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Wed, 26 Jul 2023 13:21:19 +0200 Subject: [PATCH] Support polymorphic functions with erased parameters This adds support for ```scala [T1, ..., Tn] => ([erased] x1: X1, ..., [erased] xm: Xm) => R ``` Polymorphic function types with erased parameters are represented as using a refinement on `PolyFunction`. `ErasedFunction` is not needed. ```scala PolyFunction { def apply[T1, ..., Tn]([erased] x1: X1, ..., [erased] xm: Xm): R } ``` --- .../src/dotty/tools/dotc/ast/Desugar.scala | 22 ++++++++------- .../dotty/tools/dotc/parsing/Parsers.scala | 19 ------------- .../src/dotty/tools/dotc/typer/Typer.scala | 11 -------- .../erased/poly-functions.scala | 16 ----------- .../polymorphic-erased-functions-types.check | 28 +++++++++++++++++++ .../polymorphic-erased-functions-types.scala | 7 +++++ .../polymorphic-erased-functions-used.check | 8 ++++++ .../polymorphic-erased-functions-used.scala | 4 +++ tests/pos/poly-erased-functions.scala | 14 ++++++++++ tests/run/polymorphic-erased-functions.scala | 22 +++++++++++++++ 10 files changed, 95 insertions(+), 56 deletions(-) delete mode 100644 tests/neg-custom-args/erased/poly-functions.scala create mode 100644 tests/neg/polymorphic-erased-functions-types.check create mode 100644 tests/neg/polymorphic-erased-functions-types.scala create mode 100644 tests/neg/polymorphic-erased-functions-used.check create mode 100644 tests/neg/polymorphic-erased-functions-used.scala create mode 100644 tests/pos/poly-erased-functions.scala create mode 100644 tests/run/polymorphic-erased-functions.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 1e90fdcb2017..c104c603422d 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1100,19 +1100,21 @@ object desugar { */ def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked - val funFlags = fun match + val paramFlags = fun match case fun: FunctionWithMods => - fun.mods.flags - case _ => EmptyFlags + // TODO: make use of this in the desugaring when pureFuns is enabled. + // val isImpure = funFlags.is(Impure) - // TODO: make use of this in the desugaring when pureFuns is enabled. - // val isImpure = funFlags.is(Impure) + // Function flags to be propagated to each parameter in the desugared method type. + val givenFlag = fun.mods.flags.toTermFlags & Given + fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag) + case _ => + vparamTypes.map(_ => EmptyFlags) - // Function flags to be propagated to each parameter in the desugared method type. - val paramFlags = funFlags.toTermFlags & Given - val vparams = vparamTypes.zipWithIndex.map: - case (p: ValDef, _) => p.withAddedFlags(paramFlags) - case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map { + case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags) + case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags) + }.toList RefinedTypeTree(ref(defn.PolyFunctionType), List( DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 224bd2fa7776..81118831c8fa 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1426,23 +1426,6 @@ object Parsers { case _ => None } - private def checkFunctionNotErased(f: Function, context: String) = - def fail(span: Span) = - syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span) - // erased parameter in type - val hasErasedParam = f match - case f: FunctionWithMods => f.hasErasedParams - case _ => false - if hasErasedParam then - fail(f.span) - // erased parameter in term - val hasErasedMods = f.args.collectFirst { - case v: ValDef if v.mods.is(Flags.Erased) => v - } - hasErasedMods match - case Some(param) => fail(param.span) - case _ => - /** CaptureRef ::= ident | `this` */ def captureRef(): Tree = @@ -1592,7 +1575,6 @@ object Parsers { atSpan(start, arrowOffset) { getFunction(body) match { case Some(f) => - checkFunctionNotErased(f, "poly function") PolyFunction(tparams, body) case None => syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset) @@ -2159,7 +2141,6 @@ object Parsers { atSpan(start, arrowOffset) { getFunction(body) match case Some(f) => - checkFunctionNotErased(f, "poly function") PolyFunction(tparams, f) case None => syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 4ea1cd481e43..a26ec9747903 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2604,17 +2604,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos) if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then sym.setFlag(Erased) - if - sym.info.isInstanceOf[PolyType] && - ((sym.name eq nme.ANON_FUN) || - (sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass)) - then - mdef match - case DefDef(_, _ :: vparams :: Nil, _, _) => - vparams.foreach: vparam => - if vparam.symbol.is(Erased) then - report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos) - case _ => def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = { val TypeDef(name, rhs) = tdef diff --git a/tests/neg-custom-args/erased/poly-functions.scala b/tests/neg-custom-args/erased/poly-functions.scala deleted file mode 100644 index 000a2ca49cc9..000000000000 --- a/tests/neg-custom-args/erased/poly-functions.scala +++ /dev/null @@ -1,16 +0,0 @@ -object Test: - // Poly functions with erased parameters are disallowed as an implementation restriction - - type T1 = [X] => (erased x: X, y: Int) => Int // error - type T2 = [X] => (x: X, erased y: Int) => X // error - - val t1 = [X] => (erased x: X, y: Int) => y // error - val t2 = [X] => (x: X, erased y: Int) => x // error - - // Erased classes should be detected too - erased class A - - type T3 = [X] => (x: A, y: X) => X // error - - val t3 = [X] => (x: A, y: X) => y // error - diff --git a/tests/neg/polymorphic-erased-functions-types.check b/tests/neg/polymorphic-erased-functions-types.check new file mode 100644 index 000000000000..39d2720023cf --- /dev/null +++ b/tests/neg/polymorphic-erased-functions-types.check @@ -0,0 +1,28 @@ +-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:3:28 --------------------------------- +3 |def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: [T] => (erased t: T) => Unit + | Required: [T] => (x$1: T) => Unit + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:4:37 --------------------------------- +4 |def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error + | ^^^^^^^^^^^^^^^^^^^ + | Found: [T] => (t: T) => Unit + | Required: [T] => (erased x$1: T) => Unit + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:6:36 --------------------------------- +6 |def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: [T, U] => (t: T, erased u: U) => Unit + | Required: [T, U] => (x$1: T, x$2: U) => Unit + | + | longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:7:43 --------------------------------- +7 |def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: [T, U] => (t: T, u: U) => Unit + | Required: [T, U] => (x$1: T, erased x$2: U) => Unit + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg/polymorphic-erased-functions-types.scala b/tests/neg/polymorphic-erased-functions-types.scala new file mode 100644 index 000000000000..d453c4602bad --- /dev/null +++ b/tests/neg/polymorphic-erased-functions-types.scala @@ -0,0 +1,7 @@ +import language.experimental.erasedDefinitions + +def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error +def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error + +def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error +def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error diff --git a/tests/neg/polymorphic-erased-functions-used.check b/tests/neg/polymorphic-erased-functions-used.check new file mode 100644 index 000000000000..6eb5abb0e235 --- /dev/null +++ b/tests/neg/polymorphic-erased-functions-used.check @@ -0,0 +1,8 @@ +-- Error: tests/neg/polymorphic-erased-functions-used.scala:3:33 ------------------------------------------------------- +3 |def t1 = [T] => (erased t: T) => t // error + | ^ + | parameter t is declared as `erased`, but is in fact used +-- Error: tests/neg/polymorphic-erased-functions-used.scala:4:42 ------------------------------------------------------- +4 |def t2 = [T, U] => (t: T, erased u: U) => u // error + | ^ + | parameter u is declared as `erased`, but is in fact used diff --git a/tests/neg/polymorphic-erased-functions-used.scala b/tests/neg/polymorphic-erased-functions-used.scala new file mode 100644 index 000000000000..73ca48b133ee --- /dev/null +++ b/tests/neg/polymorphic-erased-functions-used.scala @@ -0,0 +1,4 @@ +import language.experimental.erasedDefinitions + +def t1 = [T] => (erased t: T) => t // error +def t2 = [T, U] => (t: T, erased u: U) => u // error diff --git a/tests/pos/poly-erased-functions.scala b/tests/pos/poly-erased-functions.scala new file mode 100644 index 000000000000..8c7385edb86a --- /dev/null +++ b/tests/pos/poly-erased-functions.scala @@ -0,0 +1,14 @@ +import language.experimental.erasedDefinitions + +object Test: + type T1 = [X] => (erased x: X, y: Int) => Int + type T2 = [X] => (x: X, erased y: Int) => X + + val t1 = [X] => (erased x: X, y: Int) => y + val t2 = [X] => (x: X, erased y: Int) => x + + erased class A + + type T3 = [X] => (x: A, y: X) => X + + val t3 = [X] => (x: A, y: X) => y diff --git a/tests/run/polymorphic-erased-functions.scala b/tests/run/polymorphic-erased-functions.scala new file mode 100644 index 000000000000..4086423d8c6a --- /dev/null +++ b/tests/run/polymorphic-erased-functions.scala @@ -0,0 +1,22 @@ +import language.experimental.erasedDefinitions + +object Test extends App { + + // Types + type F1 = [T] => (erased T) => Int + type F2 = [T, U] => (T, erased U) => T + + // Terms + val t1 = [T] => (erased t: T) => 3 + assert(t1(List(1, 2, 3)) == 3) + val t1a: F1 = t1 + val t1b: F1 = [T] => (erased t) => 3 + assert(t1b(List(1, 2, 3)) == 3) + + val t2 = [T, U] => (t: T, erased u: U) => t + assert(t2(1, "abc") == 1) + val t2a: F2 = t2 + val t2b: F2 = [T, U] => (t, erased u) => t + assert(t2b(1, "abc") == 1) + +}