Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tests around subtyping #1057

Merged
merged 1 commit into from
Oct 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import Identifier.Constructor

import org.scalatest.funsuite.AnyFunSuite

import cats.syntax.all._

class RankNInferTest extends AnyFunSuite {

val emptyRegion: Region = Region(0, 0)
Expand All @@ -39,7 +41,8 @@ class RankNInferTest extends AnyFunSuite {
val t1 = typeFrom(left)
val t2 = typeFrom(right)

Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion)
(Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion) >>
Infer.substitutionCheck(t2, t1, emptyRegion, emptyRegion))
.runFully(Map.empty, Map.empty, Type.builtInKinds)
}

Expand All @@ -48,8 +51,32 @@ class RankNInferTest extends AnyFunSuite {
assert(res.isRight, s"$left does not unify with $right\n\n$res")
}

def assertTypesDisjoint(left: String, right: String) =
assert(runUnify(left, right).isLeft, s"$left unexpectedly unifies with $right")
// Test that left is strictly smaller (not equal to right)
def assert_:<:(left: String, right: String) = {
val t1 = typeFrom(left)
val t2 = typeFrom(right)

val res1 = Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion)
.runFully(Map.empty, Map.empty, Type.builtInKinds)
assert(res1.isRight, s"$left is not :<: $right\n\n$res1")

val res2 = Infer.substitutionCheck(t2, t1, emptyRegion, emptyRegion)
.runFully(Map.empty, Map.empty, Type.builtInKinds)
assert(res2.isLeft, s"$left is unexpectedly = $right\n\n$res2")
}

def assertTypesDisjoint(left: String, right: String) = {
val t1 = typeFrom(left)
val t2 = typeFrom(right)

val res1 = Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion)
.runFully(Map.empty, Map.empty, Type.builtInKinds)
val res2 = Infer.substitutionCheck(t2, t1, emptyRegion, emptyRegion)
.runFully(Map.empty, Map.empty, Type.builtInKinds)

assert(res1.isLeft, s"$left is unexpectedly :<: $right\n\n$res1")
assert(res2.isLeft, s"$right is unexpectedly :<: $left\n\n$res2")
}

def defType(n: String): Type.Const.Defined =
Type.Const.Defined(testPackage, TypeName(Identifier.Constructor(n)))
Expand Down Expand Up @@ -159,21 +186,31 @@ class RankNInferTest extends AnyFunSuite {

test("assert some basic unifications") {
assertTypesUnify("forall a. a", "forall b. b")
assertTypesUnify("forall a. a", "Int")
assertTypesUnify("forall a, b. a -> b", "forall b. b -> Int")
assertTypesUnify("forall a, b. a -> b", "forall a. a -> (forall b. b -> b)")
assert_:<:("forall a. a", "Int")
// function is contravariant in first arg test that against the above
assert_:<:("Int -> Int", "(forall a. a) -> Int")

assert_:<:("forall a, b. a -> b", "forall b. b -> Int")
assert_:<:("forall a, b. a -> b", "forall a. a -> (forall b. b -> b)")

// forall commutes with covariant types
assertTypesUnify("forall a, b. a -> b", "forall a. a -> (forall b. b)")
assertTypesUnify("forall a. List[a]", "List[forall a. a]")

assert_:<:("forall a. a -> Int", "(forall a. a) -> Int")
assert_:<:("List[forall a. a -> Int]", "List[(forall a. a) -> Int]")

assertTypesUnify("forall f: +* -> *. f[forall a. a]", "forall a. forall f: +* -> *. f[a]")
assertTypesDisjoint("forall f: * -> *. f[forall a. a]", "forall a. forall f: * -> *. f[a]")
assert_:<:("forall a. forall f: -* -> *. f[a]", "forall f: -* -> *. f[forall a. a]")

assertTypesUnify("(forall a. a) -> Int", "(forall a. a) -> Int")
assertTypesUnify("(forall a. a -> Int) -> Int", "(forall a. a -> Int) -> Int")
assertTypesUnify("forall a, b. a -> b -> b", "forall a. a -> a -> a")
// these aren't disjoint but the right is more polymorphic
assertTypesDisjoint("forall a. a -> a -> a", "forall a, b. a -> b -> b")
assertTypesUnify("forall a, b. a -> b", "forall b, c. b -> (c -> Int)")
// assertTypesUnify("(forall a. a)[Int]", "Int")
// assertTypesUnify("(forall a. Int)[b]", "Int")
assertTypesUnify("forall a, f: * -> *. f[a]", "forall x. List[x]")
assertTypesUnify("forall a, f: +* -> *. f[a]", "forall x. List[x]")
assert_:<:("forall a, b. a -> b -> b", "forall a. a -> a -> a")
assert_:<:("forall a, b. a -> b", "forall b, c. b -> (c -> Int)")
assert_:<:("forall a, f: * -> *. f[a]", "forall x. List[x]")
assert_:<:("forall a, f: +* -> *. f[a]", "forall x. List[x]")
assertTypesDisjoint("forall a, f: -* -> *. f[a]", "forall x. List[x]")
//assertTypesUnify("(forall a, b. a -> b)[x, y]", "z -> w")

assertTypesDisjoint("Int", "String")
assertTypesDisjoint("Int -> Unit", "String")
Expand Down
Loading