Skip to content

Commit

Permalink
Ensure no meta loops in Infer (#1064)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Oct 26, 2023
1 parent ed5e2ff commit 000a8a6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
57 changes: 38 additions & 19 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,19 @@ object Infer {
def getFreeTyVars(ts: List[Type]): Infer[Set[Type.Var]] =
ts.traverse(zonkType).map(Type.freeTyVars(_).toSet)

private val pureNone: Infer[None.type] = pure(None)

def zonk(m: Type.Meta): Infer[Option[Type.Rho]] =
readMeta(m).flatMap {
case None => pure(None)
case Some(ty) =>
case None => pureNone
case sty @ Some(ty) =>
zonkRho(ty).flatMap { ty1 =>
// short out multiple hops (I guess an optimization?)
// note: this meta was already written, so we know
// the kind must match
writeMeta(m, ty1).as(Some(ty1))
if ((ty1: Type) === ty) pure(sty)
else {
// we were able to resolve more of the inner metas
// inside ty, so update the state
writeMeta(m, ty1).as(Some(ty1))
}
}
}

Expand Down Expand Up @@ -650,24 +654,39 @@ object Infer {
} yield (leftT, rightT)
}

// invariant the flexible type variable tv1 is not bound
def unifyUnboundVar(m: Type.Meta, ty2: Type.Tau, left: Region, right: Region): Infer[Unit] =
// invariant the flexible type variable ty1 is not bound
def unifyUnboundVar(ty1: Type.TyMeta, ty2: Type.Tau, left: Region, right: Region): Infer[Unit] =
ty2 match {
case meta2@Type.TyMeta(m2) =>
readMeta(m2).flatMap {
case Some(ty2) => unify(Type.TyMeta(m), ty2, left, right)
val m = ty1.toMeta
if (m2.id == m.id) unit
else (readMeta(m2).flatMap {
case Some(ty2) => unify(ty1, ty2, left, right)
case None =>
// we have to check that the kind matches before writing to a meta
if (Kind.leftSubsumesRight(m.kind, m2.kind)) {
// we have to check that the kind matches before writing to a meta
writeMeta(m, ty2)
// Both m and m2 are not set. We just point one at the other
// by convention point to the smaller item which
// definitely prevents cycles.
if (m.id > m2.id) writeMeta(m, meta2)
else {
// since we checked above we know that
// m.id != m2.id, so it is safe to write without
// creating a self-loop here
writeMeta(m2, ty1)
}
}
else {
fail(Error.KindMetaMismatch(Type.TyMeta(m), meta2, m2.kind, left, right))
fail(Error.KindMetaMismatch(ty1, meta2, m2.kind, left, right))
}
}
})
case nonMeta =>
// we have a non-meta, but inside of it (TyApply) we may have
// metas. Let's go ahead and zonk them now to minimize nesting
// metas inside metas.
zonkType(nonMeta)
.flatMap { nm2 =>
val m = ty1.toMeta
val tvs2 = Type.metaTvs(nm2 :: Nil)
if (tvs2(m)) fail(Error.UnexpectedMeta(m, nonMeta, left, right))
else {
Expand All @@ -678,24 +697,24 @@ object Infer {
writeMeta(m, nonMeta)
}
else {
fail(Error.KindMetaMismatch(Type.TyMeta(m), nonMeta, nmk, left, right))
fail(Error.KindMetaMismatch(ty1, nonMeta, nmk, left, right))
}
}
}
}
}

def unifyVar(tv: Type.Meta, t: Type.Tau, left: Region, right: Region): Infer[Unit] =
readMeta(tv).flatMap {
def unifyVar(tv: Type.TyMeta, t: Type.Tau, left: Region, right: Region): Infer[Unit] =
readMeta(tv.toMeta).flatMap {
case None => unifyUnboundVar(tv, t, left, right)
case Some(ty1) => unify(ty1, t, left, right)
}

def unify(t1: Type.Tau, t2: Type.Tau, r1: Region, r2: Region): Infer[Unit] =
(t1, t2) match {
case (Type.TyMeta(m1), Type.TyMeta(m2)) if m1.id == m2.id => unit
case (Type.TyMeta(m), tpe) => unifyVar(m, tpe, r1, r2)
case (tpe, Type.TyMeta(m)) => unifyVar(m, tpe, r2, r1)
case (meta@Type.TyMeta(_), tpe) => unifyVar(meta, tpe, r1, r2)
case (tpe, meta@Type.TyMeta(_)) => unifyVar(meta, tpe, r2, r1)
case (Type.TyApply(a1, b1), Type.TyApply(a2, b2)) =>
unifyType(a1, a2, r1, r2) *> unifyType(b1, b2, r1, r2)
case (Type.TyConst(c1), Type.TyConst(c2)) if c1 == c2 => unit
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2897,7 +2897,7 @@ def quick_sort0(cmp, left, right):
[*smalls, *bigs]
""")) { case kie@PackageError.TypeErrorIn(_, _) =>
assert(kie.message(Map.empty, Colorize.None) == """in file: <unknown source>, package QS
type error: expected type Bosatsu/Predef::Fn3[(?43, ?41) -> Bosatsu/Predef::Comparison]
type error: expected type Bosatsu/Predef::Fn3[(?13, ?9) -> Bosatsu/Predef::Comparison]
Region(403,414)
to be the same as type Bosatsu/Predef::Fn2
hint: the first type is a function with 3 arguments and the second is a function with 2 arguments.
Expand Down

0 comments on commit 000a8a6

Please sign in to comment.