Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add property checks for BinNat #1176

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions test_workspace/BinNat.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package Bosatsu/BinNat

from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat

export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2,
prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat)
# a natural number with three variants:
# Zero = 0
# Odd(n) = 2n + 1
Expand All @@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
# Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero))
enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat)

def is_even(b: BinNat) -> Bool:
b matches Zero | Even(_)

# Convert a BinNat into the equivalent Int
# this is O(log(b)) operation
def toInt(b: BinNat) -> Int:
Expand All @@ -37,7 +41,36 @@ def toBinNat(n: Int) -> BinNat:
(dec(n), fns)
)
# Now apply all the transformations
fns.foldLeft(Zero, \n, fn -> fn(n))
fns.foldLeft(Zero, (n, fn) -> fn(n))

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
case Even(b1):
# 2n + 1 <> 2m + 2
# if n <= m, LT
# if n > m GT
match cmp_BinNat(a1, b1):
case LT | EQ: LT
case GT: GT
case Zero: GT
case Even(a1):
match b:
case Even(b1): cmp_BinNat(a1, b1)
case Odd(b1):
# 2n + 2 <> 2m + 1
# if n >= m, GT
# if n < m LT
match cmp_BinNat(a1, b1):
case GT | EQ: GT
case LT: LT
case Zero: GT

# Return the next number
def next(b: BinNat) -> BinNat:
Expand Down Expand Up @@ -97,11 +130,45 @@ def times2(b: BinNat) -> BinNat:
#2(2(n + 1)) = 2((2n + 1) + 1)
Even(Odd(n))

def sub_BinNat(left: BinNat, right: BinNat) -> BinNat:
recur left:
case Zero: Zero
case Odd(left) as odd:
match right:
case Zero: odd
case Odd(right):
# (2n + 1) - (2m + 1) = 2(n - m)
times2(sub_BinNat(left, right))
case Even(right):
# (2n + 1) - (2m + 2) = 2(n - m) - 1
times2(sub_BinNat(left, right)).prev()
case Even(left) as even:
match right:
Zero: even
Odd(right):
# (2n + 2) - (2m + 1)
# if n >= m: 2(n - m) + 1
# if n < m: 0
diff = sub_BinNat(left, right)
match diff:
case Zero:
match cmp_BinNat(left, right):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems bad to do internally, seems like we should be able to do this once outside the loop... should think some more.

case LT: Zero
case _: Odd(diff)
case _: Odd(diff)
Even(right):
# (2n + 2) - (2m + 2) = 2(n - m)
times2(sub_BinNat(left, right))

def div2(b: BinNat) -> BinNat:
match b:
case Zero: Zero
case Odd(n): n
case Even(n): prev(n)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug! yay for property checks!

case Odd(n):
# (2n + 1)/2 = n
n
case Even(n):
# (2n + 2)/2 = n + 1
next(n)

# multiply two BinNat together
def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
Expand All @@ -122,6 +189,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
prod = times_BinNat(left, right)
times2(prod.add_BinNat(right))

one = Odd(Zero)

def exp(base: BinNat, power: BinNat) -> BinNat:
recur power:
case Zero: one
case Odd(n):
# b^(2n + 1) == (b^n) * (b^n) * b
bn = exp(base, n)
bn.times_BinNat(bn).times_BinNat(base)
case Even(n):
# b^(2n + 2) = (b^n * b)^2
bn = exp(base, n)
bn1 = bn.times_BinNat(base)
bn1.times_BinNat(bn1)

# fold(fn, a, Zero) = a
# fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1)
def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a:
Expand Down Expand Up @@ -158,7 +240,6 @@ def next_law(i, msg):
def times2_law(i, msg):
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg)

one = Odd(Zero)
two = one.next()
three = two.next()
four = three.next()
Expand Down
60 changes: 58 additions & 2 deletions test_workspace/NumberProps.bosatsu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package Bosatsu/NumberProps

from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat)
from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat, is_even as is_even_BinNat,
times2 as times2_BinNat, div2 as div2_BinNat, Zero as BNZero, Even as BNEven,
times_BinNat, exp as exp_BinNat, cmp_BinNat, toInt as binNat_to_Int, add_BinNat,
next as next_BinNat, sub_BinNat,
)
from Bosatsu/Nat import (Nat, Zero as NZero, Succ as NSucc, to_Nat as int_to_Nat, is_even as is_even_Nat,
times2 as times2_Nat, div2 as div2_Nat, cmp_Nat, to_Int as nat_to_Int, add as add_Nat,
mult as mult_Nat, exp as exp_Nat, sub_Nat)
Expand Down Expand Up @@ -51,6 +55,7 @@ def exp_Int(base: Int, power: Int) -> Int:
int_loop(power, 1, (p, acc) -> (p.sub(1), acc.times(base)))

small_rand_Nat: Rand[Nat] = int_range(7).map_Rand(int_to_Nat)
small_rand_BinNat: Rand[BinNat] = int_range(7).map_Rand(int_to_BinNat)

nat_props = suite_Prop(
"Nat props",
Expand Down Expand Up @@ -103,7 +108,58 @@ nat_props = suite_Prop(
]
)

all_props = [int_props, nat_props]
binnat_props = suite_Prop(
"BinNat props",
[
forall_Prop(rand_BinNat, "if is_even(n) then times2(div2(n)) == n", n -> (
if is_even_BinNat(n):
n1 = times2_BinNat(div2_BinNat(n))
Assertion(cmp_BinNat(n1, n) matches EQ, "times2/div2")
else:
# we return the previous number
n1 = times2_BinNat(div2_BinNat(n))
Assertion(cmp_BinNat(n1.next_BinNat(), n) matches EQ, "times2/div2")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "cmp_BinNat matches cmp_Int", ((n1, n2)) -> (
cmp_n = cmp_BinNat(n1, n2)
cmp_i = cmp_Int(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(cmp_Comparison(cmp_n, cmp_i) matches EQ, "cmp_BinNat")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "add homomorphism", ((n1, n2)) -> (
n3 = add_BinNat(n1, n2)
i3 = add(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(cmp_Int(n3.binNat_to_Int(), i3) matches EQ, "add homomorphism")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "sub_BinNat homomorphism", ((n1, n2)) -> (
n3 = sub_BinNat(n1, n2)
i1 = n1.binNat_to_Int()
i2 = n2.binNat_to_Int()
match cmp_Int(i1, i2):
case EQ | GT:
i3 = sub(i1, i2)
Assertion(cmp_Int(n3.binNat_to_Int(), i3) matches EQ, "sub_BinNat homomorphism")
case LT:
Assertion(n3 matches BNZero, "sub to zero")
)),
forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "mult homomorphism", ((n1, n2)) -> (
n3 = times_BinNat(n1, n2)
i3 = times(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(cmp_Int(n3.binNat_to_Int(), i3) matches EQ, "mult homomorphism")
)),
forall_Prop(prod_Rand(small_rand_BinNat, small_rand_BinNat), "exp homomorphism", ((n1, n2)) -> (
n3 = exp_BinNat(n1, n2)
i3 = exp_Int(n1.binNat_to_Int(), n2.binNat_to_Int())
Assertion(cmp_Int(n3.binNat_to_Int(), i3) matches EQ, "exp homomorphism")
)),
forall_Prop(rand_BinNat, "times2 == x -> mult(2, x)", n -> (
t2 = n.times2_BinNat()
t2_2 = times_BinNat(n, BNEven(BNZero))
Assertion(cmp_BinNat(t2, t2_2) matches EQ, "times2 == mult(2, _)")
)),
]
)

all_props = [int_props, nat_props, binnat_props]

seed = 123456
test = TestSuite("properties", [
Expand Down
Loading