diff --git a/core/src/main/scala/org/bykn/bosatsu/Nat.scala b/core/src/main/scala/org/bykn/bosatsu/Nat.scala new file mode 100644 index 000000000..784ce7ea1 --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/Nat.scala @@ -0,0 +1,219 @@ +package org.bykn.bosatsu +import org.bykn.bosatsu.Nat.Shift +import org.bykn.bosatsu.Nat.Small + +import Nat._ + +sealed abstract class Nat { lhs => + def toBigInt: BigInt = + lhs match { + case Small(asInt) => + BigInt(toLong(asInt)) + case Shift(x, b) => + // m(x + 1) + b + two_32_BigInt * (x.inc.toBigInt) + toLong(b) + } + + override def toString = toBigInt.toString + + def maybeLong: Option[Long] = + lhs match { + case Small(asInt) => Some(toLong(asInt)) + case Shift(x, b) => + // 2^32 * (x + 1) + b + x.maybeLong match { + case Some(v0) => + if (v0 < Int.MaxValue) { + val value = v0 + 1 + Some((value << 32) + toLong(b)) + } + else None + case None => None + } + } + + + def inc: Nat = + lhs match { + case Small(asInt) => + val next = asInt + 1 + if (next != 0) wrapInt(next) + else two_32 + case Shift(x, b) => + // m(x + 1) + b + val b1 = b + 1 + if (b1 != 0) Shift(x, b1) + else { + // m(x + 1) + m + // m(x + 1 + 1) + Shift(x.inc, 0) + } + } + + def dec: Nat = + lhs match { + case Small(asInt) => + if (asInt != 0) wrapInt(asInt - 1) + else zero + case Shift(x, b) => + if (b != 0) Shift(x, b - 1) + else { + if (x.isZero) wrapInt(-1) + else Shift(x.dec, -1) + } + } + + def isZero: Boolean = + lhs match { + case Small(0) => true + case _ => false + } + + // this * 2^32 + def shift_32: Nat = + lhs match { + case Small(asInt) => + if (asInt == 0) zero + else { + // m*(b - 1 + 1) + Shift(wrapInt(asInt - 1), 0) + } + case Shift(x, b) => + // m*(m*(x + 1) + b) + // m*((m(x + 1) + (b - 1)) + 1) + 0 + if (b != 0) { + val b1 = b - 1 + Shift(Shift(x, b1), 0) + } + else if (x.isZero) { + // m * (m(0 + 1) + 0) + // m * ((m - 1) + 1) + 0 + val inner = wrapLong(0xFFFFFFFFL) + Shift(inner, 0) + } + else { + // x > 0 + //m(m(x + 1)) = + //m(m(x + 1) - 1 + 1) + 0 + //m((m(x + 1) - 1) + 1) + 0 + val inner = x.inc.shift_32.dec + Shift(inner, 0) + } + } + + def +(rhs: Nat): Nat = { + lhs match { + case Small(l) => + rhs match { + case Small(r) => + wrapLong(toLong(l) + toLong(r)) + case Shift(x, b) => + val res = toLong(l) + toLong(b) + val low = lowBits(res) + val high = highBits(res) + if (high == 0) Shift(x, low) + else { + Shift(x + wrapInt(high), low) + } + } + + case Shift(x, b) => + rhs match { + case Small(l) => + val res = toLong(l) + toLong(b) + val low = lowBits(res) + val high = highBits(res) + if (high == 0) Shift(x, low) + else { + Shift(x + wrapInt(high), low) + } + case Shift(x1, b1) => + // (m(x + 1) + b) + (m (x1 + 1) + b1) = m(x + x1 + 1 + 1) + (b + b1) + val res = toLong(b) + toLong(b1) + val low = lowBits(res) + val high = highBits(res) + val xs = (x + x1).inc + if (high == 0) Shift(xs, low) + else { + Shift(xs + wrapInt(high), low) + } + } + } + } + def *(rhs: Nat): Nat = { + lhs match { + case Small(l) => + if (l == 0) zero + else (rhs match { + case Small(r) => + // can't overflow Long + wrapLong(toLong(l) * toLong(r)) + case Shift(x, b) => + // (m(x + 1) + b) * l + (x.inc * lhs).shift_32 + wrapInt(b) * lhs + }) + + case Shift(x, b) => + // (m(x + 1) + b) * rhs + // (x + 1)*rhs*m + b * rhs + (rhs * x.inc).shift_32 + rhs * wrapInt(b) + } + } +} + +object Nat { + private val two_32_BigInt = BigInt(1L << 32) + + private val intMaskLow: Long = 0xFFFFFFFFL + def toLong(i: Int): Long = i.toLong & intMaskLow + def lowBits(l: Long): Int = l.toInt + private def highBits(l: Long): Int = lowBits(l >> 32) + + // all numbers from [0, 2^{32} - 1] + private case class Small(asInt: Int) extends Nat + // base * (x + 1) + b + private case class Shift(x: Nat, b: Int) extends Nat + + private val cache: Array[Small] = + Array.tabulate(1024)(Small(_)) + + val zero: Nat = cache(0) + val one: Nat = cache(1) + val two_32: Nat = Shift(zero, 0) + + // if the number is <= 0, return 0 + def fromInt(i: Int): Nat = + if (i < 0) zero + else if (i < cache.length) cache(i) + else Small(i) + + // if the number is <= 0, return 0 + def fromLong(l: Long): Nat = + if (l < 0) zero + else wrapLong(l) + + def wrapInt(i: Int): Nat = + if (0 <= i && i < cache.length) cache(i) + else Small(i) + + def wrapLong(l: Long): Nat = { + val low = lowBits(l) + val high = highBits(l) + if (high == 0) wrapInt(low) + else { + Shift(wrapInt(high - 1), low) + } + } + + // if b < 0 return 0 + def fromBigInt(b: BigInt): Nat = + if (b <= 0) zero + else if (b < two_32_BigInt) { + fromLong(b.toLong) + } + else { + val low = b % two_32_BigInt + val high = b >> 32 + Shift(fromBigInt(high - 1), lowBits(low.toLong)) + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/bykn/bosatsu/NatTest.scala b/core/src/test/scala/org/bykn/bosatsu/NatTest.scala new file mode 100644 index 000000000..9625c2d34 --- /dev/null +++ b/core/src/test/scala/org/bykn/bosatsu/NatTest.scala @@ -0,0 +1,154 @@ +package org.bykn.bosatsu + +import org.scalacheck.{Gen, Prop} + +import Prop.forAll + +class NatTest extends munit.ScalaCheckSuite { + + override def scalaCheckTestParameters = + super.scalaCheckTestParameters + .withMinSuccessfulTests(if (Platform.isScalaJvm) 100000 else 100) + .withMaxDiscardRatio(10) + + //override def scalaCheckInitialSeed = "YOFqcGzXOFtFVgRFxOmODi5100tVovDS3EPOv0Ihk4C=" + + lazy val genNat: Gen[Nat] = { + val recur = Gen.lzy(genNat) + Gen.frequency( + // make sure to exercise the cached table + (1, Gen.chooseNum(0, 1024, 1).map(Nat.fromInt(_))), + (5, Gen.chooseNum(0, Long.MaxValue, Int.MaxValue.toLong, Int.MaxValue.toLong + 1).map(Nat.fromLong(_))), + (1, Gen.zip(recur, recur).map { case (a, b) => a + b }), + (1, Gen.zip(recur, recur).map { case (a, b) => a * b }) + ) + } + + test("constants are right") { + assertEquals(Nat.zero.maybeLong, Some(0L)) + assertEquals(Nat.one.maybeLong, Some(1L)) + assertEquals(Nat.two_32.maybeLong, Some(1L << 32)) + } + property("fromInt/maybeLong identity") { + forAll { (i: Int) => + val n = Nat.fromInt(i) + n.maybeLong match { + case None => fail(s"couldn't maybeLong $i") + case Some(l) => assert(i < 0 || (l.toInt == i)) + } + } + } + property("fromLong/maybeLong identity") { + forAll { (i: Long) => + val n = Nat.fromLong(i) + n.maybeLong match { + case Some(l) => assert(i < 0 || (l == i)) + case None => fail(s"couldn't maybeLong $i") + } + } + } + + def trunc(i: Int): Long = if (i < 0) 0L else i.toLong + + property("toLong => lowBits is identity") { + val big = 0x80000000L + + assertEquals(Nat.lowBits(big), Int.MinValue, s"${Nat.lowBits(big)}") + forAll(Gen.chooseNum(Int.MinValue, Int.MaxValue)) { (i: Int) => + val l = Nat.toLong(i) + assertEquals(Nat.lowBits(Nat.toLong(i)), i, s"long = $l") + } + } + + property("x + y homomorphism") { + forAll(genNat, genNat) { (ni, nj) => + val nk = ni + nj + assertEquals(nk.toBigInt, ni.toBigInt + nj.toBigInt) + } + } + + property("x * y = y * x") { + forAll(genNat, genNat) { (n1, n2) => + assertEquals(n1 * n2, n2 * n1) + } + } + + property("x + y = y + x") { + forAll(genNat, genNat) { (n1, n2) => + assertEquals(n1 + n2, n2 + n1) + } + } + + property("x * y homomorphism") { + forAll(genNat, genNat) { (ni, nj) => + val nk = ni * nj + assertEquals(nk.toBigInt, ni.toBigInt * nj.toBigInt) + } + } + + property("x.inc == x + 1") { + forAll(genNat) { n => + val i = n.inc + val a = n + Nat.one + assertEquals(i.toBigInt, a.toBigInt) + } + } + + property("x.dec == x - 1 when x > 0") { + forAll(genNat) { n => + val i = n.dec.toBigInt + if (n == Nat.zero) assertEquals(i, BigInt(0)) + else assertEquals(i, n.toBigInt - 1) + } + } + + property("x.shift_32 == x.toBigInt * 2^32") { + val n1 = BigInt("8588888260151524380556863712485265508") + val shift = n1 << 32 + assertEquals(Nat.fromBigInt(n1).shift_32.toBigInt, shift) + + forAll(genNat) { n => + val s = n.shift_32 + val viaBigInt = n.toBigInt << 32 + val viaTimes = n * Nat.two_32 + assertEquals(s.toBigInt, viaBigInt, s"viaTimes = $viaTimes") + assertEquals(s, viaTimes) + } + } + + property("Nat.fromBigInt/toBigInt") { + assertEquals(Nat.fromLong(Long.MaxValue).toBigInt, BigInt(Long.MaxValue)) + + forAll { (bi0: BigInt) => + val bi = bi0.abs + val n = Nat.fromBigInt(bi) + val b2 = n.toBigInt + assertEquals(b2, bi) + } + } + + property("x.inc.dec == x") { + forAll(genNat) { n => + assertEquals(n.inc.dec, n) + } + } + + property("x.dec.inc == x || x.isZero") { + forAll(genNat) { n => + assert((n.dec.inc == n) || n.isZero) + } + } + + property("if the value is > Long.MaxValue maybeLong = None") { + forAll(genNat) { n => + val bi = n.toBigInt + val ml = n.maybeLong + assertEquals(ml.isEmpty, bi > Long.MaxValue) + } + } + property("the string repr matches toBigInt") { + forAll(genNat) { n => + assertEquals(n.toString, n.toBigInt.toString) + } + } +} \ No newline at end of file diff --git a/test_workspace/BinInt.bosatsu b/test_workspace/BinInt.bosatsu new file mode 100644 index 000000000..fc3d047e7 --- /dev/null +++ b/test_workspace/BinInt.bosatsu @@ -0,0 +1,98 @@ +package Bosatsu/BinInt + +from Bosatsu/BinNat import (BinNat, Zero as BNZero, + prev as prev_BinNat, next as next_BinNat, sub_Option, sub_BinNat, + add_BinNat, toInt as binNat_to_Int, toBinNat as int_to_BinNat, + cmp_BinNat, eq_BinNat, +) + +from Bosatsu/Predef import (add as add_Int,) + +export (BinInt(), binNat_to_BinInt, int_to_BinInt, binInt_to_Int, + negate, abs, add, not, cmp, eq, sub,) + +# BWNot(x) == -x - 1 +enum BinInt: + FromBinNat(bn: BinNat) + BWNot(arg: BinNat) + +def cmp(a: BinInt, b: BinInt) -> Comparison: + match a: + case FromBinNat(a): + match b: + case FromBinNat(b): cmp_BinNat(a, b) + case BWNot(_): GT + case BWNot(a): + match b: + case BWNot(b): + # -a - 1 <> -b - 1 == (a <> b).invert + cmp_BinNat(b, a) + case FromBinNat(_): LT + +def eq(a: BinInt, b: BinInt) -> Bool: + match a: + case FromBinNat(a): + match b: + case FromBinNat(b): eq_BinNat(a, b) + case BWNot(_): False + case BWNot(a): + match b: + case BWNot(b): eq_BinNat(a, b) + case FromBinNat(_): False + +def binNat_to_BinInt(bn: BinNat) -> BinInt: FromBinNat(bn) + +def not(bi: BinInt) -> BinInt: + match bi: + case FromBinNat(b): BWNot(b) + case BWNot(b): FromBinNat(b) + +def int_to_BinInt(i: Int) -> BinInt: + if cmp_Int(i, 0) matches LT: + # x = -(-x - 1) - 1 + BWNot(int_to_BinNat(not_Int(i))) + else: + FromBinNat(int_to_BinNat(i)) + +def binInt_to_Int(bi: BinInt) -> Int: + match bi: + case FromBinNat(bn): binNat_to_Int(bn) + case BWNot(x): not_Int(binNat_to_Int(x)) + +def negate(bi: BinInt) -> BinInt: + # -x = -(x - 1) - 1 + # -(-x - 1) = x + 1 + match bi: + case FromBinNat(BNZero): bi + case FromBinNat(bn): BWNot(bn.prev_BinNat()) + case BWNot(x): FromBinNat(x.next_BinNat()) + +def abs(bi: BinInt) -> BinNat: + match bi: + case FromBinNat(bn): bn + case BWNot(x): + #abs(-x - 1) = x + 1 + x.next_BinNat() + +def add(x: BinInt, y: BinInt) -> BinInt: + match (x, y): + case (FromBinNat(x), FromBinNat(y)): + FromBinNat(x.add_BinNat(y)) + case (FromBinNat(x), BWNot(y)): + # x + (-y - 1) = x - (y + 1) + ypos = y.next_BinNat() + match sub_Option(x, ypos): + case Some(bi): FromBinNat(bi) + case None: FromBinNat(sub_BinNat(ypos, x)).negate() + case (BWNot(x), FromBinNat(y)): + # -x - 1 + y = y - (x + 1) + xpos = x.next_BinNat() + match sub_Option(y, xpos): + case Some(bi): FromBinNat(bi) + case None: FromBinNat(sub_BinNat(xpos, y)).negate() + case (BWNot(x), BWNot(y)): + # (-x - 1) + (-y - 1) == -(x + y + 1) - 1 + BWNot(add_BinNat(x, y).next_BinNat()) + +def sub(a: BinInt, b: BinInt) -> BinInt: + add(a, negate(b)) \ No newline at end of file diff --git a/test_workspace/BinNat.bosatsu b/test_workspace/BinNat.bosatsu index 49d5676f5..052a469af 100644 --- a/test_workspace/BinNat.bosatsu +++ b/test_workspace/BinNat.bosatsu @@ -3,7 +3,7 @@ package Bosatsu/BinNat from Bosatsu/Nat import Nat, Zero as NatZero, Succ as NatSucc, times2 as times2_Nat export (BinNat(), toInt, toNat, toBinNat, next, add_BinNat, times2, div2, - prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, sub_Option, eq_BinNat) + prev, times_BinNat, exp, cmp_BinNat, is_even, sub_BinNat, sub_Option, eq_BinNat, divmod) # a natural number with three variants: # Zero = 0 # Odd(n) = 2n + 1 @@ -225,6 +225,55 @@ def times_BinNat(left: BinNat, right: BinNat) -> BinNat: one = Odd(Zero) +def divmod(numerator: BinNat, divisor: BinNat) -> (BinNat, BinNat): + # invariant: divisor >= 2 + # x.divmod(y) = (d, m) <=> x = y * d + m + # n.divmod(y) = (d1, m1) <=> + # n = y * d + m + # 2n + 1 = 2 * y * d + 2m + 1 = (2d) * y + (2m + 1) + # 2n + 2 = 2 * y * d + 2m + 2 + def loop(numerator): + recur numerator: + case Odd(n): + (d1, m1) = loop(n) + m = Odd(m1) + # if (2m + 1) < y => mod = 2m + 1 + # if (2m + 1) == y => mod = 0, increment d + # we know m < y, so 2m + 1 < 2y + 1, + # so we at most have to subtract 2 times + match cmp_BinNat(m, divisor): + case LT: (d1.times2(), m) + case _: + m2 = sub_BinNat(m, divisor) + match cmp_BinNat(m2, divisor): + case LT: (Odd(d1), m2) + case EQ: (Even(d1), Zero) + case GT: (Even(d1), sub_BinNat(m2, divisor)) + case Even(n): + (d1, m1) = loop(n) + m = Even(m1) + # we know m < y, so 2m + 2 < 2y + 2, + # subtracting twice: (2y + 2) - y - y < 2 + # so we at most have to subtract 2 times + match cmp_BinNat(m, divisor): + case LT: (d1.times2(), m) + case _: + m2 = sub_BinNat(m, divisor) + match cmp_BinNat(m2, divisor): + case LT: (Odd(d1), m2) + case EQ: (Even(d1), Zero) + case GT: + m3 = sub_BinNat(m2, divisor) + (Even(d1), m3) + case Zero: (Zero, Zero) + + match divisor: + case Odd(Zero): + # x.divmod(1) == (x, Zero) + (numerator, Zero) + case Odd(_) | Even(_): loop(numerator) + case Zero: (Zero, numerator) + def exp(base: BinNat, power: BinNat) -> BinNat: recur power: case Zero: one diff --git a/test_workspace/Nat.bosatsu b/test_workspace/Nat.bosatsu index 55db1dcc1..4aeff3cb3 100644 --- a/test_workspace/Nat.bosatsu +++ b/test_workspace/Nat.bosatsu @@ -95,8 +95,9 @@ def divmod(numerator: Nat, divisor: Nat) -> (Nat, Nat): else: loop(n, d, m1) match divisor: + case Succ(Zero): (numerator, Zero) + case Succ(_): loop(numerator, Zero, Zero) case Zero: (Zero, numerator) - case _: loop(numerator, Zero, Zero) one = Succ(Zero) diff --git a/test_workspace/NumberProps.bosatsu b/test_workspace/NumberProps.bosatsu index 740069890..9a690c8ba 100644 --- a/test_workspace/NumberProps.bosatsu +++ b/test_workspace/NumberProps.bosatsu @@ -4,10 +4,18 @@ from Bosatsu/BinNat import (BinNat, toBinNat as int_to_BinNat, is_even as is_eve times2 as times2_BinNat, div2 as div2_BinNat, Zero as BNZero, Even as BNEven, times_BinNat, exp as exp_BinNat, cmp_BinNat, toInt as binNat_to_Int, add_BinNat, next as next_BinNat, sub_BinNat, eq_BinNat, sub_Option as sub_BinNat_Option, + divmod as divmod_BinNat, ) + from Bosatsu/Nat import (Nat, Zero as NZero, Succ as NSucc, to_Nat as int_to_Nat, is_even as is_even_Nat, times2 as times2_Nat, div2 as div2_Nat, cmp_Nat, to_Int as nat_to_Int, add as add_Nat, mult as mult_Nat, exp as exp_Nat, sub_Nat, divmod as divmod_Nat) + +from Bosatsu/BinInt import (BinInt, int_to_BinInt, binInt_to_Int, add as add_BinInt, + negate as negate_BinNat, abs as abs_BinInt, sub as sub_BinInt, binNat_to_BinInt, + cmp as cmp_BinInt, not as not_BinInt, eq as eq_BinInt, + ) + from Bosatsu/Properties import (Prop, suite_Prop, forall_Prop, run_Prop) from Bosatsu/Rand import (Rand, from_pair, geometric_Int, int_range, map_Rand, prod_Rand) @@ -20,6 +28,11 @@ export (rand_Int, rand_Nat, rand_BinNat) rand_Int: Rand[Int] = from_pair(int_range(128), geometric_Int) rand_Nat: Rand[Nat] = rand_Int.map_Rand(int_to_Nat) rand_BinNat: Rand[BinNat] = rand_Int.map_Rand(int_to_BinNat) +rand_BinInt: Rand[BinInt] = ( + pos = rand_Int.map_Rand(int_to_BinInt) + neg = pos.map_Rand(not_BinInt) + from_pair(pos, neg) +) int_props = suite_Prop( "Int props", @@ -181,10 +194,56 @@ binnat_props = suite_Prop( t2_2 = times_BinNat(n, BNEven(BNZero)) Assertion(cmp_BinNat(t2, t2_2) matches EQ, "times2 == mult(2, _)") )), + forall_Prop(prod_Rand(rand_BinNat, rand_BinNat), "divmod homomorphism", ((n1, n2)) -> ( + (dn, mn) = divmod_BinNat(n1, n2) + di = div(n1.binNat_to_Int(), n2.binNat_to_Int()) + mi = mod_Int(n1.binNat_to_Int(), n2.binNat_to_Int()) + Assertion( + (eq_Int(dn.binNat_to_Int(), di), eq_Int(mn.binNat_to_Int(), mi)) matches (True, True), + "div BinNat") + )), ] ) -all_props = [int_props, nat_props, binnat_props] +binint_props = suite_Prop( + "BinInt props", + [ + forall_Prop(prod_Rand(rand_BinInt, rand_BinInt), "add homomorphism", ((n1, n2)) -> ( + n3 = add_BinInt(n1, n2) + n3i = binInt_to_Int(n3) + i3 = add(binInt_to_Int(n1), binInt_to_Int(n2)) + Assertion(eq_Int(n3i, i3), "add BinInt") + )), + forall_Prop(prod_Rand(rand_BinInt, rand_BinInt), "sub homomorphism", ((n1, n2)) -> ( + n3 = sub_BinInt(n1, n2) + n3i = binInt_to_Int(n3) + i3 = sub(binInt_to_Int(n1), binInt_to_Int(n2)) + Assertion(eq_Int(n3i, i3), "sub BinInt") + )), + forall_Prop(rand_BinInt, "x + (-x) == 0", x -> ( + nx = negate_BinNat(x) + z = add_BinInt(x, nx).binInt_to_Int() + Assertion(eq_Int(z, 0), "x + (-x) == 0") + )), + forall_Prop(rand_BinInt, "x + |x| == 0 or 2x", x -> ( + ax = abs_BinInt(x) + sum = add_BinInt(x, binNat_to_BinInt(ax)) + xi = x.binInt_to_Int() + if cmp_Int(xi, 0) matches GT: + Assertion(eq_Int(sum.binInt_to_Int(), times(xi, 2)), "x + |x| == 2|x|") + else: + Assertion(eq_Int(sum.binInt_to_Int(), 0), "x + |x| == 0 if x <= 0") + )), + forall_Prop(rand_BinInt, "x + not(x) == x - x - 1 = -1", x -> ( + nx = not_BinInt(x) + z = add_BinInt(x, nx) + neg_1 = int_to_BinInt(-1) + Assertion( + (cmp_BinInt(z, neg_1), eq_BinInt(z, neg_1)) matches (EQ, True), "x + not(x) = -1") + )), + ]) + +all_props = [int_props, nat_props, binnat_props, binint_props] seed = 123456 test = TestSuite("properties", [