Skip to content

Commit

Permalink
Remember the TyApply is always on a Rho
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Dec 11, 2023
1 parent a4db25d commit 1dbc7d2
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 31 deletions.
2 changes: 1 addition & 1 deletion cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ object ProtoConverter {
} yield Type.exists(args, inT)

case Value.TypeApply(TypeApply(left, right, _)) =>
(tpe(left), tpe(right)).mapN(Type.TyApply(_, _))
(tpe(left), tpe(right)).mapN(Type.apply1(_, _))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ final case class DefinedType[+A](

def fnTypeOf(cf: ConstructorFn)(implicit ev: A <:< Kind.Arg): Type = {
// evidence to prove that we only ask for this after inference
val tc: Type = Type.const(packageName, name)
val tc: Type.Rho = Type.const(packageName, name)

val res = typeParams.foldLeft(tc) { (res, v) =>
Type.TyApply(res, Type.TyVar(v))
Expand Down
22 changes: 13 additions & 9 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ object Infer {
t: Type,
region: Region): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type.Rho)] = {

// Invariant: if t is Rho, then result._3 is Rho
def loop(t: Type, path: Variance): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type)] =
t match {
case q: Type.Quantified =>
Expand All @@ -426,9 +427,12 @@ object Infer {
case ta@Type.TyApply(left, right) =>
// Rule PRFUN
// we know the kind of left is k -> x, and right has kind k
// since left: Rho, we know loop(left, path)._3 is Rho
(varianceOfCons(ta, region), loop(left, path))
.flatMapN {
case (consVar, (sksl, el, ltpe)) =>
case (consVar, (sksl, el, ltpe0)) =>
// due to loop invariant
val ltpe: Type.Rho = ltpe0.asInstanceOf[Type.Rho]
val rightPath = consVar * path
loop(right, rightPath)
.map { case (sksr, er, rtpe) =>
Expand Down Expand Up @@ -634,7 +638,7 @@ object Infer {
case Variance.Invariant =>
unifyType(r1, r2, left, right)
}
_ <- subsCheck(l1, l2, left, right)
_ <- subsCheckRho2(l1, l2, left, right)
ks <- checkedKinds
} yield TypedExpr.coerceRho(ta, ks)
case (ta@Type.TyApply(l1, r1), rho2) =>
Expand All @@ -656,7 +660,7 @@ object Infer {
case Variance.Invariant =>
unifyType(r1, r2, left, right)
}
_ <- subsCheck(l1, l2, left, right)
_ <- subsCheckRho2(l1, l2, left, right)
ks <- checkedKinds
} yield TypedExpr.coerceRho(rho2, ks)
case (t1, t2) =>
Expand Down Expand Up @@ -719,7 +723,7 @@ object Infer {

// destructure apType in left[right]
// invariant apType is being checked against some rho with validated kind: lKind[rKind]
def unifyTyApp(apType: Type.Rho, lKind: Kind, rKind: Kind, apRegion: Region, evidenceRegion: Region): Infer[(Type, Type)] =
def unifyTyApp(apType: Type.Rho, lKind: Kind, rKind: Kind, apRegion: Region, evidenceRegion: Region): Infer[(Type.Rho, Type)] =
apType match {
case ta @ Type.TyApply(left, right) =>
// this branch only happens when checking ta <:< (rho: lKind[rKind])
Expand Down Expand Up @@ -822,7 +826,7 @@ object Infer {
case (t1 @ Type.TyApply(a1, b1), t2 @ Type.TyApply(a2, b2)) =>
validateKinds(t1, r1) &>
validateKinds(t2, r2) &>
unifyType(a1, a2, r1, r2) &>
unify(a1, a2, r1, r2) &>
unifyType(b1, b2, r1, r2)
case (Type.TyConst(c1), Type.TyConst(c2)) if c1 == c2 => unit
case (Type.TyVar(v1), Type.TyVar(v2)) if v1 == v2 => unit
Expand Down Expand Up @@ -1588,9 +1592,9 @@ object Infer {
pushDownCovariant(rest, nextRFA, left) match {
case Type.ForAll(bs, l) =>
// TODO: I think we can push down existentials too
Type.forAll(bs, Type.TyApply(l, nextRight))
Type.forAll(bs, Type.apply1(l, nextRight))
case rho /*: Type.Rho */ =>
Type.TyApply(rho, nextRight)
Type.apply1(rho, nextRight)
}
case (_ :: rest, Type.TyApply(left, right)) =>
val rightFree = Type.freeBoundTyVars(right :: Nil).toSet
Expand All @@ -1600,9 +1604,9 @@ object Infer {
Type.forAll(keptRight.reverse, pushDownCovariant(rest, lefts, left)) match {
case Type.ForAll(bs, l) =>
// TODO: we could possibly have an existential here?
Type.forAll(bs, Type.TyApply(l, right))
Type.forAll(bs, Type.apply1(l, right))
case rho /*: Type.Rho */=>
Type.TyApply(rho, right)
Type.apply1(rho, right)
}
case _ =>
Type.forAll(revForAlls.reverse, sigma)
Expand Down
75 changes: 61 additions & 14 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ object Type {
/**
* A type with no top level quantification
*/
sealed abstract class Rho extends Type
sealed abstract class Rho extends Type {
override def normalize: Rho
}

object Rho {
implicit val orderRho: Order[Rho] =
Expand Down Expand Up @@ -50,7 +52,7 @@ object Type {
}

sealed abstract class Leaf extends Rho {
def normalize: Type = this
override def normalize: Leaf = this
}
type Tau = Rho // no forall or exists anywhere

Expand Down Expand Up @@ -152,8 +154,8 @@ object Type {
}
}

case class TyApply(on: Type, arg: Type) extends Rho {
lazy val normalize: Type = TyApply(on.normalize, arg.normalize)
case class TyApply(on: Rho, arg: Type) extends Rho {
lazy val normalize: Rho = TyApply(on.normalize, arg.normalize)
}
case class TyConst(tpe: Const) extends Leaf
case class TyVar(toVar: Var) extends Leaf
Expand Down Expand Up @@ -190,13 +192,56 @@ object Type {
implicit val typeOrdering: Ordering[Type] = typeOrder.toOrdering

@annotation.tailrec
def applyAll(fn: Type, args: List[Type]): Type =
def applyAllRho(rho: Rho, args: List[Type]): Rho =
args match {
case Nil => fn
case a :: as =>
applyAll(TyApply(fn, a), as)
case Nil => rho
case a :: as => applyAllRho(TyApply(rho, a), as)
}

def apply1(fn: Type, arg: Type): Type =
fn match {
case rho: Rho => TyApply(rho, arg)
case q => applyAll(q, arg :: Nil)
}

def applyAll(fn: Type, args: List[Type]): Type =
fn match {
case rho: Rho => applyAllRho(rho, args)
case Quantified(q, rho) =>
val freeBound = freeBoundTyVars(fn :: args)
if (freeBound.isEmpty) {
Quantified(q, applyAllRho(rho, args))
}
else {
val freeBoundSet: Set[Var.Bound] = freeBound.toSet
val collisions = q.vars.exists { case (b, _) => freeBoundSet(b) }
if (!collisions) {
// we don't need to rename the vars
Quantified(q, applyAllRho(rho, args))
}
else {
// we have to to rename the collisions so the free set
// is unchanged
val fa1 = alignBinders(q.forallList, freeBoundSet)
val ex1 = alignBinders(q.existList, freeBoundSet ++ fa1.map(_._2))
val subMap = (fa1.iterator ++ ex1.iterator).map {
case ((b0, _), b1) => (b0, TyVar(b1))
}
.toMap[Var, Rho]

val rho1 = substituteRhoVar(rho, subMap)

val q1 = Quantification.fromLists(
forallList = fa1.map { case ((_, k), b) => (b, k)},
existList = ex1.map { case ((_, k), b) => (b, k)}
)
.get // this Option must be defined because we started with a defined q

Quantified(q1, applyAllRho(rho1, args))
}
}
}

def unapplyAll(fn: Type): (Type, List[Type]) = {
@annotation.tailrec
def loop(fn: Type, acc: List[Type]): (Type, List[Type]) =
Expand Down Expand Up @@ -353,7 +398,8 @@ object Type {
def substituteVar(t: Type, env: Map[Type.Var, Type]): Type =
if (env.isEmpty) t
else (t match {
case TyApply(on, arg) => TyApply(substituteVar(on, env), substituteVar(arg, env))
case TyApply(on, arg) =>
apply1(substituteVar(on, env), substituteVar(arg, env))
case v@TyVar(n) =>
env.get(n) match {
case Some(rho) => rho
Expand All @@ -370,7 +416,8 @@ object Type {

def substituteRhoVar(t: Type.Rho, env: Map[Type.Var, Type.Rho]): Type.Rho =
t match {
case TyApply(on, arg) => TyApply(substituteVar(on, env), substituteVar(arg, env))
case TyApply(on, arg) =>
TyApply(substituteRhoVar(on, env), substituteVar(arg, env))
case v@TyVar(n) =>
env.get(n) match {
case Some(rho) => rho
Expand Down Expand Up @@ -602,7 +649,7 @@ object Type {
val TestType: Type.TyConst = TyConst(Const.predef("Test"))
val UnitType: Type.TyConst = TyConst(Type.Const.predef("Unit"))

def const(pn: PackageName, name: TypeName): Type =
def const(pn: PackageName, name: TypeName): Type.Rho =
TyConst(Type.Const.Defined(pn, name))

object Fun {
Expand Down Expand Up @@ -635,7 +682,7 @@ object Type {

def apply(from: NonEmptyList[Type], to: Type): Type.Rho = {
val arityFn = FnType.maybeFakeName(from.length)
val withArgs = from.foldLeft(arityFn: Type)(TyApply(_, _))
val withArgs = from.foldLeft(arityFn: Type.Rho)(TyApply(_, _))
TyApply(withArgs, to)
}
def apply(from: Type, to: Type): Type.Rho =
Expand Down Expand Up @@ -684,7 +731,7 @@ object Type {

def apply(ts: List[Type]): Type = {
val sz = ts.size
val root: Type = Arity(sz)
val root: Type.Rho = Arity(sz)
ts.foldLeft(root) { (acc, t) => TyApply(acc, t)}
}

Expand Down Expand Up @@ -920,7 +967,7 @@ object Type {
def zonkRhoMeta[F[_]: Applicative](t: Type.Rho)(mfn: Meta => F[Option[Type.Rho]]): F[Type.Rho] =
t match {
case Type.TyApply(on, arg) =>
(zonkMeta(on)(mfn), zonkMeta(arg)(mfn)).mapN(Type.TyApply(_, _))
(zonkRhoMeta(on)(mfn), zonkMeta(arg)(mfn)).mapN(Type.TyApply(_, _))
case t@Type.TyMeta(m) =>
mfn(m).map {
case None => t
Expand Down
4 changes: 3 additions & 1 deletion core/src/test/scala/org/bykn/bosatsu/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ object TestUtils {
case t@Type.TyMeta(_) =>
sys.error(s"illegal meta ($t) escape in ${te.repr}")
case Type.TyApply(left, right) =>
Type.TyApply(checkType(left, bound), checkType(right, bound))
Type.TyApply(
checkType(left, bound).asInstanceOf[Type.Rho],
checkType(right, bound))
case q: Type.Quantified =>
q.copy(in = checkType(q.in, bound ++ q.vars.toList.map(_._1)).asInstanceOf[Type.Rho])
case Type.TyConst(_) => t
Expand Down
18 changes: 15 additions & 3 deletions core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object NTypeGen {
val genBound: Gen[Type.Var.Bound] =
lowerIdent.map { v => Type.Var.Bound(v) }

def genRootType(genC: Option[Gen[Type.Const]]): Gen[Type] = {
def genRootType(genC: Option[Gen[Type.Const]]): Gen[Type.Leaf] = {
val b = genBound.map(Type.TyVar(_))
genC match {
case None => b
Expand Down Expand Up @@ -97,7 +97,7 @@ object NTypeGen {
shrink(in).map(Type.exists(items.tail, _))
case _: Leaf => Stream.empty
case TyApply(on, arg) =>
on #:: arg #:: shrink(on).map(TyApply(_, arg)) #::: shrink(arg).map(TyApply(on, _))
on #:: arg #:: shrink(on).collect { case r: Type.Rho => TyApply(r, arg) } #::: shrink(arg).map(TyApply(on, _))
}
Shrink(shrink(_))
}
Expand Down Expand Up @@ -163,6 +163,18 @@ object NTypeGen {
}


def genTypeRho(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type.Rho] = {
val root = genRootType(genC)
if (d <= 0) root
else {
val recurse = Gen.lzy(genTypeRho(d - 1, genC))
val genApply = Gen.zip(recurse, genDepth(d - 1, genC))
.map { case (a, b) => Type.TyApply(a, b) }

Gen.frequency((3, root), (1, genApply))
}
}

def genDepth(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type] =
if (d <= 0) genRootType(genC)
else {
Expand All @@ -187,7 +199,7 @@ object NTypeGen {
Type.quantify(q, t)
}

val genApply = Gen.zip(recurse, recurse).map { case (a, b) => Type.TyApply(a, b) }
val genApply = Gen.zip(genTypeRho(d - 1, genC), recurse).map { case (a, b) => Type.TyApply(a, b) }

Gen.frequency(
(2, recurse),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ class RankNInferTest extends AnyFunSuite {
}

test("match with custom generic types") {
def tv(a: String): Type = Type.TyVar(Type.Var.Bound(a))
def tv(a: String): Type.Rho = Type.TyVar(Type.Var.Bound(a))

import OptionTypes._

Expand Down Expand Up @@ -418,7 +418,7 @@ class RankNInferTest extends AnyFunSuite {
}

test("Test a constructor with ForAll") {
def tv(a: String): Type = Type.TyVar(Type.Var.Bound(a))
def tv(a: String): Type.Rho = Type.TyVar(Type.Var.Bound(a))

val pureName = defType("Pure")
val optName = defType("Option")
Expand Down
12 changes: 12 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ class TypeTest extends AnyFunSuite {
(parse("foo"), List(parse("bar"))))
}

test("freeBoundVar doesn't change by applyAll") {
forAll(NTypeGen.genDepth03, Gen.listOf(NTypeGen.genDepth03)) { (ts, args) =>
val applied = Type.applyAll(ts, args)
val free0 = Type.freeBoundTyVars(ts :: args)
val free1 = Type.freeBoundTyVars(applied :: Nil)
assert(free1.toSet == free0.toSet,
s"applied = ${Type.typeParser.render(applied)}, (${Type.typeParser.render(ts)})[${
args.iterator.map(Type.typeParser.render(_)).mkString(", ")
}]})")
}
}

test("types are well ordered") {
forAll(NTypeGen.genDepth03, NTypeGen.genDepth03, NTypeGen.genDepth03) {
org.bykn.bosatsu.OrderingLaws.law(_, _, _)
Expand Down

0 comments on commit 1dbc7d2

Please sign in to comment.