diff --git a/bigint.go b/bigint.go index 7ae9cc3..301f09c 100644 --- a/bigint.go +++ b/bigint.go @@ -290,12 +290,18 @@ func mulInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { } func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool) { + if yVal == 0 { // divide by 0 + return 0, false, false + } quo := xVal / yVal neg := xNeg != yNeg return quo, neg, true } func remInline(xVal, yVal uint, xNeg, yNeg bool) (remVal uint, remNeg, ok bool) { + if yVal == 0 { // divide by 0 + return 0, false, false + } rem := xVal % yVal return rem, xNeg, true } @@ -356,7 +362,7 @@ func (z *BigInt) AndNot(x, y *BigInt) *BigInt { // Append calls (big.Int).Append. func (z *BigInt) Append(buf []byte, base int) []byte { var tmp1 big.Int - return z.inner(&tmp1).Append(buf, base) + return z.innerOrNil(&tmp1).Append(buf, base) } // Binomial calls (big.Int).Binomial. @@ -771,10 +777,10 @@ func (z *BigInt) Set(x *BigInt) *BigInt { } // SetBit calls (big.Int).SetBit. -func (z *BigInt) SetBit(x *BigInt, i int, v uint) *BigInt { +func (z *BigInt) SetBit(x *BigInt, i int, b uint) *BigInt { var tmp1, tmp2 big.Int zi := z.inner(&tmp1) - zi.SetBit(x.inner(&tmp2), i, v) + zi.SetBit(x.inner(&tmp2), i, b) z.updateInner(zi) return z } @@ -864,7 +870,7 @@ func (z *BigInt) Sqrt(x *BigInt) *BigInt { // String calls (big.Int).String. func (z *BigInt) String() string { var tmp1 big.Int - return z.inner(&tmp1).String() + return z.innerOrNil(&tmp1).String() } // Sub calls (big.Int).Sub. @@ -887,7 +893,7 @@ func (z *BigInt) Sub(x, y *BigInt) *BigInt { // Text calls (big.Int).Text. func (z *BigInt) Text(base int) string { var tmp1 big.Int - return z.inner(&tmp1).Text(base) + return z.innerOrNil(&tmp1).Text(base) } // TrailingZeroBits calls (big.Int).TrailingZeroBits. diff --git a/bigint_test.go b/bigint_test.go index 2d925e5..f457dc5 100644 --- a/bigint_test.go +++ b/bigint_test.go @@ -21,14 +21,844 @@ import ( "encoding/json" "encoding/xml" "fmt" + "math" "math/big" "math/rand" + "reflect" "strconv" "strings" "testing" "testing/quick" ) +// TestBigIntMatchesMathBigInt uses testing/quick to verify that all methods on +// apd.BigInt and math/big.Int have identical behavior under all inputs. +func TestBigIntMatchesMathBigInt(t *testing.T) { + // Until we import github.com/stretchr/testify/require. + require := func(t *testing.T, err error) { + if err != nil { + t.Error(err) + } + } + + // Catch specific panics and map to strings. + const ( + panicDivisionByZero = "division by zero" + panicJacobi = "invalid 2nd argument to Int.Jacobi: need odd integer" + panicNegativeBit = "negative bit index" + panicSquareRootOfNegativeNum = "square root of negative number" + ) + catchPanic := func(fn func() string, catches ...string) (res string) { + defer func() { + if r := recover(); r != nil { + if rs, ok := r.(string); ok { + for _, catch := range catches { + if strings.Contains(rs, catch) { + res = fmt.Sprintf("caught: %s", r) + } + } + } + if res == "" { // not caught + panic(r) + } + } + }() + return fn() + } + + t.Run("Abs", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Abs(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Abs(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Add", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Add(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Add(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("And", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).And(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).And(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("AndNot", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).AndNot(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).AndNot(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Append", func(t *testing.T) { + apd := func(z numberOrNil) []byte { + return z.toApd(t).Append(nil, 10) + } + math := func(z numberOrNil) []byte { + return z.toMath(t).Append(nil, 10) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Binomial", func(t *testing.T) { + t.Skip("too slow") + apd := func(z number, n, k int64) string { + return z.toApd(t).Binomial(n, k).String() + } + math := func(z number, n, k int64) string { + return z.toMath(t).Binomial(n, k).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bit", func(t *testing.T) { + apd := func(z number, i int) string { + return catchPanic(func() string { + return strconv.FormatUint(uint64(z.toApd(t).Bit(i)), 10) + }, panicNegativeBit) + } + math := func(z number, i int) string { + return catchPanic(func() string { + return strconv.FormatUint(uint64(z.toMath(t).Bit(i)), 10) + }, panicNegativeBit) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("BitLen", func(t *testing.T) { + apd := func(z number) int { + return z.toApd(t).BitLen() + } + math := func(z number) int { + return z.toMath(t).BitLen() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bits", func(t *testing.T) { + emptyToNil := func(w []big.Word) []big.Word { + if len(w) == 0 { + return nil + } + return w + } + apd := func(z number) []big.Word { + return emptyToNil(z.toApd(t).Bits()) + } + math := func(z number) []big.Word { + return emptyToNil(z.toMath(t).Bits()) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bytes", func(t *testing.T) { + apd := func(z number) []byte { + return z.toApd(t).Bytes() + } + math := func(z number) []byte { + return z.toMath(t).Bytes() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Cmp", func(t *testing.T) { + apd := func(z, y number) int { + return z.toApd(t).Cmp(y.toApd(t)) + } + math := func(z, y number) int { + return z.toMath(t).Cmp(y.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("CmpAbs", func(t *testing.T) { + apd := func(z, y number) int { + return z.toApd(t).CmpAbs(y.toApd(t)) + } + math := func(z, y number) int { + return z.toMath(t).CmpAbs(y.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Div", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Div(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Div(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("DivMod", func(t *testing.T) { + apd := func(z, x, y, m number) string { + return catchPanic(func() string { + zi, mi := z.toApd(t), m.toApd(t) + zi.DivMod(x.toApd(t), y.toApd(t), mi) + return zi.String() + " | " + mi.String() + }, panicDivisionByZero) + } + math := func(z, x, y, m number) string { + return catchPanic(func() string { + zi, mi := z.toMath(t), m.toMath(t) + zi.DivMod(x.toMath(t), y.toMath(t), mi) + return zi.String() + " | " + mi.String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Exp", func(t *testing.T) { + t.Skip("too slow") + apd := func(z, x, y, m number) string { + return z.toApd(t).Exp(x.toApd(t), y.toApd(t), m.toApd(t)).String() + } + math := func(z, x, y, m number) string { + return z.toMath(t).Exp(x.toMath(t), y.toMath(t), m.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("FillBytes", func(t *testing.T) { + apd := func(z number) []byte { + return z.toApd(t).FillBytes(make([]byte, len(z))) + } + math := func(z number) []byte { + return z.toMath(t).FillBytes(make([]byte, len(z))) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Format", func(t *testing.T) { + // Call indirectly through fmt.Sprint. + apd := func(z numberOrNil) string { + return fmt.Sprint(z.toApd(t)) + } + math := func(z numberOrNil) string { + return fmt.Sprint(z.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GCD", func(t *testing.T) { + apd := func(z number, x, y numberOrNil, a, b number) string { + return z.toApd(t).GCD(x.toApd(t), y.toApd(t), a.toApd(t), b.toApd(t)).String() + } + math := func(z number, x, y numberOrNil, a, b number) string { + return z.toMath(t).GCD(x.toMath(t), y.toMath(t), a.toMath(t), b.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GobEncode", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).GobEncode() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).GobEncode() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GobDecode", func(t *testing.T) { + apd := func(z number, buf []byte) (string, error) { + zi := z.toApd(t) + err := zi.GobDecode(buf) + return zi.String(), err + } + math := func(z number, buf []byte) (string, error) { + zi := z.toMath(t) + err := zi.GobDecode(buf) + return zi.String(), err + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Int64", func(t *testing.T) { + apd := func(z number) int64 { + return z.toApd(t).Int64() + } + math := func(z number) int64 { + return z.toMath(t).Int64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("IsInt64", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).IsInt64() + } + math := func(z number) bool { + return z.toMath(t).IsInt64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("IsUint64", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).IsUint64() + } + math := func(z number) bool { + return z.toMath(t).IsUint64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Lsh", func(t *testing.T) { + const maxShift = 1000 // avoid makeslice: len out of range + apd := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toApd(t).Lsh(x.toApd(t), n).String() + } + math := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toMath(t).Lsh(x.toMath(t), n).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MarshalJSON", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).MarshalJSON() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).MarshalJSON() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MarshalText", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).MarshalText() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).MarshalText() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Mod", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Mod(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero, panicJacobi) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Mod(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero, panicJacobi) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ModInverse", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).ModInverse(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).ModInverse(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ModSqrt", func(t *testing.T) { + t.Skip("too slow") + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).ModSqrt(x.toApd(t), y.toApd(t)).String() + }, panicJacobi) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).ModSqrt(x.toMath(t), y.toMath(t)).String() + }, panicJacobi) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Mul", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Mul(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Mul(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MulRange", func(t *testing.T) { + t.Skip("too slow") + apd := func(z number, x, y int64) string { + return z.toApd(t).MulRange(x, y).String() + } + math := func(z number, x, y int64) string { + return z.toMath(t).MulRange(x, y).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Neg", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Neg(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Neg(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Not", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Not(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Not(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Or", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Or(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Or(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ProbablyPrime", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).ProbablyPrime(64) + } + math := func(z number) bool { + return z.toMath(t).ProbablyPrime(64) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Quo", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Quo(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Quo(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("QuoRem", func(t *testing.T) { + apd := func(z, x, y, r number) string { + return catchPanic(func() string { + zi, ri := z.toApd(t), r.toApd(t) + zi.QuoRem(x.toApd(t), y.toApd(t), ri) + return zi.String() + " | " + ri.String() + }, panicDivisionByZero) + } + math := func(z, x, y, r number) string { + return catchPanic(func() string { + zi, ri := z.toMath(t), r.toMath(t) + zi.QuoRem(x.toMath(t), y.toMath(t), ri) + return zi.String() + " | " + ri.String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rand", func(t *testing.T) { + apd := func(z, n number, seed int64) string { + rng := rand.New(rand.NewSource(seed)) + return z.toApd(t).Rand(rng, n.toApd(t)).String() + } + math := func(z, n number, seed int64) string { + rng := rand.New(rand.NewSource(seed)) + return z.toMath(t).Rand(rng, n.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rem", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Rem(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Rem(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rsh", func(t *testing.T) { + const maxShift = 1000 // avoid makeslice: len out of range + apd := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toApd(t).Rsh(x.toApd(t), n).String() + } + math := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toMath(t).Rsh(x.toMath(t), n).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Scan", func(t *testing.T) { + // Call indirectly through fmt.Scan. + apd := func(z, src number) (string, error) { + zi := z.toApd(t) + _, err := fmt.Sscan(string(src), zi) + return zi.String(), err + } + math := func(z, src number) (string, error) { + zi := z.toMath(t) + _, err := fmt.Sscan(string(src), zi) + return zi.String(), err + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Set", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Set(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Set(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBit", func(t *testing.T) { + const maxBit = 1000 // avoid makeslice: len out of range + apd := func(z, x number, i int, b bool) string { + if i > maxBit { + i = maxBit + } + bi := uint(0) + if b { + bi = 1 + } + return catchPanic(func() string { + return z.toApd(t).SetBit(x.toApd(t), i, bi).String() + }, panicNegativeBit) + } + math := func(z, x number, i int, b bool) string { + if i > maxBit { + i = maxBit + } + bi := uint(0) + if b { + bi = 1 + } + return catchPanic(func() string { + return z.toMath(t).SetBit(x.toMath(t), i, bi).String() + }, panicNegativeBit) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBits", func(t *testing.T) { + apd := func(z number, abs []big.Word) string { + return z.toApd(t).SetBits(abs).String() + } + math := func(z number, abs []big.Word) string { + return z.toMath(t).SetBits(abs).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBytes", func(t *testing.T) { + apd := func(z number, buf []byte) string { + return z.toApd(t).SetBytes(buf).String() + } + math := func(z number, buf []byte) string { + return z.toMath(t).SetBytes(buf).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetInt64", func(t *testing.T) { + apd := func(z number, x int64) string { + return z.toApd(t).SetInt64(x).String() + } + math := func(z number, x int64) string { + return z.toMath(t).SetInt64(x).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetString", func(t *testing.T) { + apd := func(z, x number) (string, bool) { + zi, ok := z.toApd(t).SetString(string(x), 10) + return zi.String(), ok + } + math := func(z, x number) (string, bool) { + zi, ok := z.toMath(t).SetString(string(x), 10) + return zi.String(), ok + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetUint64", func(t *testing.T) { + apd := func(z number, x uint64) string { + return z.toApd(t).SetUint64(x).String() + } + math := func(z number, x uint64) string { + return z.toMath(t).SetUint64(x).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sign", func(t *testing.T) { + apd := func(z number) int { + return z.toApd(t).Sign() + } + math := func(z number) int { + return z.toMath(t).Sign() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sqrt", func(t *testing.T) { + apd := func(z, x number) string { + return catchPanic(func() string { + return z.toApd(t).Sqrt(x.toApd(t)).String() + }, panicSquareRootOfNegativeNum) + } + math := func(z, x number) string { + return catchPanic(func() string { + return z.toMath(t).Sqrt(x.toMath(t)).String() + }, panicSquareRootOfNegativeNum) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("String", func(t *testing.T) { + apd := func(z numberOrNil) string { + return z.toApd(t).String() + } + math := func(z numberOrNil) string { + return z.toMath(t).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sub", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Sub(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Sub(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Text", func(t *testing.T) { + apd := func(z numberOrNil) string { + return z.toApd(t).Text(10) + } + math := func(z numberOrNil) string { + return z.toMath(t).Text(10) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("TrailingZeroBits", func(t *testing.T) { + apd := func(z number) uint { + return z.toApd(t).TrailingZeroBits() + } + math := func(z number) uint { + return z.toMath(t).TrailingZeroBits() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Uint64", func(t *testing.T) { + apd := func(z number) uint64 { + return z.toApd(t).Uint64() + } + math := func(z number) uint64 { + return z.toMath(t).Uint64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("UnmarshalJSON", func(t *testing.T) { + apd := func(z number, text []byte) (string, error) { + zi := z.toApd(t) + if err := zi.UnmarshalJSON(text); err != nil { + return "", err + } + return zi.String(), nil + } + math := func(z number, text []byte) (string, error) { + zi := z.toMath(t) + if err := zi.UnmarshalJSON(text); err != nil { + return "", err + } + return zi.String(), nil + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("UnmarshalText", func(t *testing.T) { + apd := func(z number, text []byte) (string, error) { + zi := z.toApd(t) + if err := zi.UnmarshalText(text); err != nil { + return "", err + } + return zi.String(), nil + } + math := func(z number, text []byte) (string, error) { + zi := z.toMath(t) + if err := zi.UnmarshalText(text); err != nil { + return "", err + } + return zi.String(), nil + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Xor", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Xor(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Xor(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) +} + +// number is a quick.Generator for decimal numbers. +type number string + +func (n number) Generate(r *rand.Rand, size int) reflect.Value { + var s string + if r.Intn(2) != 0 { + s = n.generateInterestingNumber(r) + } else { + s = n.generateRandomNumber(r, size) + } + return reflect.ValueOf(number(s)) +} + +func (z *BigInt) incr() *BigInt { return z.Add(z, bigOne) } +func (z *BigInt) decr() *BigInt { return z.Sub(z, bigOne) } + +var interestingNumbers = [...]*BigInt{ + new(BigInt).SetInt64(math.MinInt64).decr(), + new(BigInt).SetInt64(math.MinInt64), + new(BigInt).SetInt64(math.MinInt64).incr(), + new(BigInt).SetInt64(math.MinInt32), + new(BigInt).SetInt64(math.MinInt16), + new(BigInt).SetInt64(math.MinInt8), + new(BigInt).SetInt64(0), + new(BigInt).SetInt64(math.MaxInt8), + new(BigInt).SetInt64(math.MaxUint8), + new(BigInt).SetInt64(math.MaxInt16), + new(BigInt).SetInt64(math.MaxUint16), + new(BigInt).SetInt64(math.MaxInt32), + new(BigInt).SetInt64(math.MaxUint32), + new(BigInt).SetInt64(math.MaxInt64).decr(), + new(BigInt).SetInt64(math.MaxInt64), + new(BigInt).SetInt64(math.MaxInt64).incr(), + new(BigInt).SetUint64(math.MaxUint64).decr(), + new(BigInt).SetUint64(math.MaxUint64), + new(BigInt).SetUint64(math.MaxUint64).incr(), +} + +func (number) generateInterestingNumber(r *rand.Rand) string { + return interestingNumbers[r.Intn(len(interestingNumbers))].String() +} + +var numbers = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} + +func (number) generateRandomNumber(r *rand.Rand, size int) string { + var s strings.Builder + if r.Intn(2) != 0 { + s.WriteByte('-') // neg + } + digits := r.Intn(size) + 1 + for i := 0; i < digits; i++ { + s.WriteByte(numbers[r.Intn(len(numbers))]) + } + return s.String() +} + +func (n number) toApd(t *testing.T) *BigInt { + var x BigInt + if _, ok := x.SetString(string(n), 10); !ok { + t.Fatalf("failed to SetString(%q)", n) + } + return &x +} + +func (n number) toMath(t *testing.T) *big.Int { + var x big.Int + if _, ok := x.SetString(string(n), 10); !ok { + t.Fatalf("failed to SetString(%q)", n) + } + return &x +} + +type numberOrNil struct { + Num number + Nil bool +} + +func (n numberOrNil) toApd(t *testing.T) *BigInt { + if n.Nil { + return nil + } + return n.Num.toApd(t) +} + +func (n numberOrNil) toMath(t *testing.T) *big.Int { + if n.Nil { + return nil + } + return n.Num.toMath(t) +} + ////////////////////////////////////////////////////////////////////////////////// // The following tests were copied from the standard library's math/big package // //////////////////////////////////////////////////////////////////////////////////