Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add property checks for BinNat #1176

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/ShapeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package org.bykn.bosatsu
import org.bykn.bosatsu.rankn.TypeEnv
import org.scalatest.funsuite.AnyFunSuite

import cats.syntax.all._

class ShapeTest extends AnyFunSuite {

def makeTE(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class RankNInferTest extends AnyFunSuite {

// this could be used to test the string representation of expressions
def checkTERepr(statement: String, repr: String) =
checkLast(statement)(te => assert(te.repr == repr))
checkLast(statement)(te => assert(te.repr.render(80) == repr))

/** Test that a program is ill-typed
*/
Expand Down
3 changes: 1 addition & 2 deletions test_workspace/AvlTree.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ contains_test = (
])
)

def eq_i(a, b):
cmp_Int(a, b) matches EQ
eq_i = eq_Int

def add_increases_size(t, i, msg):
s0 = size(t)
Expand Down
148 changes: 132 additions & 16 deletions test_workspace/BinNat.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package Bosatsu/BinNat

from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat

export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2,
prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, sub_Option, eq_BinNat)
# a natural number with three variants:
# Zero = 0
# Odd(n) = 2n + 1
Expand All @@ -11,6 +12,9 @@ export BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, prev
# Zero, Odd(Zero), Even(Zero), Odd(Odd(Zero)), Even(Odd(Zero))
enum BinNat: Zero, Odd(half: BinNat), Even(half1: BinNat)

def is_even(b: BinNat) -> Bool:
b matches Zero | Even(_)

# Convert a BinNat into the equivalent Int
# this is O(log(b)) operation
def toInt(b: BinNat) -> Int:
Expand All @@ -37,71 +41,168 @@ def toBinNat(n: Int) -> BinNat:
(dec(n), fns)
)
# Now apply all the transformations
fns.foldLeft(Zero, \n, fn -> fn(n))
fns.foldLeft(Zero, (n, fn) -> fn(n))

def cmp_BinNat(a: BinNat, b: BinNat) -> Comparison:
recur a:
case Zero:
match b:
case Odd(_) | Even(_): LT
case Zero: EQ
case Odd(a1):
match b:
case Odd(b1): cmp_BinNat(a1, b1)
case Even(b1):
# 2n + 1 <> 2m + 2
# if n <= m, LT
# if n > m GT
match cmp_BinNat(a1, b1):
case LT | EQ: LT
case GT: GT
case Zero: GT
case Even(a1):
match b:
case Even(b1): cmp_BinNat(a1, b1)
case Odd(b1):
# 2n + 2 <> 2m + 1
# if n >= m, GT
# if n < m LT
match cmp_BinNat(a1, b1):
case GT | EQ: GT
case LT: LT
case Zero: GT

# this is more efficient potentially than cmp_BinNat
# because at the first difference we can stop. In the worst
# case of equality, the cost is the same.
def eq_BinNat(a: BinNat, b: BinNat) -> Bool:
recur a:
case Zero: b matches Zero
case Odd(n):
match b:
case Odd(m): eq_BinNat(n, m)
case _: False
case Even(n):
match b:
case Even(m): eq_BinNat(n, m)
case _: False

# Return the next number
def next(b: BinNat) -> BinNat:
recur b:
Zero: Odd(Zero)
Odd(half):
# (2n + 1) + 1 = 2(n + 1)
Even(half)
Even(half1):
# 2(n + 1) + 1
Odd(next(half1))
Zero: Odd(Zero)

# Return the previous number if the number is > 0, else return 0
def prev(b: BinNat) -> BinNat:
recur b:
Zero: Zero
Odd(Zero):
# This breaks the law below because 0 - 1 = 0 in this function
Zero
Odd(half):
case Zero | Odd(Zero): Zero
case Odd(half):
# (2n + 1) - 1 = 2n = 2(n-1 + 1)
Even(prev(half))
Even(half1):
case Even(half1):
# 2(n + 1) - 1 = 2n + 1
Odd(half1)

def add_BinNat(left: BinNat, right: BinNat) -> BinNat:
recur left:
Zero: right
Odd(left) as odd:
match right:
Zero: odd
Odd(right):
# 2left + 1 + 2right + 1 = 2((left + right) + 1)
Even(add_BinNat(left, right))
Even(right):
# 2left + 1 + 2(right + 1) = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next()))
Zero: odd
Even(left) as even:
match right:
Zero: even
Odd(right):
# 2(left + 1) + 2right + 1 = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next()))
Even(right):
# 2(left + 1) + 2(right + 1) = 2((left + right + 1) + 1)
Even(add_BinNat(left, right.next()))
Zero: even
Zero: right

# multiply by 2
def times2(b: BinNat) -> BinNat:
recur b:
Zero: Zero
Odd(n):
#2(2n + 1) = Even(2n)
Even(times2(n))
Even(n):
#2(2(n + 1)) = 2((2n + 1) + 1)
Even(Odd(n))
Zero: Zero

# 2n - 1 if it is defined
def doub_prev(b: BinNat) -> Option[BinNat]:
match b:
case Odd(n):
# 2(2n + 1) - 1 = 4n + 1 = Odd(2n)
Some(Odd(times2(n)))
case Even(n):
# 2(2n + 2) - 1 = 4n + 3 = 2(2n + 1) + 1
Some(Odd(Odd(n)))
case Zero: None

def sub_Option(left: BinNat, right: BinNat) -> Option[BinNat]:
recur left:
case Zero:
match right:
case Zero: Some(Zero)
case _: None
case Odd(left) as odd:
match right:
case Zero: Some(odd)
case Odd(right):
# (2n + 1) - (2m + 1) = 2(n - m)
match sub_Option(left, right):
case Some(n_m): Some(times2(n_m))
case None: None
case Even(right):
# (2n + 1) - (2m + 2) = 2(n - m) - 1
# note if (2n + 1) > (2m + 2), then n > m
match sub_Option(left, right):
case Some(n_m): doub_prev(n_m)
case None: None
case Even(left) as even:
match right:
case Zero: Some(even)
case Odd(right):
# Even can't equal odd, so we never return
# zero. Next an even - odd is odd.
# (2n + 2) - (2m + 1) = 2(n - m) + 1
match sub_Option(left, right):
case Some(n_m): Some(Odd(n_m))
case None: None
case Even(right):
# (2n + 2) - (2m + 2) = 2(n - m)
match sub_Option(left, right):
case Some(n_m): Some(times2(n_m))
case None: None

def sub_BinNat(left: BinNat, right: BinNat) -> BinNat:
match sub_Option(left, right):
case Some(v): v
case None: Zero

def div2(b: BinNat) -> BinNat:
match b:
case Zero: Zero
case Odd(n): n
case Even(n): prev(n)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug! yay for property checks!

case Odd(n):
# (2n + 1)/2 = n
n
case Even(n):
# (2n + 2)/2 = n + 1
next(n)

# multiply two BinNat together
def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
Expand All @@ -122,6 +223,21 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat:
prod = times_BinNat(left, right)
times2(prod.add_BinNat(right))

one = Odd(Zero)

def exp(base: BinNat, power: BinNat) -> BinNat:
recur power:
case Zero: one
case Odd(n):
# b^(2n + 1) == (b^n) * (b^n) * b
bn = exp(base, n)
bn.times_BinNat(bn).times_BinNat(base)
case Even(n):
# b^(2n + 2) = (b^n * b)^2
bn = exp(base, n)
bn1 = bn.times_BinNat(base)
bn1.times_BinNat(bn1)

# fold(fn, a, Zero) = a
# fold(fn, a, n) = fold(fn, fn(a, n - 1), n - 1)
def fold_left_BinNat(fn: (a, BinNat) -> a, init: a, cnt: BinNat) -> a:
Expand Down Expand Up @@ -158,7 +274,6 @@ def next_law(i, msg):
def times2_law(i, msg):
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg)

one = Odd(Zero)
two = one.next()
three = two.next()
four = three.next()
Expand Down Expand Up @@ -210,4 +325,5 @@ test = TestSuite(
Assertion(fib(two).toInt().eq_Int(2), "fib(2) == 2"),
Assertion(fib(three).toInt().eq_Int(3), "fib(3) == 3"),
Assertion(fib(four).toInt().eq_Int(5), "fib(4) == 5"),
Assertion(cmp_BinNat(54.toBinNat(), 54.toBinNat()) matches EQ, "54 == 54"),
])
2 changes: 1 addition & 1 deletion test_workspace/Nat.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ n4 = Succ(n3)
n5 = Succ(n4)

def operator ==(i0: Int, i1: Int):
cmp_Int(i0, i1) matches EQ
eq_Int(i0, i1)

def addLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label)
Expand Down
Loading
Loading