Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support polymorphic functions with erased parameters #18293

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1100,19 +1100,21 @@ object desugar {
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val funFlags = fun match
val paramFlags = fun match
case fun: FunctionWithMods =>
fun.mods.flags
case _ => EmptyFlags
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)
// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

// Function flags to be propagated to each parameter in the desugared method type.
val paramFlags = funFlags.toTermFlags & Given
val vparams = vparamTypes.zipWithIndex.map:
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
Expand Down
19 changes: 0 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1426,23 +1426,6 @@ object Parsers {
case _ => None
}

private def checkFunctionNotErased(f: Function, context: String) =
def fail(span: Span) =
syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span)
// erased parameter in type
val hasErasedParam = f match
case f: FunctionWithMods => f.hasErasedParams
case _ => false
if hasErasedParam then
fail(f.span)
// erased parameter in term
val hasErasedMods = f.args.collectFirst {
case v: ValDef if v.mods.is(Flags.Erased) => v
}
hasErasedMods match
case Some(param) => fail(param.span)
case _ =>

/** CaptureRef ::= ident | `this`
*/
def captureRef(): Tree =
Expand Down Expand Up @@ -1592,7 +1575,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match {
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, body)
case None =>
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
Expand Down Expand Up @@ -2159,7 +2141,6 @@ object Parsers {
atSpan(start, arrowOffset) {
getFunction(body) match
case Some(f) =>
checkFunctionNotErased(f, "poly function")
PolyFunction(tparams, f)
case None =>
syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)
Expand Down
11 changes: 0 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2604,17 +2604,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
sym.setFlag(Erased)
if
sym.info.isInstanceOf[PolyType] &&
((sym.name eq nme.ANON_FUN) ||
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
then
mdef match
case DefDef(_, _ :: vparams :: Nil, _, _) =>
vparams.foreach: vparam =>
if vparam.symbol.is(Erased) then
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
case _ =>

def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
val TypeDef(name, rhs) = tdef
Expand Down
16 changes: 0 additions & 16 deletions tests/neg-custom-args/erased/poly-functions.scala

This file was deleted.

28 changes: 28 additions & 0 deletions tests/neg/polymorphic-erased-functions-types.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:3:28 ---------------------------------
3 |def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T] => (erased t: T) => Unit
| Required: [T] => (x$1: T) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:4:37 ---------------------------------
4 |def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error
| ^^^^^^^^^^^^^^^^^^^
| Found: [T] => (t: T) => Unit
| Required: [T] => (erased x$1: T) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:6:36 ---------------------------------
6 |def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T, U] => (t: T, erased u: U) => Unit
| Required: [T, U] => (x$1: T, x$2: U) => Unit
|
| longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:7:43 ---------------------------------
7 |def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: [T, U] => (t: T, u: U) => Unit
| Required: [T, U] => (x$1: T, erased x$2: U) => Unit
|
| longer explanation available when compiling with `-explain`
7 changes: 7 additions & 0 deletions tests/neg/polymorphic-erased-functions-types.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import language.experimental.erasedDefinitions

def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error

def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
8 changes: 8 additions & 0 deletions tests/neg/polymorphic-erased-functions-used.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/polymorphic-erased-functions-used.scala:3:33 -------------------------------------------------------
3 |def t1 = [T] => (erased t: T) => t // error
| ^
| parameter t is declared as `erased`, but is in fact used
-- Error: tests/neg/polymorphic-erased-functions-used.scala:4:42 -------------------------------------------------------
4 |def t2 = [T, U] => (t: T, erased u: U) => u // error
| ^
| parameter u is declared as `erased`, but is in fact used
4 changes: 4 additions & 0 deletions tests/neg/polymorphic-erased-functions-used.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import language.experimental.erasedDefinitions

def t1 = [T] => (erased t: T) => t // error
def t2 = [T, U] => (t: T, erased u: U) => u // error
14 changes: 14 additions & 0 deletions tests/pos/poly-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import language.experimental.erasedDefinitions

object Test:
type T1 = [X] => (erased x: X, y: Int) => Int
type T2 = [X] => (x: X, erased y: Int) => X

val t1 = [X] => (erased x: X, y: Int) => y
val t2 = [X] => (x: X, erased y: Int) => x

erased class A

type T3 = [X] => (x: A, y: X) => X

val t3 = [X] => (x: A, y: X) => y
22 changes: 22 additions & 0 deletions tests/run/polymorphic-erased-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import language.experimental.erasedDefinitions

object Test extends App {

// Types
type F1 = [T] => (erased T) => Int
type F2 = [T, U] => (T, erased U) => T

// Terms
val t1 = [T] => (erased t: T) => 3
assert(t1(List(1, 2, 3)) == 3)
val t1a: F1 = t1
val t1b: F1 = [T] => (erased t) => 3
assert(t1b(List(1, 2, 3)) == 3)

val t2 = [T, U] => (t: T, erased u: U) => t
assert(t2(1, "abc") == 1)
val t2a: F2 = t2
val t2b: F2 = [T, U] => (t, erased u) => t
assert(t2b(1, "abc") == 1)

}