diff --git a/core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala b/core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala index c0d5e9230..132e62059 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala @@ -3,8 +3,6 @@ package org.bykn.bosatsu import org.bykn.bosatsu.rankn.TypeEnv import org.scalatest.funsuite.AnyFunSuite -import cats.syntax.all._ - class ShapeTest extends AnyFunSuite { def makeTE( diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 31200abad..1fd7eb8b8 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -212,7 +212,7 @@ class RankNInferTest extends AnyFunSuite { // this could be used to test the string representation of expressions def checkTERepr(statement: String, repr: String) = - checkLast(statement)(te => assert(te.repr == repr)) + checkLast(statement)(te => assert(te.repr.render(80) == repr)) /** Test that a program is ill-typed */ diff --git a/test_workspace/AvlTree.bosatsu b/test_workspace/AvlTree.bosatsu index fe80f13e6..fac7ff005 100644 --- a/test_workspace/AvlTree.bosatsu +++ b/test_workspace/AvlTree.bosatsu @@ -210,8 +210,7 @@ contains_test = ( ]) ) -def eq_i(a, b): - cmp_Int(a, b) matches EQ +eq_i = eq_Int def add_increases_size(t, i, msg): s0 = size(t) diff --git a/test_workspace/BinNat.bosatsu b/test_workspace/BinNat.bosatsu index 37147c4b3..f1243afa8 100644 --- a/test_workspace/BinNat.bosatsu +++ b/test_workspace/BinNat.bosatsu @@ -2,7 +2,8 @@ package Bosatsu/BinNat from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat -export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev +export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, + prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, sub_Option, eq_BinNat) # a natural number with three variants: # Zero = 0 # Odd(n) = 2n + 1 @@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev # Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero)) enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat) +def is_even(b: BinNat) -> Bool: + b matches Zero | Even(_) + # Convert a BinNat into the equivalent Int # this is O(log(b)) operation def toInt(b: BinNat) -> Int: @@ -37,71 +41,168 @@ def toBinNat(n: Int) -> BinNat: (dec(n), fns) ) # Now apply all the transformations - fns.foldLeft(Zero, \n, fn -> fn(n)) + fns.foldLeft(Zero, (n, fn) -> fn(n)) + +def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison: + recur a: + case Zero: + match b: + case Odd(_) | Even(_): LT + case Zero: EQ + case Odd(a1): + match b: + case Odd(b1): cmp_BinNat(a1, b1) + case Even(b1): + # 2n + 1 <> 2m + 2 + # if n <= m, LT + # if n > m GT + match cmp_BinNat(a1, b1): + case LT | EQ: LT + case GT: GT + case Zero: GT + case Even(a1): + match b: + case Even(b1): cmp_BinNat(a1, b1) + case Odd(b1): + # 2n + 2 <> 2m + 1 + # if n >= m, GT + # if n < m LT + match cmp_BinNat(a1, b1): + case GT | EQ: GT + case LT: LT + case Zero: GT + +# this is more efficient potentially than cmp_BinNat +# because at the first difference we can stop. In the worst +# case of equality, the cost is the same. +def eq_BinNat(a: BinNat, b: BinNat) -> Bool: + recur a: + case Zero: b matches Zero + case Odd(n): + match b: + case Odd(m): eq_BinNat(n, m) + case _: False + case Even(n): + match b: + case Even(m): eq_BinNat(n, m) + case _: False # Return the next number def next(b: BinNat) -> BinNat: recur b: - Zero: Odd(Zero) Odd(half): # (2n + 1) + 1 = 2(n + 1) Even(half) Even(half1): # 2(n + 1) + 1 Odd(next(half1)) + Zero: Odd(Zero) # Return the previous number if the number is > 0, else return 0 def prev(b: BinNat) -> BinNat: recur b: - Zero: Zero - Odd(Zero): - # This breaks the law below because 0 - 1 = 0 in this function - Zero - Odd(half): + case Zero | Odd(Zero): Zero + case Odd(half): # (2n + 1) - 1 = 2n = 2(n-1 + 1) Even(prev(half)) - Even(half1): + case Even(half1): # 2(n + 1) - 1 = 2n + 1 Odd(half1) def add_BinNat(left: BinNat, right: BinNat) -> BinNat: recur left: - Zero: right Odd(left) as odd: match right: - Zero: odd Odd(right): # 2left + 1 + 2right + 1 = 2((left + right) + 1) Even(add_BinNat(left, right)) Even(right): # 2left + 1 + 2(right + 1) = 2((left + right) + 1) + 1 Odd(add_BinNat(left, right.next())) + Zero: odd Even(left) as even: match right: - Zero: even Odd(right): # 2(left + 1) + 2right + 1 = 2((left + right) + 1) + 1 Odd(add_BinNat(left, right.next())) Even(right): # 2(left + 1) + 2(right + 1) = 2((left + right + 1) + 1) Even(add_BinNat(left, right.next())) + Zero: even + Zero: right # multiply by 2 def times2(b: BinNat) -> BinNat: recur b: - Zero: Zero Odd(n): #2(2n + 1) = Even(2n) Even(times2(n)) Even(n): #2(2(n + 1)) = 2((2n + 1) + 1) Even(Odd(n)) + Zero: Zero + +# 2n - 1 if it is defined +def doub_prev(b: BinNat) -> Option[BinNat]: + match b: + case Odd(n): + # 2(2n + 1) - 1 = 4n + 1 = Odd(2n) + Some(Odd(times2(n))) + case Even(n): + # 2(2n + 2) - 1 = 4n + 3 = 2(2n + 1) + 1 + Some(Odd(Odd(n))) + case Zero: None + +def sub_Option(left: BinNat, right: BinNat) -> Option[BinNat]: + recur left: + case Zero: + match right: + case Zero: Some(Zero) + case _: None + case Odd(left) as odd: + match right: + case Zero: Some(odd) + case Odd(right): + # (2n + 1) - (2m + 1) = 2(n - m) + match sub_Option(left, right): + case Some(n_m): Some(times2(n_m)) + case None: None + case Even(right): + # (2n + 1) - (2m + 2) = 2(n - m) - 1 + # note if (2n + 1) > (2m + 2), then n > m + match sub_Option(left, right): + case Some(n_m): doub_prev(n_m) + case None: None + case Even(left) as even: + match right: + case Zero: Some(even) + case Odd(right): + # Even can't equal odd, so we never return + # zero. Next an even - odd is odd. + # (2n + 2) - (2m + 1) = 2(n - m) + 1 + match sub_Option(left, right): + case Some(n_m): Some(Odd(n_m)) + case None: None + case Even(right): + # (2n + 2) - (2m + 2) = 2(n - m) + match sub_Option(left, right): + case Some(n_m): Some(times2(n_m)) + case None: None + +def sub_BinNat(left: BinNat, right: BinNat) -> BinNat: + match sub_Option(left, right): + case Some(v): v + case None: Zero def div2(b: BinNat) -> BinNat: match b: case Zero: Zero - case Odd(n): n - case Even(n): prev(n) + case Odd(n): + # (2n + 1)/2 = n + n + case Even(n): + # (2n + 2)/2 = n + 1 + next(n) # multiply two BinNat together def times_BinNat(left: BinNat, right: BinNat) -> BinNat: @@ -122,6 +223,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat: prod = times_BinNat(left, right) times2(prod.add_BinNat(right)) +one = Odd(Zero) + +def exp(base: BinNat, power: BinNat) -> BinNat: + recur power: + case Zero: one + case Odd(n): + # b^(2n + 1) == (b^n) * (b^n) * b + bn = exp(base, n) + bn.times_BinNat(bn).times_BinNat(base) + case Even(n): + # b^(2n + 2) = (b^n * b)^2 + bn = exp(base, n) + bn1 = bn.times_BinNat(base) + bn1.times_BinNat(bn1) + # fold(fn, a, Zero) = a # fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1) def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a: @@ -158,7 +274,6 @@ def next_law(i, msg): def times2_law(i, msg): Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg) -one = Odd(Zero) two = one.next() three = two.next() four = three.next() @@ -210,4 +325,5 @@ test = TestSuite( Assertion(fib(two).toInt().eq_Int(2), "fib(2) == 2"), Assertion(fib(three).toInt().eq_Int(3), "fib(3) == 3"), Assertion(fib(four).toInt().eq_Int(5), "fib(4) == 5"), + Assertion(cmp_BinNat(54.toBinNat(), 54.toBinNat()) matches EQ, "54 == 54"), ]) diff --git a/test_workspace/Nat.bosatsu b/test_workspace/Nat.bosatsu index b41872b48..664edebcd 100644 --- a/test_workspace/Nat.bosatsu +++ b/test_workspace/Nat.bosatsu @@ -122,7 +122,7 @@ n4 = Succ(n3) n5 = Succ(n4) def operator ==(i0: Int, i1: Int): - cmp_Int(i0, i1) matches EQ + eq_Int(i0, i1) def addLaw(n1: Nat, n2: Nat, label: String) -> Test: Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label) diff --git a/test_workspace/NumberProps.bosatsu b/test_workspace/NumberProps.bosatsu index 6fafaebc2..9fd9b5d0a 100644 --- a/test_workspace/NumberProps.bosatsu +++ b/test_workspace/NumberProps.bosatsu @@ -1,6 +1,10 @@ package Bosatsu/NumberProps -from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat) +from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat, is_even as is_even_BinNat, + times2 as times2_BinNat, div2 as div2_BinNat, Zero as BNZero, Even as BNEven, + times_BinNat, exp as exp_BinNat, cmp_BinNat, toInt as binNat_to_Int, add_BinNat, + next as next_BinNat, sub_BinNat, eq_BinNat, sub_Option as sub_BinNat_Option, + ) from Bosatsu/Nat import (Nat, Zero as NZero, Succ as NSucc, to_Nat as int_to_Nat, is_even as is_even_Nat, times2 as times2_Nat, div2 as div2_Nat, cmp_Nat, to_Int as nat_to_Int, add as add_Nat, mult as mult_Nat, exp as exp_Nat, sub_Nat) @@ -17,8 +21,6 @@ rand_Int: Rand[Int] = from_pair(int_range(128), geometric_Int) rand_Nat: Rand[Nat] = rand_Int.map_Rand(int_to_Nat) rand_BinNat: Rand[BinNat] = rand_Int.map_Rand(int_to_BinNat) -eq_Int = (a, b) -> a.cmp_Int(b) matches EQ - int_props = suite_Prop( "Int props", [ @@ -51,6 +53,7 @@ def exp_Int(base: Int, power: Int) -> Int: int_loop(power, 1, (p, acc) -> (p.sub(1), acc.times(base))) small_rand_Nat: Rand[Nat] = int_range(7).map_Rand(int_to_Nat) +small_rand_BinNat: Rand[BinNat] = int_range(7).map_Rand(int_to_BinNat) nat_props = suite_Prop( "Nat props", @@ -72,7 +75,7 @@ nat_props = suite_Prop( forall_Prop(prod_Rand(rand_Nat, rand_Nat), "add homomorphism", ((n1, n2)) -> ( n3 = add_Nat(n1, n2) i3 = add(n1.nat_to_Int(), n2.nat_to_Int()) - Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "add homomorphism") + Assertion(eq_Int(n3.nat_to_Int(), i3), "add homomorphism") )), forall_Prop(prod_Rand(rand_Nat, rand_Nat), "sub_Nat homomorphism", ((n1, n2)) -> ( n3 = sub_Nat(n1, n2) @@ -81,19 +84,19 @@ nat_props = suite_Prop( match cmp_Int(i1, i2): case EQ | GT: i3 = sub(i1, i2) - Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "sub_Nat homomorphism") + Assertion(eq_Int(n3.nat_to_Int(), i3), "sub_Nat homomorphism") case LT: Assertion(n3 matches NZero, "sub to zero") )), forall_Prop(prod_Rand(rand_Nat, rand_Nat), "mult homomorphism", ((n1, n2)) -> ( n3 = mult_Nat(n1, n2) i3 = times(n1.nat_to_Int(), n2.nat_to_Int()) - Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "mult homomorphism") + Assertion(eq_Int(n3.nat_to_Int(), i3), "mult homomorphism") )), forall_Prop(prod_Rand(small_rand_Nat, small_rand_Nat), "exp homomorphism", ((n1, n2)) -> ( n3 = exp_Nat(n1, n2) i3 = exp_Int(n1.nat_to_Int(), n2.nat_to_Int()) - Assertion(cmp_Int(n3.nat_to_Int(), i3) matches EQ, "exp homomorphism") + Assertion(eq_Int(n3.nat_to_Int(), i3), "exp homomorphism") )), forall_Prop(rand_Nat, "times2 == x -> mult(2, x)", n -> ( t2 = n.times2_Nat() @@ -103,7 +106,76 @@ nat_props = suite_Prop( ] ) -all_props = [int_props, nat_props] +binnat_props = suite_Prop( + "BinNat props", + [ + forall_Prop(rand_BinNat, "if is_even(n) then times2(div2(n)) == n", n -> ( + if is_even_BinNat(n): + n1 = times2_BinNat(div2_BinNat(n)) + n_str = binNat_to_Int(n).int_to_String() + n1_str = binNat_to_Int(n1).int_to_String() + Assertion(cmp_BinNat(n1, n) matches EQ, "even, times2/div2: n = ${n_str}, n1 = ${n1_str}") + else: + # we return the previous number + n1 = times2_BinNat(div2_BinNat(n)) + n_str = binNat_to_Int(n).int_to_String() + n1_str = binNat_to_Int(n1).int_to_String() + Assertion(cmp_BinNat(n1.next_BinNat(), n) matches EQ, "times2/div2: n = ${n_str}, n1 = ${n1_str}") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "cmp_BinNat matches cmp_Int", ((n1, n2)) -> ( + cmp_n = cmp_BinNat(n1, n2) + cmp_i = cmp_Int(n1.binNat_to_Int(), n2.binNat_to_Int()) + Assertion(cmp_Comparison(cmp_n, cmp_i) matches EQ, "cmp_BinNat") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "cmp_BinNat matches eq_BinNat", ((n1, n2)) -> ( + eq1 = cmp_BinNat(n1, n2) matches EQ + eq2 = eq_BinNat(n1, n2) + correct = (eq1, eq2) matches (True, True) | (False, False) + Assertion(correct, "cmp vs eq consistency") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "add homomorphism", ((n1, n2)) -> ( + n3 = add_BinNat(n1, n2) + i3 = add(n1.binNat_to_Int(), n2.binNat_to_Int()) + Assertion(eq_Int(n3.binNat_to_Int(), i3), "add homomorphism") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "sub_BinNat homomorphism", ((n1, n2)) -> ( + n3 = sub_BinNat(n1, n2) + i1 = n1.binNat_to_Int() + i2 = n2.binNat_to_Int() + match cmp_Int(i1, i2): + case EQ | GT: + i3 = sub(i1, i2) + Assertion(eq_Int(n3.binNat_to_Int(), i3), "sub_BinNat homomorphism") + case LT: + Assertion(n3 matches BNZero, "sub to zero") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "sub_BinNat_Option is None implies a < b", ((n1, n2)) -> ( + match sub_BinNat_Option(n1, n2): + case Some(n3): + n3_sub = sub_BinNat(n1, n2) + Assertion(cmp_BinNat(n3, n3_sub) matches EQ, "sub_BinNat same as sub_BinNat_Option when Some") + case None: + Assertion(cmp_BinNat(n1, n2) matches LT, "otherwise n1 < n2") + )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "mult homomorphism", ((n1, n2)) -> ( + n3 = times_BinNat(n1, n2) + i3 = times(n1.binNat_to_Int(), n2.binNat_to_Int()) + Assertion(eq_Int(n3.binNat_to_Int(), i3), "mult homomorphism") + )), + forall_Prop(prod_Rand(small_rand_BinNat, small_rand_BinNat), "exp homomorphism", ((n1, n2)) -> ( + n3 = exp_BinNat(n1, n2) + i3 = exp_Int(n1.binNat_to_Int(), n2.binNat_to_Int()) + Assertion(eq_Int(n3.binNat_to_Int(), i3), "exp homomorphism") + )), + forall_Prop(rand_BinNat, "times2 == x -> mult(2, x)", n -> ( + t2 = n.times2_BinNat() + t2_2 = times_BinNat(n, BNEven(BNZero)) + Assertion(cmp_BinNat(t2, t2_2) matches EQ, "times2 == mult(2, _)") + )), + ] +) + +all_props = [int_props, nat_props, binnat_props] seed = 123456 test = TestSuite("properties", [ diff --git a/test_workspace/Properties.bosatsu b/test_workspace/Properties.bosatsu index 78f979d4d..0abc099db 100644 --- a/test_workspace/Properties.bosatsu +++ b/test_workspace/Properties.bosatsu @@ -33,7 +33,7 @@ def run_Prop(prop: Prop, trials: Int, seed: Int) -> Test: signed64 = int_range(1 << 64).map_Rand(i -> i - (1 << 63)) -def operator ==(a, b): cmp_Int(a, b) matches EQ +def operator ==(a, b): eq_Int(a, b) not_law = forall_Prop( signed64,