Skip to content

Commit

Permalink
Improve type inference for functions like fold (#18780)
Browse files Browse the repository at this point in the history
When calling a fold with an accumulator like `Nil` or `List()` one used
to have add an explicit type ascription. This is now no longer
necessary. When instantiating type variables that occur invariantly in
the expected type of a lambda, we now replace covariant occurrences of
`Nothing` in the (possibly widened) instance type of the type variable
with fresh type variables.

In the case of fold, the accumulator determines the instance type of a
type variable that appears both in the parameter list and in the result
type of the closure, which makes it invariant. So the accumulator type
is improved in the way described above.

The idea is that a fresh type variable in such places is always better
than Nothing. For module values such as `Nil` we widen to `List[<fresh
var>]`. This does possibly cause a new type error if the fold really
wanted a `Nil` instance. But that case seems very rare, so it looks like
a good bet in general to do the widening.
  • Loading branch information
odersky authored Nov 14, 2023
2 parents e4ba788 + 6f1a09a commit 563fab9
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 39 deletions.
12 changes: 3 additions & 9 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Flags.*
import config.Config
import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import transform.TypeUtils.isTransparent
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
Expand Down Expand Up @@ -566,13 +567,6 @@ trait ConstraintHandling {
inst
end approximation

private def isTransparent(tp: Type, traitOnly: Boolean)(using Context): Boolean = tp match
case AndType(tp1, tp2) =>
isTransparent(tp1, traitOnly) && isTransparent(tp2, traitOnly)
case _ =>
val cls = tp.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))

/** If `tp` is an intersection such that some operands are transparent trait instances
* and others are not, replace as many transparent trait instances as possible with Any
* as long as the result is still a subtype of `bound`. But fall back to the
Expand All @@ -585,7 +579,7 @@ trait ConstraintHandling {
var dropped: List[Type] = List() // the types dropped so far, last one on top

def dropOneTransparentTrait(tp: Type): Type =
if isTransparent(tp, traitOnly = true) && !kept.contains(tp) then
if tp.isTransparent(traitOnly = true) && !kept.contains(tp) then
dropped = tp :: dropped
defn.AnyType
else tp match
Expand Down Expand Up @@ -658,7 +652,7 @@ trait ConstraintHandling {
def widenOr(tp: Type) =
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && !isTransparent(tpw, traitOnly = false) && (tpw <:< bound) then tpw else tp
if (tpw ne tp) && !tpw.isTransparent() && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
Expand Down
13 changes: 9 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4908,6 +4908,9 @@ object Types {
tp
}

def typeToInstantiateWith(fromBelow: Boolean)(using Context): Type =
TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)

/** Instantiate variable from the constraints over its `origin`.
* If `fromBelow` is true, the variable is instantiated to the lub
* of its lower bounds in the current constraint; otherwise it is
Expand All @@ -4916,7 +4919,7 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
val tp = typeToInstantiateWith(fromBelow)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
Expand Down Expand Up @@ -5812,11 +5815,13 @@ object Types {
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
tp.derivedLambdaType(tp.paramNames, formals, restpe)

protected def mapArg(arg: Type, tparam: ParamInfo): Type = arg match
case arg: TypeBounds => this(arg)
case arg => atVariance(variance * tparam.paramVarianceSign)(this(arg))

protected def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match
case arg :: otherArgs if tparams.nonEmpty =>
val arg1 = arg match
case arg: TypeBounds => this(arg)
case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg))
val arg1 = mapArg(arg, tparams.head)
val otherArgs1 = mapArgs(otherArgs, tparams.tail)
if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args
else arg1 :: otherArgs1
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@ package transform

import core.*
import TypeErasure.ErasedValueType
import Types.*
import Contexts.*
import Symbols.*
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
import Names.Name

import dotty.tools.dotc.core.Decorators.*

object TypeUtils {
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
Expand Down Expand Up @@ -98,5 +94,15 @@ object TypeUtils {
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false

/** Is this a type deriving only from transparent classes?
* @param traitOnly if true, all class symbols must be transparent traits
*/
def isTransparent(traitOnly: Boolean = false)(using Context): Boolean = self match
case AndType(tp1, tp2) =>
tp1.isTransparent(traitOnly) && tp2.isTransparent(traitOnly)
case _ =>
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
}
}
111 changes: 91 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import ProtoTypes.*
import NameKinds.UniqueName
import util.Spans.*
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
import Decorators.*
import transform.TypeUtils.isTransparent
import Decorators._
import config.Printers.{gadts, typr}
import annotation.tailrec
import reporting.*
Expand Down Expand Up @@ -60,7 +61,9 @@ object Inferencing {
def instantiateSelected(tp: Type, tvars: List[Type])(using Context): Unit =
if (tvars.nonEmpty)
IsFullyDefinedAccumulator(
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
new ForceDegree.Value(IfBottom.flip):
override def appliesTo(tvar: TypeVar) = tvars.contains(tvar),
minimizeSelected = true
).process(tp)

/** Instantiate any type variables in `tp` whose bounds contain a reference to
Expand Down Expand Up @@ -154,15 +157,66 @@ object Inferencing {
* their lower bound. Record whether successful.
* 2nd Phase: If first phase was successful, instantiate all remaining type variables
* to their upper bound.
*
* Instance types can be improved by replacing covariant occurrences of Nothing
* with fresh type variables, if `force` allows this in its `canImprove` implementation.
*/
private class IsFullyDefinedAccumulator(force: ForceDegree.Value, minimizeSelected: Boolean = false)
(using Context) extends TypeAccumulator[Boolean] {

private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
/** Replace toplevel-covariant occurrences (i.e. covariant without double flips)
* of Nothing by fresh type variables. Double-flips are not covered to be
* conservative and save a bit of time on traversals; we could probably
* generalize that if we see use cases.
* For singleton types and references to module classes: try to
* improve the widened type. For module classes, the widened type
* is the intersection of all its non-transparent parent types.
*/
private def improve(tvar: TypeVar) = new TypeMap:
def apply(t: Type) = trace(i"improve $t", show = true):
def tryWidened(widened: Type): Type =
val improved = apply(widened)
if improved ne widened then improved else mapOver(t)
if variance > 0 then
t match
case t: TypeRef =>
if t.symbol == defn.NothingClass then
newTypeVar(TypeBounds.empty, nestingLevel = tvar.nestingLevel)
else if t.symbol.is(ModuleClass) then
tryWidened(t.parents.filter(!_.isTransparent())
.foldLeft(defn.AnyType: Type)(TypeComparer.andType(_, _)))
else
mapOver(t)
case t: TermRef =>
tryWidened(t.widen)
case _ =>
mapOver(t)
else t

// Don't map Nothing arguments for higher-kinded types; we'd get the wrong kind */
override def mapArg(arg: Type, tparam: ParamInfo): Type =
if tparam.paramInfo.isLambdaSub then arg
else super.mapArg(arg, tparam)
end improve

/** Instantiate type variable with possibly improved computed instance type.
* @return true if variable was instantiated with improved type, which
* in this case should not be instantiated further, false otherwise.
*/
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Boolean =
if fromBelow && force.canImprove(tvar) then
val inst = tvar.typeToInstantiateWith(fromBelow = true)
if apply(true, inst) then
// need to recursively check before improving, since improving adds type vars
// which should not be instantiated at this point
val better = improve(tvar)(inst)
if better <:< TypeComparer.fullUpperBound(tvar.origin) then
typr.println(i"forced instantiation of invariant ${tvar.origin} = $inst, improved to $better")
tvar.instantiateWith(better)
return true
val inst = tvar.instantiate(fromBelow)
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
inst
}
false

private var toMaximize: List[TypeVar] = Nil

Expand All @@ -178,26 +232,27 @@ object Inferencing {
&& ctx.typerState.constraint.contains(tvar)
&& {
var fail = false
var skip = false
val direction = instDirection(tvar.origin)
if minimizeSelected then
if direction <= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if direction >= 0 && tvar.hasUpperBound then
instantiate(tvar, fromBelow = false)
skip = instantiate(tvar, fromBelow = false)
// else hold off instantiating unbounded unconstrained variable
else if direction != 0 then
instantiate(tvar, fromBelow = direction < 0)
skip = instantiate(tvar, fromBelow = direction < 0)
else if variance >= 0 && tvar.hasLowerBound then
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
&& force.ifBottom == IfBottom.ok
then // if variance == 0, prefer upper bound if one is given
instantiate(tvar, fromBelow = true)
skip = instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
fail = true
else
toMaximize = tvar :: toMaximize
!fail && foldOver(x, tvar)
!fail && (skip || foldOver(x, tvar))
}
case tp => foldOver(x, tp)
}
Expand Down Expand Up @@ -467,7 +522,7 @@ object Inferencing {
*
* we want to instantiate U to x.type right away. No need to wait further.
*/
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
Stats.record("variances")
val constraint = ctx.typerState.constraint

Expand Down Expand Up @@ -769,14 +824,30 @@ trait Inferencing { this: Typer =>
}

/** An enumeration controlling the degree of forcing in "is-fully-defined" checks. */
@sharable object ForceDegree {
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom):
override def toString = s"ForceDegree.Value(.., $ifBottom)"
val none: Value = new Value(_ => false, IfBottom.ok) { override def toString = "ForceDegree.none" }
val all: Value = new Value(_ => true, IfBottom.ok) { override def toString = "ForceDegree.all" }
val failBottom: Value = new Value(_ => true, IfBottom.fail) { override def toString = "ForceDegree.failBottom" }
val flipBottom: Value = new Value(_ => true, IfBottom.flip) { override def toString = "ForceDegree.flipBottom" }
}
@sharable object ForceDegree:
class Value(val ifBottom: IfBottom):

/** Does `tv` need to be instantiated? */
def appliesTo(tv: TypeVar): Boolean = true

/** Should we try to improve the computed instance type by replacing bottom types
* with fresh type variables?
*/
def canImprove(tv: TypeVar): Boolean = false

override def toString = s"ForceDegree.Value($ifBottom)"
end Value

val none: Value = new Value(IfBottom.ok):
override def appliesTo(tv: TypeVar) = false
override def toString = "ForceDegree.none"
val all: Value = new Value(IfBottom.ok):
override def toString = "ForceDegree.all"
val failBottom: Value = new Value(IfBottom.fail):
override def toString = "ForceDegree.failBottom"
val flipBottom: Value = new Value(IfBottom.flip):
override def toString = "ForceDegree.flipBottom"
end ForceDegree

enum IfBottom:
case ok, fail, flip
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1634,14 +1634,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case _ =>

if desugared.isEmpty then
val forceDegree =
if pt.isValueType then
// Allow variables that appear invariantly in `pt` to be improved by mapping
// bottom types in their instance types to fresh type variables
new ForceDegree.Value(IfBottom.fail):
val tvmap = variances(pt)
override def canImprove(tvar: TypeVar) =
tvmap.computedVariance(tvar) == (0: Integer)
else
ForceDegree.failBottom

val inferredParams: List[untpd.ValDef] =
for ((param, i) <- params.zipWithIndex) yield
if (!param.tpt.isEmpty) param
else
val (formalBounds, isErased) = protoFormal(i)
val formal = formalBounds.loBound
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
val knownFormal = isFullyDefined(formal, forceDegree)
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
val paramType =
Expand Down
7 changes: 7 additions & 0 deletions tests/neg/foldinf-ill-kinded.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg/foldinf-ill-kinded.scala:9:16 -------------------------------------------------
9 | ys.combine(x) // error
| ^^^^^^^^^^^^^
| Found: Foo[List]
| Required: Foo[Nothing]
|
| longer explanation available when compiling with `-explain`
10 changes: 10 additions & 0 deletions tests/neg/foldinf-ill-kinded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class Foo[+T[_]]:
def combine[T1[x] >: T[x]](x: T1[Int]): Foo[T1] = new Foo
object Foo:
def empty: Foo[Nothing] = new Foo

object X:
def test(xs: List[List[Int]]): Unit =
xs.foldLeft(Foo.empty)((ys, x) =>
ys.combine(x) // error
)
34 changes: 34 additions & 0 deletions tests/pos/folds.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

object Test:
extension [A](xs: List[A])
def foldl[B](acc: B)(f: (A, B) => B): B = ???

val xs = List(1, 2, 3)

val _ = xs.foldl(List())((y, ys) => y :: ys)

val _ = xs.foldl(Nil)((y, ys) => y :: ys)

def partition[a](xs: List[a], pred: a => Boolean): Tuple2[List[a], List[a]] = {
xs.foldRight/*[Tuple2[List[a], List[a]]]*/((List(), List())) {
(x, p) => if (pred (x)) (x :: p._1, p._2) else (p._1, x :: p._2)
}
}

def snoc[A](xs: List[A], x: A) = x :: xs

def reverse[A](xs: List[A]) =
xs.foldLeft(Nil)(snoc)

def reverse2[A](xs: List[A]) =
xs.foldLeft(List())(snoc)

val ys: Seq[Int] = xs
ys.foldLeft(Seq())((ys, y) => y +: ys)
ys.foldLeft(Nil)((ys, y) => y +: ys)

def dup[A](xs: List[A]) =
xs.foldRight(Nil)((x, xs) => x :: x :: xs)

def toSet[A](xs: Seq[A]) =
xs.foldLeft(Set.empty)(_ + _)

0 comments on commit 563fab9

Please sign in to comment.