diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4875dec..9257243 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,8 +15,6 @@ jobs: strategy: matrix: go: - - '1.11' - - '1.12' - '1.13' - '1.14' - '1.15' @@ -38,11 +36,12 @@ jobs: run: go test -v ./... - name: 'Vet' - run: go vet ./... + # -unsafeptr=false is needed because of the noescape function in bigint.go. + run: go vet -unsafeptr=false ./... - name: 'Staticcheck' # staticcheck requires go1.14. - if: ${{ matrix.go != '1.11' && matrix.go != '1.12' && matrix.go != '1.13' }} + if: ${{ matrix.go != '1.13' }} run: | go get honnef.co/go/tools/cmd/staticcheck staticcheck ./... diff --git a/README.md b/README.md index 2864950..09172a8 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,17 @@ apd is an arbitrary-precision decimal package for Go. - **Good performance**. Operations will either be fast enough or will produce an error if they will be slow. This prevents edge-case operations from consuming lots of CPU or memory. - **Condition flags and traps**. All operations will report whether their result is exact, is rounded, is over- or under-flowed, is [subnormal](https://en.wikipedia.org/wiki/Denormal_number), or is some other condition. `apd` supports traps which will trigger an error on any of these conditions. This makes it possible to guarantee exactness in computations, if needed. -`apd` has two main types. The first is [`Decimal`](https://godoc.org/github.com/cockroachdb/apd#Decimal) which holds the values of decimals. It is simple and uses a `big.Int` with an exponent to describe values. Most operations on `Decimal`s can’t produce errors as they work directly on the underlying `big.Int`. Notably, however, there are no arithmetic operations on `Decimal`s. +`apd` has three main types. -The second main type is [`Context`](https://godoc.org/github.com/cockroachdb/apd#Context), which is where all arithmetic operations are defined. A `Context` describes the precision, range, and some other restrictions during operations. These operations can all produce failures, and so return errors. +The first is [`BigInt`](https://godoc.org/github.com/cockroachdb/apd#BigInt) which is a wrapper around `big.Int` that exposes an identical API while reducing memory allocations. `BigInt` does so by using an inline array to back the `big.Int`'s variable-length value when the integer's absolute value is sufficiently small. `BigInt` also contains fast-paths that allow it to perform basic arithmetic directly on this inline array, only falling back to `big.Int` when the arithmetic gets complex or takes place on large values. + +The second is [`Decimal`](https://godoc.org/github.com/cockroachdb/apd#Decimal) which holds the values of decimals. It is simple and uses a `BigInt` with an exponent to describe values. Most operations on `Decimal`s can’t produce errors as they work directly on the underlying `big.Int`. Notably, however, there are no arithmetic operations on `Decimal`s. + +The third main type is [`Context`](https://godoc.org/github.com/cockroachdb/apd#Context), which is where all arithmetic operations are defined. A `Context` describes the precision, range, and some other restrictions during operations. These operations can all produce failures, and so return errors. `Context` operations, in addition to errors, return a [`Condition`](https://godoc.org/github.com/cockroachdb/apd#Condition), which is a bitfield of flags that occurred during an operation. These include overflow, underflow, inexact, rounded, and others. The `Traps` field of a `Context` can be set which will produce an error if the corresponding flag occurs. An example of this is given below. See the [examples](https://godoc.org/github.com/cockroachdb/apd#pkg-examples) for some operations that were previously difficult to perform in Go. ## Documentation -https://pkg.go.dev/github.com/cockroachdb/apd/v2?tab=doc +https://pkg.go.dev/github.com/cockroachdb/apd/v3?tab=doc diff --git a/bigint.go b/bigint.go new file mode 100644 index 0000000..d5837b3 --- /dev/null +++ b/bigint.go @@ -0,0 +1,939 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package apd + +import ( + "fmt" + "math/big" + "math/bits" + "math/rand" + "unsafe" +) + +// The inlineWords capacity is set to accommodate any value that would fit in a +// 128-bit integer (i.e. values with an absolute value up to 2^128 - 1). +const inlineWords = 2 + +// BigInt is a wrapper around big.Int. It minimizes memory allocation by using +// an inline array to back the big.Int's variable-length "nat" slice when the +// integer's value is sufficiently small. +// The zero value is ready to use. +type BigInt struct { + // A wrapped big.Int. Only set to the BigInt's value when the value exceeds + // what is representable in the _inline array. + // + // When the BigInt's value is still small enough to use the _inline array, + // this field doubles as integer's negative flag. See negSentinel. + // + // Methods should access this field through inner. + _inner *big.Int + + // The inlined backing array use for short-lived, stack-allocated big.Int + // structs during arithmetic when the value is small. + // + // Each BigInt maintains (through big.Int) an internal reference to a + // variable-length integer value, which is represented by a []big.Word. The + // _inline field and the inner and updateInner methods combine to allow + // BigInt to inline this variable-length integer array within the BigInt + // struct when its value is sufficiently small. In the inner method, we + // point a temporary big.Int's nat slice at this _inline array. big.Int will + // avoid re-allocating this array until it is provided with a value that + // exceeds the initial capacity. Later in updateInner, we detect whether the + // array has been re-allocated. If so, we switch to using the _inner. If + // not, we continue to use this array. + _inline [inlineWords]big.Word +} + +// NewBigInt allocates and returns a new BigInt set to x. +// +// NOTE: BigInt jumps through hoops to avoid escaping to the heap. As such, most +// users of BigInt should not need this function. They should instead declare a +// zero-valued BigInt directly on the stack and interact with references to this +// stack-allocated value. Recall that the zero-valued BigInt is ready to use. +func NewBigInt(x int64) *BigInt { + return new(BigInt).SetInt64(x) +} + +// Set as the value of BigInt._inner as a "sentinel" flag to indicate that a +// BigInt is negative ((big.Int).Sign() < 0) but the absolute value is still +// small enough to represent in the _inline array. +var negSentinel = new(big.Int) + +// isInline returns whether the BigInt stores its value in its _inline array. +func (z *BigInt) isInline() bool { + return z._inner == nil || z._inner == negSentinel +} + +// The memory representation of big.Int. Used for unsafe modification below. +type intStruct struct { + neg bool + abs []big.Word +} + +// noescape hides a pointer from escape analysis. noescape is the identity +// function but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// +// USE CAREFULLY! +// +// This was copied from strings.Builder, which has identical code which was +// itself copied from the runtime. +// For more, see issues #23382 and #7921 in github.com/golang/go. +//go:nosplit +//go:nocheckptr +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + //lint:ignore SA4016 intentional no-op to hide pointer from escape analysis. + return unsafe.Pointer(x ^ 0) +} + +// inner returns the BigInt's current value as a *big.Int. +// +// NOTE: this was carefully written to permit function inlining. Modify with +// care. +//gcassert:inline +func (z *BigInt) inner(tmp *big.Int) *big.Int { + // Point the big.Int at the inline array. When doing so, use noescape to + // avoid forcing the BigInt to escape to the heap. Go's escape analysis + // struggles with self-referential pointers, and it can't prove that we + // only assign _inner to a heap-allocated object (which must not contain + // pointers that reference the stack or the GC explodes) if the big.Int's + // backing array has been re-allocated onto the heap first. + // + // NOTE: SetBits sets the neg field to false, so this must come before the + // negSentinel handling. + tmp.SetBits((*[inlineWords]big.Word)(noescape(unsafe.Pointer(&z._inline[0])))[:]) + + if z._inner != nil { + if z._inner != negSentinel { + // The variable-length big.Int reference is set. + return z._inner + } + + // This is the negative sentinel, which indicates that the integer is + // negative but still stored inline. Update the big.Int accordingly. We + // use unsafe because (*big.Int).Neg is too complex and prevents this + // method from being inlined. + (*intStruct)(unsafe.Pointer(tmp)).neg = true + } + return tmp +} + +// innerOrNil is like inner, but returns a nil *big.Int if the receiver is nil. +func (z *BigInt) innerOrNil(tmp *big.Int) *big.Int { + if z == nil { + return nil + } + return z.inner(tmp) +} + +// innerOrAlias is like inner, but returns the provided *big.Int if the receiver +// and the other *BigInt argument reference the same object. +func (z *BigInt) innerOrAlias(tmp *big.Int, a *BigInt, ai *big.Int) *big.Int { + if a == z { + return ai + } + return z.inner(tmp) +} + +// innerOrNilOrAlias is like inner, but with the added semantics specified for +// both innerOrNil and innerOrAlias. +func (z *BigInt) innerOrNilOrAlias(tmp *big.Int, a *BigInt, ai *big.Int) *big.Int { + if z == nil { + return nil + } else if z == a { + return ai + } + return z.inner(tmp) +} + +// updateInner updates the BigInt's current value with the provided *big.Int. +// +// NOTE: this was carefully written to permit function inlining. Modify with +// care. +//gcassert:inline +func (z *BigInt) updateInner(src *big.Int) { + if z._inner == src { + return + } + + bits := src.Bits() + bitsLen := len(bits) + if bitsLen > 0 && &z._inline[0] != &bits[0] { + // The big.Int re-allocated its backing array during arithmetic because + // the value grew beyond what could fit in the _inline array. Switch to + // a heap-allocated, variable-length big.Int and store that in _inner. + // From now on, all arithmetic will use this big.Int directly. + // + // Allocate a new big.Int and perform a shallow-copy of the argument to + // prevent it from escaping off the stack. + z._inner = new(big.Int) + *z._inner = *src + } else { + // Zero out all words beyond the end of the big.Int's current Word + // slice. big.Int arithmetic can sometimes leave these words "dirty". + // They would cause issues when the _inline array is injected into the + // next big.Int if not cleared. + for bitsLen < len(z._inline) { + z._inline[bitsLen] = 0 + bitsLen++ + } + + // Set or unset the negative sentinel, according to the argument's sign. + // We use unsafe because (*big.Int).Sign is too complex and prevents + // this method from being inlined. + if (*intStruct)(unsafe.Pointer(src)).neg { + z._inner = negSentinel + } else { + z._inner = nil + } + } +} + +// innerAsUint returns the BigInt's current absolute value as a uint and a flag +// indicating whether the value is negative. If the value is not stored inline +// or if it can not fit in a uint, false is returned. +// +// NOTE: this was carefully written to permit function inlining. Modify with +// care. +//gcassert:inline +func (z *BigInt) innerAsUint() (val uint, neg bool, ok bool) { + if !z.isInline() { + // The value is not stored inline. + return 0, false, false + } + for i := 1; i < len(z._inline); i++ { + if z._inline[i] != 0 { + // The value can not fit in a uint. + return 0, false, false + } + } + + val = uint(z._inline[0]) + neg = z._inner == negSentinel + return val, neg, true +} + +// updateInnerFromUint updates the BigInt's current value with the provided +// absolute value and sign. +// +// NOTE: this was carefully written to permit function inlining. Modify with +// care. +//gcassert:inline +func (z *BigInt) updateInnerFromUint(val uint, neg bool) { + // Set the inline value, making sure to clear out all other words. + z._inline[0] = big.Word(val) + for i := 1; i < len(z._inline); i++ { + z._inline[i] = 0 + } + + // Set or unset the negative sentinel. + if neg { + z._inner = negSentinel + } else { + z._inner = nil + } +} + +const ( + bigIntSize = unsafe.Sizeof(BigInt{}) + mathBigIntSize = unsafe.Sizeof(big.Int{}) + mathWordSize = unsafe.Sizeof(big.Word(0)) +) + +// Size returns the total memory footprint of z in bytes. +func (z *BigInt) Size() uintptr { + if z.isInline() { + return bigIntSize + } + return bigIntSize + mathBigIntSize + uintptr(cap(z._inner.Bits()))*mathWordSize +} + +/////////////////////////////////////////////////////////////////////////////// +// inline arithmetic for small values // +/////////////////////////////////////////////////////////////////////////////// + +//gcassert:inline +func addInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { + if xNeg == yNeg { + sum, carry := bits.Add(xVal, yVal, 0) + overflow := carry != 0 + return sum, xNeg, !overflow + } + + diff, borrow := bits.Sub(xVal, yVal, 0) + if borrow != 0 { // underflow + xNeg = !xNeg + diff = yVal - xVal + } + if diff == 0 { + xNeg = false + } + return diff, xNeg, true +} + +//gcassert:inline +func mulInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { + hi, lo := bits.Mul(xVal, yVal) + neg := xNeg != yNeg + overflow := hi != 0 + return lo, neg, !overflow +} + +//gcassert:inline +func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool) { + if yVal == 0 { // divide by 0 + return 0, false, false + } + quo := xVal / yVal + neg := xNeg != yNeg + return quo, neg, true +} + +//gcassert:inline +func remInline(xVal, yVal uint, xNeg, yNeg bool) (remVal uint, remNeg, ok bool) { + if yVal == 0 { // divide by 0 + return 0, false, false + } + rem := xVal % yVal + return rem, xNeg, true +} + +/////////////////////////////////////////////////////////////////////////////// +// big.Int API wrapper methods // +/////////////////////////////////////////////////////////////////////////////// + +// Abs calls (big.Int).Abs. +func (z *BigInt) Abs(x *BigInt) *BigInt { + if x.isInline() { + z._inline = x._inline + z._inner = nil // !negSentinel + return z + } + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Abs(x.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// Add calls (big.Int).Add. +func (z *BigInt) Add(x, y *BigInt) *BigInt { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, yNeg); ok { + z.updateInnerFromUint(zVal, zNeg) + return z + } + } + } + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Add(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// And calls (big.Int).And. +func (z *BigInt) And(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.And(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// AndNot calls (big.Int).AndNot. +func (z *BigInt) AndNot(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.AndNot(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// Append calls (big.Int).Append. +func (z *BigInt) Append(buf []byte, base int) []byte { + var tmp1 big.Int + return z.innerOrNil(&tmp1).Append(buf, base) +} + +// Binomial calls (big.Int).Binomial. +func (z *BigInt) Binomial(n, k int64) *BigInt { + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.Binomial(n, k) + z.updateInner(zi) + return z +} + +// Bit calls (big.Int).Bit. +func (z *BigInt) Bit(i int) uint { + if i == 0 && z.isInline() { + // Optimization for common case: odd/even test of z. + return uint(z._inline[0] & 1) + } + var tmp1 big.Int + return z.inner(&tmp1).Bit(i) +} + +// BitLen calls (big.Int).BitLen. +func (z *BigInt) BitLen() int { + if z.isInline() { + // Find largest non-zero inline word. + for i := len(z._inline) - 1; i >= 0; i-- { + if z._inline[i] != 0 { + return i*bits.UintSize + bits.Len(uint(z._inline[i])) + } + } + return 0 + } + var tmp1 big.Int + return z.inner(&tmp1).BitLen() +} + +// Bits calls (big.Int).Bits. +func (z *BigInt) Bits() []big.Word { + var tmp1 big.Int + return z.inner(&tmp1).Bits() +} + +// Bytes calls (big.Int).Bytes. +func (z *BigInt) Bytes() []byte { + var tmp1 big.Int + return z.inner(&tmp1).Bytes() +} + +// Cmp calls (big.Int).Cmp. +func (z *BigInt) Cmp(y *BigInt) (r int) { + if zVal, zNeg, ok := z.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + switch { + case zNeg == yNeg: + switch { + case zVal < yVal: + r = -1 + case zVal > yVal: + r = 1 + } + if zNeg { + r = -r + } + case zNeg: + r = -1 + default: + r = 1 + } + return r + } + } + var tmp1, tmp2 big.Int + return z.inner(&tmp1).Cmp(y.inner(&tmp2)) +} + +// CmpAbs calls (big.Int).CmpAbs. +func (z *BigInt) CmpAbs(y *BigInt) (r int) { + if zVal, _, ok := z.innerAsUint(); ok { + if yVal, _, ok := y.innerAsUint(); ok { + switch { + case zVal < yVal: + r = -1 + case zVal > yVal: + r = 1 + } + return r + } + } + var tmp1, tmp2 big.Int + return z.inner(&tmp1).CmpAbs(y.inner(&tmp2)) +} + +// Div calls (big.Int).Div. +func (z *BigInt) Div(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Div(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// DivMod calls (big.Int).DivMod. +func (z *BigInt) DivMod(x, y, m *BigInt) (*BigInt, *BigInt) { + var tmp1, tmp2, tmp3, tmp4 big.Int + zi := z.inner(&tmp1) + mi := m.inner(&tmp2) + // NOTE: innerOrAlias for the y param because (big.Int).DivMod needs to + // detect when y is aliased to the receiver. + zi.DivMod(x.inner(&tmp3), y.innerOrAlias(&tmp4, z, zi), mi) + z.updateInner(zi) + m.updateInner(mi) + return z, m +} + +// Exp calls (big.Int).Exp. +func (z *BigInt) Exp(x, y, m *BigInt) *BigInt { + var tmp1, tmp2, tmp3, tmp4 big.Int + zi := z.inner(&tmp1) + if zi.Exp(x.inner(&tmp2), y.inner(&tmp3), m.innerOrNil(&tmp4)) == nil { + return nil + } + z.updateInner(zi) + return z +} + +// Format calls (big.Int).Format. +func (z *BigInt) Format(s fmt.State, ch rune) { + var tmp1 big.Int + z.innerOrNil(&tmp1).Format(s, ch) +} + +// GCD calls (big.Int).GCD. +func (z *BigInt) GCD(x, y, a, b *BigInt) *BigInt { + var tmp1, tmp2, tmp3, tmp4, tmp5 big.Int + zi := z.inner(&tmp1) + ai := a.inner(&tmp2) + bi := b.inner(&tmp3) + xi := x.innerOrNil(&tmp4) + // NOTE: innerOrNilOrAlias for the y param because (big.Int).GCD needs to + // detect when y is aliased to b. See "avoid aliasing b" in lehmerGCD. + yi := y.innerOrNilOrAlias(&tmp5, b, bi) + zi.GCD(xi, yi, ai, bi) + z.updateInner(zi) + if xi != nil { + x.updateInner(xi) + } + if yi != nil { + y.updateInner(yi) + } + return z +} + +// GobEncode calls (big.Int).GobEncode. +func (z *BigInt) GobEncode() ([]byte, error) { + var tmp1 big.Int + return z.innerOrNil(&tmp1).GobEncode() +} + +// GobDecode calls (big.Int).GobDecode. +func (z *BigInt) GobDecode(buf []byte) error { + var tmp1 big.Int + zi := z.inner(&tmp1) + if err := zi.GobDecode(buf); err != nil { + return err + } + z.updateInner(zi) + return nil +} + +// Int64 calls (big.Int).Int64. +func (z *BigInt) Int64() int64 { + var tmp1 big.Int + return z.inner(&tmp1).Int64() +} + +// IsInt64 calls (big.Int).IsInt64. +func (z *BigInt) IsInt64() bool { + var tmp1 big.Int + return z.inner(&tmp1).IsInt64() +} + +// IsUint64 calls (big.Int).IsUint64. +func (z *BigInt) IsUint64() bool { + var tmp1 big.Int + return z.inner(&tmp1).IsUint64() +} + +// Lsh calls (big.Int).Lsh. +func (z *BigInt) Lsh(x *BigInt, n uint) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Lsh(x.inner(&tmp2), n) + z.updateInner(zi) + return z +} + +// MarshalJSON calls (big.Int).MarshalJSON. +func (z *BigInt) MarshalJSON() ([]byte, error) { + var tmp1 big.Int + return z.innerOrNil(&tmp1).MarshalJSON() +} + +// MarshalText calls (big.Int).MarshalText. +func (z *BigInt) MarshalText() (text []byte, err error) { + var tmp1 big.Int + return z.innerOrNil(&tmp1).MarshalText() +} + +// Mod calls (big.Int).Mod. +func (z *BigInt) Mod(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + // NOTE: innerOrAlias for the y param because (big.Int).Mod needs to detect + // when y is aliased to the receiver. + zi.Mod(x.inner(&tmp2), y.innerOrAlias(&tmp3, z, zi)) + z.updateInner(zi) + return z +} + +// ModInverse calls (big.Int).ModInverse. +func (z *BigInt) ModInverse(g, n *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + if zi.ModInverse(g.inner(&tmp2), n.inner(&tmp3)) == nil { + return nil + } + z.updateInner(zi) + return z +} + +// ModSqrt calls (big.Int).ModSqrt. +func (z *BigInt) ModSqrt(x, p *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + if zi.ModSqrt(x.inner(&tmp2), p.inner(&tmp3)) == nil { + return nil + } + z.updateInner(zi) + return z +} + +// Mul calls (big.Int).Mul. +func (z *BigInt) Mul(x, y *BigInt) *BigInt { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if zVal, zNeg, ok := mulInline(xVal, yVal, xNeg, yNeg); ok { + z.updateInnerFromUint(zVal, zNeg) + return z + } + } + } + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Mul(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// MulRange calls (big.Int).MulRange. +func (z *BigInt) MulRange(x, y int64) *BigInt { + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.MulRange(x, y) + z.updateInner(zi) + return z +} + +// Neg calls (big.Int).Neg. +func (z *BigInt) Neg(x *BigInt) *BigInt { + if x.isInline() { + z._inline = x._inline + if x._inner == negSentinel { + z._inner = nil + } else { + z._inner = negSentinel + } + return z + } + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Neg(x.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// Not calls (big.Int).Not. +func (z *BigInt) Not(x *BigInt) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Not(x.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// Or calls (big.Int).Or. +func (z *BigInt) Or(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Or(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// ProbablyPrime calls (big.Int).ProbablyPrime. +func (z *BigInt) ProbablyPrime(n int) bool { + var tmp1 big.Int + return z.inner(&tmp1).ProbablyPrime(n) +} + +// Quo calls (big.Int).Quo. +func (z *BigInt) Quo(x, y *BigInt) *BigInt { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok { + z.updateInnerFromUint(quoVal, quoNeg) + return z + } + } + } + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Quo(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// QuoRem calls (big.Int).QuoRem. +func (z *BigInt) QuoRem(x, y, r *BigInt) (*BigInt, *BigInt) { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok { + if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok { + z.updateInnerFromUint(quoVal, quoNeg) + r.updateInnerFromUint(remVal, remNeg) + return z, r + } + } + } + } + var tmp1, tmp2, tmp3, tmp4 big.Int + zi := z.inner(&tmp1) + ri := r.inner(&tmp2) + zi.QuoRem(x.inner(&tmp3), y.inner(&tmp4), ri) + z.updateInner(zi) + r.updateInner(ri) + return z, r +} + +// Rand calls (big.Int).Rand. +func (z *BigInt) Rand(rnd *rand.Rand, n *BigInt) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Rand(rnd, n.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// Rem calls (big.Int).Rem. +func (z *BigInt) Rem(x, y *BigInt) *BigInt { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok { + z.updateInnerFromUint(remVal, remNeg) + return z + } + } + } + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Rem(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// Rsh calls (big.Int).Rsh. +func (z *BigInt) Rsh(x *BigInt, n uint) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Rsh(x.inner(&tmp2), n) + z.updateInner(zi) + return z +} + +// Scan calls (big.Int).Scan. +func (z *BigInt) Scan(s fmt.ScanState, ch rune) error { + var tmp1 big.Int + zi := z.inner(&tmp1) + if err := zi.Scan(s, ch); err != nil { + return err + } + z.updateInner(zi) + return nil +} + +// Set calls (big.Int).Set. +func (z *BigInt) Set(x *BigInt) *BigInt { + if x.isInline() { + *z = *x + return z + } + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Set(x.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// SetBit calls (big.Int).SetBit. +func (z *BigInt) SetBit(x *BigInt, i int, b uint) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.SetBit(x.inner(&tmp2), i, b) + z.updateInner(zi) + return z +} + +// SetBits calls (big.Int).SetBits. +func (z *BigInt) SetBits(abs []big.Word) *BigInt { + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.SetBits(abs) + z.updateInner(zi) + return z +} + +// SetBytes calls (big.Int).SetBytes. +func (z *BigInt) SetBytes(buf []byte) *BigInt { + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.SetBytes(buf) + z.updateInner(zi) + return z +} + +// SetInt64 calls (big.Int).SetInt64. +func (z *BigInt) SetInt64(x int64) *BigInt { + if bits.UintSize == 64 { + neg := false + if x < 0 { + neg = true + x = -x + } + z.updateInnerFromUint(uint(x), neg) + return z + } + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.SetInt64(x) + z.updateInner(zi) + return z +} + +// SetString calls (big.Int).SetString. +func (z *BigInt) SetString(s string, base int) (*BigInt, bool) { + var tmp1 big.Int + zi := z.inner(&tmp1) + if _, ok := zi.SetString(s, base); !ok { + return nil, false + } + z.updateInner(zi) + return z, true +} + +// SetUint64 calls (big.Int).SetUint64. +func (z *BigInt) SetUint64(x uint64) *BigInt { + if bits.UintSize == 64 { + z.updateInnerFromUint(uint(x), false) + return z + } + var tmp1 big.Int + zi := z.inner(&tmp1) + zi.SetUint64(x) + z.updateInner(zi) + return z +} + +// Sign calls (big.Int).Sign. +func (z *BigInt) Sign() int { + if z._inner == nil { + if z._inline == [inlineWords]big.Word{} { + return 0 + } + return 1 + } else if z._inner == negSentinel { + return -1 + } + return z._inner.Sign() +} + +// Sqrt calls (big.Int).Sqrt. +func (z *BigInt) Sqrt(x *BigInt) *BigInt { + var tmp1, tmp2 big.Int + zi := z.inner(&tmp1) + zi.Sqrt(x.inner(&tmp2)) + z.updateInner(zi) + return z +} + +// String calls (big.Int).String. +func (z *BigInt) String() string { + var tmp1 big.Int + return z.innerOrNil(&tmp1).String() +} + +// Sub calls (big.Int).Sub. +func (z *BigInt) Sub(x, y *BigInt) *BigInt { + if xVal, xNeg, ok := x.innerAsUint(); ok { + if yVal, yNeg, ok := y.innerAsUint(); ok { + if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, !yNeg); ok { + z.updateInnerFromUint(zVal, zNeg) + return z + } + } + } + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Sub(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} + +// Text calls (big.Int).Text. +func (z *BigInt) Text(base int) string { + var tmp1 big.Int + return z.innerOrNil(&tmp1).Text(base) +} + +// TrailingZeroBits calls (big.Int).TrailingZeroBits. +func (z *BigInt) TrailingZeroBits() uint { + var tmp1 big.Int + return z.inner(&tmp1).TrailingZeroBits() +} + +// Uint64 calls (big.Int).Uint64. +func (z *BigInt) Uint64() uint64 { + var tmp1 big.Int + return z.inner(&tmp1).Uint64() +} + +// UnmarshalJSON calls (big.Int).UnmarshalJSON. +func (z *BigInt) UnmarshalJSON(text []byte) error { + var tmp1 big.Int + zi := z.inner(&tmp1) + if err := zi.UnmarshalJSON(text); err != nil { + return err + } + z.updateInner(zi) + return nil +} + +// UnmarshalText calls (big.Int).UnmarshalText. +func (z *BigInt) UnmarshalText(text []byte) error { + var tmp1 big.Int + zi := z.inner(&tmp1) + if err := zi.UnmarshalText(text); err != nil { + return err + } + z.updateInner(zi) + return nil +} + +// Xor calls (big.Int).Xor. +func (z *BigInt) Xor(x, y *BigInt) *BigInt { + var tmp1, tmp2, tmp3 big.Int + zi := z.inner(&tmp1) + zi.Xor(x.inner(&tmp2), y.inner(&tmp3)) + z.updateInner(zi) + return z +} diff --git a/bigint_go1.14_test.go b/bigint_go1.14_test.go new file mode 100644 index 0000000..5c52756 --- /dev/null +++ b/bigint_go1.14_test.go @@ -0,0 +1,157 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +//go:build go1.14 +// +build go1.14 + +package apd + +import ( + "testing" + "testing/quick" +) + +////////////////////////////////////////////////////////////////////////////////// +// The following tests were copied from the standard library's math/big package // +////////////////////////////////////////////////////////////////////////////////// + +func checkGcd(aBytes, bBytes []byte) bool { + x := new(BigInt) + y := new(BigInt) + a := new(BigInt).SetBytes(aBytes) + b := new(BigInt).SetBytes(bBytes) + + d := new(BigInt).GCD(x, y, a, b) + x.Mul(x, a) + y.Mul(y, b) + x.Add(x, y) + + return x.Cmp(d) == 0 +} + +var gcdTests = []struct { + d, x, y, a, b string +}{ + // a <= 0 || b <= 0 + {"0", "0", "0", "0", "0"}, + {"7", "0", "1", "0", "7"}, + {"7", "0", "-1", "0", "-7"}, + {"11", "1", "0", "11", "0"}, + {"7", "-1", "-2", "-77", "35"}, + {"935", "-3", "8", "64515", "24310"}, + {"935", "-3", "-8", "64515", "-24310"}, + {"935", "3", "-8", "-64515", "-24310"}, + + {"1", "-9", "47", "120", "23"}, + {"7", "1", "-2", "77", "35"}, + {"935", "-3", "8", "64515", "24310"}, + {"935000000000000000", "-3", "8", "64515000000000000000", "24310000000000000000"}, + {"1", "-221", "22059940471369027483332068679400581064239780177629666810348940098015901108344", "98920366548084643601728869055592650835572950932266967461790948584315647051443", "991"}, +} + +func testGcd(t *testing.T, d, x, y, a, b *BigInt) { + var X *BigInt + if x != nil { + X = new(BigInt) + } + var Y *BigInt + if y != nil { + Y = new(BigInt) + } + + D := new(BigInt).GCD(X, Y, a, b) + if D.Cmp(d) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + // check results in presence of aliasing (issue #11284) + a2 := new(BigInt).Set(a) + b2 := new(BigInt).Set(b) + a2.GCD(X, Y, a2, b2) // result is same as 1st argument + if a2.Cmp(d) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, a2, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("aliased z = a GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + a2 = new(BigInt).Set(a) + b2 = new(BigInt).Set(b) + b2.GCD(X, Y, a2, b2) // result is same as 2nd argument + if b2.Cmp(d) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, b2, d) + } + if x != nil && X.Cmp(x) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, X, x) + } + if y != nil && Y.Cmp(y) != 0 { + t.Errorf("aliased z = b GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, Y, y) + } + + a2 = new(BigInt).Set(a) + b2 = new(BigInt).Set(b) + D = new(BigInt).GCD(a2, b2, a2, b2) // x = a, y = b + if D.Cmp(d) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && a2.Cmp(x) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, a2, x) + } + if y != nil && b2.Cmp(y) != 0 { + t.Errorf("aliased x = a, y = b GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, b2, y) + } + + a2 = new(BigInt).Set(a) + b2 = new(BigInt).Set(b) + D = new(BigInt).GCD(b2, a2, a2, b2) // x = b, y = a + if D.Cmp(d) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got d = %s, want %s", x, y, a, b, D, d) + } + if x != nil && b2.Cmp(x) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got x = %s, want %s", x, y, a, b, b2, x) + } + if y != nil && a2.Cmp(y) != 0 { + t.Errorf("aliased x = b, y = a GCD(%s, %s, %s, %s): got y = %s, want %s", x, y, a, b, a2, y) + } +} + +// This was not supported in go1.13. See https://go.dev/doc/go1.14: +// > The GCD method now allows the inputs a and b to be zero or negative. +func TestBigIntGcd(t *testing.T) { + for _, test := range gcdTests { + d, _ := new(BigInt).SetString(test.d, 0) + x, _ := new(BigInt).SetString(test.x, 0) + y, _ := new(BigInt).SetString(test.y, 0) + a, _ := new(BigInt).SetString(test.a, 0) + b, _ := new(BigInt).SetString(test.b, 0) + + testGcd(t, d, nil, nil, a, b) + testGcd(t, d, x, nil, a, b) + testGcd(t, d, nil, y, a, b) + testGcd(t, d, x, y, a, b) + } + + if err := quick.Check(checkGcd, nil); err != nil { + t.Error(err) + } +} diff --git a/bigint_go1.15.go b/bigint_go1.15.go new file mode 100644 index 0000000..e8ad605 --- /dev/null +++ b/bigint_go1.15.go @@ -0,0 +1,26 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +//go:build go1.15 +// +build go1.15 + +package apd + +import "math/big" + +// FillBytes calls (big.Int).FillBytes. +func (z *BigInt) FillBytes(buf []byte) []byte { + var tmp1 big.Int + return z.inner(&tmp1).FillBytes(buf) +} diff --git a/bigint_go1.15_test.go b/bigint_go1.15_test.go new file mode 100644 index 0000000..fd04115 --- /dev/null +++ b/bigint_go1.15_test.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +//go:build go1.15 +// +build go1.15 + +package apd + +import ( + "testing" + "testing/quick" +) + +// TestBigIntMatchesMathBigInt15 is like TestBigIntMatchesMathBigInt, but for +// parts of the shared BigInt/big.Int API that were introduced in go1.15. +func TestBigIntMatchesMathBigInt15(t *testing.T) { + t.Run("FillBytes", func(t *testing.T) { + apd := func(z number) []byte { + return z.toApd(t).FillBytes(make([]byte, len(z))) + } + math := func(z number) []byte { + return z.toMath(t).FillBytes(make([]byte, len(z))) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) +} + +////////////////////////////////////////////////////////////////////////////////// +// The following tests were copied from the standard library's math/big package // +////////////////////////////////////////////////////////////////////////////////// + +func TestBigIntFillBytes(t *testing.T) { + checkResult := func(t *testing.T, buf []byte, want *BigInt) { + t.Helper() + got := new(BigInt).SetBytes(buf) + if got.CmpAbs(want) != 0 { + t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf) + } + } + panics := func(f func()) (panic bool) { + defer func() { panic = recover() != nil }() + f() + return + } + + for _, n := range []string{ + "0", + "1000", + "0xffffffff", + "-0xffffffff", + "0xffffffffffffffff", + "0x10000000000000000", + "0xabababababababababababababababababababababababababa", + "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + } { + t.Run(n, func(t *testing.T) { + t.Logf(n) + x, ok := new(BigInt).SetString(n, 0) + if !ok { + panic("invalid test entry") + } + + // Perfectly sized buffer. + byteLen := (x.BitLen() + 7) / 8 + buf := make([]byte, byteLen) + checkResult(t, x.FillBytes(buf), x) + + // Way larger, checking all bytes get zeroed. + buf = make([]byte, 100) + for i := range buf { + buf[i] = 0xff + } + checkResult(t, x.FillBytes(buf), x) + + // Too small. + if byteLen > 0 { + buf = make([]byte, byteLen-1) + if !panics(func() { x.FillBytes(buf) }) { + t.Errorf("expected panic for small buffer and value %x", x) + } + } + }) + } +} diff --git a/bigint_test.go b/bigint_test.go new file mode 100644 index 0000000..2191344 --- /dev/null +++ b/bigint_test.go @@ -0,0 +1,2834 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package apd + +import ( + "bytes" + "encoding/gob" + "encoding/hex" + "encoding/json" + "encoding/xml" + "fmt" + "math" + "math/big" + "math/rand" + "reflect" + "strconv" + "strings" + "testing" + "testing/quick" +) + +// TestBigIntMatchesMathBigInt uses testing/quick to verify that all methods on +// apd.BigInt and math/big.Int have identical behavior for all inputs. +func TestBigIntMatchesMathBigInt(t *testing.T) { + // Catch specific panics and map to strings. + const ( + panicDivisionByZero = "division by zero" + panicJacobi = "invalid 2nd argument to Int.Jacobi: need odd integer" + panicNegativeBit = "negative bit index" + panicSquareRootOfNegativeNum = "square root of negative number" + ) + catchPanic := func(fn func() string, catches ...string) (res string) { + defer func() { + if r := recover(); r != nil { + if rs, ok := r.(string); ok { + for _, catch := range catches { + if strings.Contains(rs, catch) { + res = fmt.Sprintf("caught: %s", r) + } + } + } + if res == "" { // not caught + panic(r) + } + } + }() + return fn() + } + + t.Run("Abs", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Abs(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Abs(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Add", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Add(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Add(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("And", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).And(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).And(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("AndNot", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).AndNot(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).AndNot(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Append", func(t *testing.T) { + apd := func(z numberOrNil) []byte { + return z.toApd(t).Append(nil, 10) + } + math := func(z numberOrNil) []byte { + return z.toMath(t).Append(nil, 10) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Binomial", func(t *testing.T) { + t.Skip("too slow") + apd := func(z number, n, k int64) string { + return z.toApd(t).Binomial(n, k).String() + } + math := func(z number, n, k int64) string { + return z.toMath(t).Binomial(n, k).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bit", func(t *testing.T) { + apd := func(z number, i int) string { + return catchPanic(func() string { + return strconv.FormatUint(uint64(z.toApd(t).Bit(i)), 10) + }, panicNegativeBit) + } + math := func(z number, i int) string { + return catchPanic(func() string { + return strconv.FormatUint(uint64(z.toMath(t).Bit(i)), 10) + }, panicNegativeBit) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("BitLen", func(t *testing.T) { + apd := func(z number) int { + return z.toApd(t).BitLen() + } + math := func(z number) int { + return z.toMath(t).BitLen() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bits", func(t *testing.T) { + emptyToNil := func(w []big.Word) []big.Word { + if len(w) == 0 { + return nil + } + return w + } + apd := func(z number) []big.Word { + return emptyToNil(z.toApd(t).Bits()) + } + math := func(z number) []big.Word { + return emptyToNil(z.toMath(t).Bits()) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Bytes", func(t *testing.T) { + apd := func(z number) []byte { + return z.toApd(t).Bytes() + } + math := func(z number) []byte { + return z.toMath(t).Bytes() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Cmp", func(t *testing.T) { + apd := func(z, y number) int { + return z.toApd(t).Cmp(y.toApd(t)) + } + math := func(z, y number) int { + return z.toMath(t).Cmp(y.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("CmpAbs", func(t *testing.T) { + apd := func(z, y number) int { + return z.toApd(t).CmpAbs(y.toApd(t)) + } + math := func(z, y number) int { + return z.toMath(t).CmpAbs(y.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Div", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Div(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Div(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("DivMod", func(t *testing.T) { + apd := func(z, x, y, m number) string { + return catchPanic(func() string { + zi, mi := z.toApd(t), m.toApd(t) + zi.DivMod(x.toApd(t), y.toApd(t), mi) + return zi.String() + " | " + mi.String() + }, panicDivisionByZero) + } + math := func(z, x, y, m number) string { + return catchPanic(func() string { + zi, mi := z.toMath(t), m.toMath(t) + zi.DivMod(x.toMath(t), y.toMath(t), mi) + return zi.String() + " | " + mi.String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Exp", func(t *testing.T) { + t.Skip("too slow") + apd := func(z, x, y, m number) string { + return z.toApd(t).Exp(x.toApd(t), y.toApd(t), m.toApd(t)).String() + } + math := func(z, x, y, m number) string { + return z.toMath(t).Exp(x.toMath(t), y.toMath(t), m.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Format", func(t *testing.T) { + // Call indirectly through fmt.Sprint. + apd := func(z numberOrNil) string { + return fmt.Sprint(z.toApd(t)) + } + math := func(z numberOrNil) string { + return fmt.Sprint(z.toMath(t)) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GCD", func(t *testing.T) { + apd := func(z number, x, y numberOrNil, a, b number) string { + return z.toApd(t).GCD(x.toApd(t), y.toApd(t), a.toApd(t), b.toApd(t)).String() + } + math := func(z number, x, y numberOrNil, a, b number) string { + return z.toMath(t).GCD(x.toMath(t), y.toMath(t), a.toMath(t), b.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GobEncode", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).GobEncode() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).GobEncode() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("GobDecode", func(t *testing.T) { + apd := func(z number, buf []byte) (string, error) { + zi := z.toApd(t) + err := zi.GobDecode(buf) + return zi.String(), err + } + math := func(z number, buf []byte) (string, error) { + zi := z.toMath(t) + err := zi.GobDecode(buf) + return zi.String(), err + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Int64", func(t *testing.T) { + apd := func(z number) int64 { + return z.toApd(t).Int64() + } + math := func(z number) int64 { + return z.toMath(t).Int64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("IsInt64", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).IsInt64() + } + math := func(z number) bool { + return z.toMath(t).IsInt64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("IsUint64", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).IsUint64() + } + math := func(z number) bool { + return z.toMath(t).IsUint64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Lsh", func(t *testing.T) { + const maxShift = 1000 // avoid makeslice: len out of range + apd := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toApd(t).Lsh(x.toApd(t), n).String() + } + math := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toMath(t).Lsh(x.toMath(t), n).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MarshalJSON", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).MarshalJSON() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).MarshalJSON() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MarshalText", func(t *testing.T) { + apd := func(z numberOrNil) ([]byte, error) { + return z.toApd(t).MarshalText() + } + math := func(z numberOrNil) ([]byte, error) { + return z.toMath(t).MarshalText() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Mod", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Mod(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero, panicJacobi) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Mod(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero, panicJacobi) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ModInverse", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).ModInverse(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).ModInverse(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ModSqrt", func(t *testing.T) { + t.Skip("too slow") + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).ModSqrt(x.toApd(t), y.toApd(t)).String() + }, panicJacobi) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).ModSqrt(x.toMath(t), y.toMath(t)).String() + }, panicJacobi) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Mul", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Mul(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Mul(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("MulRange", func(t *testing.T) { + t.Skip("too slow") + apd := func(z number, x, y int64) string { + return z.toApd(t).MulRange(x, y).String() + } + math := func(z number, x, y int64) string { + return z.toMath(t).MulRange(x, y).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Neg", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Neg(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Neg(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Not", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Not(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Not(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Or", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Or(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Or(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("ProbablyPrime", func(t *testing.T) { + apd := func(z number) bool { + return z.toApd(t).ProbablyPrime(64) + } + math := func(z number) bool { + return z.toMath(t).ProbablyPrime(64) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Quo", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Quo(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Quo(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("QuoRem", func(t *testing.T) { + apd := func(z, x, y, r number) string { + return catchPanic(func() string { + zi, ri := z.toApd(t), r.toApd(t) + zi.QuoRem(x.toApd(t), y.toApd(t), ri) + return zi.String() + " | " + ri.String() + }, panicDivisionByZero) + } + math := func(z, x, y, r number) string { + return catchPanic(func() string { + zi, ri := z.toMath(t), r.toMath(t) + zi.QuoRem(x.toMath(t), y.toMath(t), ri) + return zi.String() + " | " + ri.String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rand", func(t *testing.T) { + apd := func(z, n number, seed int64) string { + rng := rand.New(rand.NewSource(seed)) + return z.toApd(t).Rand(rng, n.toApd(t)).String() + } + math := func(z, n number, seed int64) string { + rng := rand.New(rand.NewSource(seed)) + return z.toMath(t).Rand(rng, n.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rem", func(t *testing.T) { + apd := func(z, x, y number) string { + return catchPanic(func() string { + return z.toApd(t).Rem(x.toApd(t), y.toApd(t)).String() + }, panicDivisionByZero) + } + math := func(z, x, y number) string { + return catchPanic(func() string { + return z.toMath(t).Rem(x.toMath(t), y.toMath(t)).String() + }, panicDivisionByZero) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Rsh", func(t *testing.T) { + const maxShift = 1000 // avoid makeslice: len out of range + apd := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toApd(t).Rsh(x.toApd(t), n).String() + } + math := func(z, x number, n uint) string { + if n > maxShift { + n = maxShift + } + return z.toMath(t).Rsh(x.toMath(t), n).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Scan", func(t *testing.T) { + // Call indirectly through fmt.Sscan. + apd := func(z, src number) (string, error) { + zi := z.toApd(t) + _, err := fmt.Sscan(string(src), zi) + return zi.String(), err + } + math := func(z, src number) (string, error) { + zi := z.toMath(t) + _, err := fmt.Sscan(string(src), zi) + return zi.String(), err + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Set", func(t *testing.T) { + apd := func(z, x number) string { + return z.toApd(t).Set(x.toApd(t)).String() + } + math := func(z, x number) string { + return z.toMath(t).Set(x.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBit", func(t *testing.T) { + const maxBit = 1000 // avoid makeslice: len out of range + apd := func(z, x number, i int, b bool) string { + if i > maxBit { + i = maxBit + } + bi := uint(0) + if b { + bi = 1 + } + return catchPanic(func() string { + return z.toApd(t).SetBit(x.toApd(t), i, bi).String() + }, panicNegativeBit) + } + math := func(z, x number, i int, b bool) string { + if i > maxBit { + i = maxBit + } + bi := uint(0) + if b { + bi = 1 + } + return catchPanic(func() string { + return z.toMath(t).SetBit(x.toMath(t), i, bi).String() + }, panicNegativeBit) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBits", func(t *testing.T) { + apd := func(z number, abs []big.Word) string { + return z.toApd(t).SetBits(abs).String() + } + math := func(z number, abs []big.Word) string { + return z.toMath(t).SetBits(abs).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetBytes", func(t *testing.T) { + apd := func(z number, buf []byte) string { + return z.toApd(t).SetBytes(buf).String() + } + math := func(z number, buf []byte) string { + return z.toMath(t).SetBytes(buf).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetInt64", func(t *testing.T) { + apd := func(z number, x int64) string { + return z.toApd(t).SetInt64(x).String() + } + math := func(z number, x int64) string { + return z.toMath(t).SetInt64(x).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetString", func(t *testing.T) { + apd := func(z, x number) (string, bool) { + zi, ok := z.toApd(t).SetString(string(x), 10) + return zi.String(), ok + } + math := func(z, x number) (string, bool) { + zi, ok := z.toMath(t).SetString(string(x), 10) + return zi.String(), ok + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("SetUint64", func(t *testing.T) { + apd := func(z number, x uint64) string { + return z.toApd(t).SetUint64(x).String() + } + math := func(z number, x uint64) string { + return z.toMath(t).SetUint64(x).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sign", func(t *testing.T) { + apd := func(z number) int { + return z.toApd(t).Sign() + } + math := func(z number) int { + return z.toMath(t).Sign() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sqrt", func(t *testing.T) { + apd := func(z, x number) string { + return catchPanic(func() string { + return z.toApd(t).Sqrt(x.toApd(t)).String() + }, panicSquareRootOfNegativeNum) + } + math := func(z, x number) string { + return catchPanic(func() string { + return z.toMath(t).Sqrt(x.toMath(t)).String() + }, panicSquareRootOfNegativeNum) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("String", func(t *testing.T) { + apd := func(z numberOrNil) string { + return z.toApd(t).String() + } + math := func(z numberOrNil) string { + return z.toMath(t).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Sub", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Sub(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Sub(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Text", func(t *testing.T) { + apd := func(z numberOrNil) string { + return z.toApd(t).Text(10) + } + math := func(z numberOrNil) string { + return z.toMath(t).Text(10) + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("TrailingZeroBits", func(t *testing.T) { + apd := func(z number) uint { + return z.toApd(t).TrailingZeroBits() + } + math := func(z number) uint { + return z.toMath(t).TrailingZeroBits() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Uint64", func(t *testing.T) { + apd := func(z number) uint64 { + return z.toApd(t).Uint64() + } + math := func(z number) uint64 { + return z.toMath(t).Uint64() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("UnmarshalJSON", func(t *testing.T) { + apd := func(z number, text []byte) (string, error) { + zi := z.toApd(t) + if err := zi.UnmarshalJSON(text); err != nil { + return "", err + } + return zi.String(), nil + } + math := func(z number, text []byte) (string, error) { + zi := z.toMath(t) + if err := zi.UnmarshalJSON(text); err != nil { + return "", err + } + return zi.String(), nil + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("UnmarshalText", func(t *testing.T) { + apd := func(z number, text []byte) (string, error) { + zi := z.toApd(t) + if err := zi.UnmarshalText(text); err != nil { + return "", err + } + return zi.String(), nil + } + math := func(z number, text []byte) (string, error) { + zi := z.toMath(t) + if err := zi.UnmarshalText(text); err != nil { + return "", err + } + return zi.String(), nil + } + require(t, quick.CheckEqual(apd, math, nil)) + }) + + t.Run("Xor", func(t *testing.T) { + apd := func(z, x, y number) string { + return z.toApd(t).Xor(x.toApd(t), y.toApd(t)).String() + } + math := func(z, x, y number) string { + return z.toMath(t).Xor(x.toMath(t), y.toMath(t)).String() + } + require(t, quick.CheckEqual(apd, math, nil)) + }) +} + +// number is a quick.Generator for large integer numbers. +type number string + +func (n number) Generate(r *rand.Rand, size int) reflect.Value { + var s string + if r.Intn(2) != 0 { + s = n.generateInterestingNumber(r) + } else { + s = n.generateRandomNumber(r, size) + } + return reflect.ValueOf(number(s)) +} + +func (z *BigInt) incr() *BigInt { return z.Add(z, bigOne) } +func (z *BigInt) decr() *BigInt { return z.Sub(z, bigOne) } + +var interestingNumbers = [...]*BigInt{ + new(BigInt).SetInt64(math.MinInt64).decr(), + new(BigInt).SetInt64(math.MinInt64), + new(BigInt).SetInt64(math.MinInt64).incr(), + new(BigInt).SetInt64(math.MinInt32), + new(BigInt).SetInt64(math.MinInt16), + new(BigInt).SetInt64(math.MinInt8), + new(BigInt).SetInt64(0), + new(BigInt).SetInt64(math.MaxInt8), + new(BigInt).SetInt64(math.MaxUint8), + new(BigInt).SetInt64(math.MaxInt16), + new(BigInt).SetInt64(math.MaxUint16), + new(BigInt).SetInt64(math.MaxInt32), + new(BigInt).SetInt64(math.MaxUint32), + new(BigInt).SetInt64(math.MaxInt64).decr(), + new(BigInt).SetInt64(math.MaxInt64), + new(BigInt).SetInt64(math.MaxInt64).incr(), + new(BigInt).SetUint64(math.MaxUint64).decr(), + new(BigInt).SetUint64(math.MaxUint64), + new(BigInt).SetUint64(math.MaxUint64).incr(), +} + +func (number) generateInterestingNumber(r *rand.Rand) string { + return interestingNumbers[r.Intn(len(interestingNumbers))].String() +} + +var numbers = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} + +func (number) generateRandomNumber(r *rand.Rand, size int) string { + var s strings.Builder + if r.Intn(2) != 0 { + s.WriteByte('-') // neg + } + digits := r.Intn(size) + 1 + for i := 0; i < digits; i++ { + s.WriteByte(numbers[r.Intn(len(numbers))]) + } + return s.String() +} + +func (n number) toApd(t *testing.T) *BigInt { + var x BigInt + if _, ok := x.SetString(string(n), 10); !ok { + t.Fatalf("failed to SetString(%q)", n) + } + return &x +} + +func (n number) toMath(t *testing.T) *big.Int { + var x big.Int + if _, ok := x.SetString(string(n), 10); !ok { + t.Fatalf("failed to SetString(%q)", n) + } + return &x +} + +type numberOrNil struct { + Num number + Nil bool +} + +func (n numberOrNil) toApd(t *testing.T) *BigInt { + if n.Nil { + return nil + } + return n.Num.toApd(t) +} + +func (n numberOrNil) toMath(t *testing.T) *big.Int { + if n.Nil { + return nil + } + return n.Num.toMath(t) +} + +// Until we import github.com/stretchr/testify/require. +func require(t *testing.T, err error) { + if err != nil { + t.Error(err) + } +} + +////////////////////////////////////////////////////////////////////////////////// +// The following tests were copied from the standard library's math/big package // +////////////////////////////////////////////////////////////////////////////////// + +// +// Tests from src/math/big/int_test.go +// + +type funZZ func(z, x, y *BigInt) *BigInt +type argZZ struct { + z, x, y *BigInt +} + +var sumZZ = []argZZ{ + {NewBigInt(0), NewBigInt(0), NewBigInt(0)}, + {NewBigInt(1), NewBigInt(1), NewBigInt(0)}, + {NewBigInt(1111111110), NewBigInt(123456789), NewBigInt(987654321)}, + {NewBigInt(-1), NewBigInt(-1), NewBigInt(0)}, + {NewBigInt(864197532), NewBigInt(-123456789), NewBigInt(987654321)}, + {NewBigInt(-1111111110), NewBigInt(-123456789), NewBigInt(-987654321)}, +} + +var prodZZ = []argZZ{ + {NewBigInt(0), NewBigInt(0), NewBigInt(0)}, + {NewBigInt(0), NewBigInt(1), NewBigInt(0)}, + {NewBigInt(1), NewBigInt(1), NewBigInt(1)}, + {NewBigInt(-991 * 991), NewBigInt(991), NewBigInt(-991)}, + // TODO(gri) add larger products +} + +func TestBigIntSignZ(t *testing.T) { + var zero BigInt + for _, a := range sumZZ { + s := a.z.Sign() + e := a.z.Cmp(&zero) + if s != e { + t.Errorf("got %d; want %d for z = %v", s, e, a.z) + } + } +} + +func TestBigIntSetZ(t *testing.T) { + for _, a := range sumZZ { + var z BigInt + z.Set(a.z) + if (&z).Cmp(a.z) != 0 { + t.Errorf("got z = %v; want %v", &z, a.z) + } + } +} + +func TestBigIntAbsZ(t *testing.T) { + var zero BigInt + for _, a := range sumZZ { + var z BigInt + z.Abs(a.z) + var e BigInt + e.Set(a.z) + if e.Cmp(&zero) < 0 { + e.Sub(&zero, &e) + } + if z.Cmp(&e) != 0 { + t.Errorf("got z = %v; want %v", &z, &e) + } + } +} + +func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) { + var z BigInt + f(&z, a.x, a.y) + if (&z).Cmp(a.z) != 0 { + t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, &z, a.z) + } +} + +func TestBigIntSumZZ(t *testing.T) { + AddZZ := func(z, x, y *BigInt) *BigInt { return z.Add(x, y) } + SubZZ := func(z, x, y *BigInt) *BigInt { return z.Sub(x, y) } + for _, a := range sumZZ { + arg := a + testFunZZ(t, "AddZZ", AddZZ, arg) + + arg = argZZ{a.z, a.y, a.x} + testFunZZ(t, "AddZZ symmetric", AddZZ, arg) + + arg = argZZ{a.x, a.z, a.y} + testFunZZ(t, "SubZZ", SubZZ, arg) + + arg = argZZ{a.y, a.z, a.x} + testFunZZ(t, "SubZZ symmetric", SubZZ, arg) + } +} + +func TestBigIntProdZZ(t *testing.T) { + MulZZ := func(z, x, y *BigInt) *BigInt { return z.Mul(x, y) } + for _, a := range prodZZ { + arg := a + testFunZZ(t, "MulZZ", MulZZ, arg) + + arg = argZZ{a.z, a.y, a.x} + testFunZZ(t, "MulZZ symmetric", MulZZ, arg) + } +} + +// mulBytes returns x*y via grade school multiplication. Both inputs +// and the result are assumed to be in big-endian representation (to +// match the semantics of BigInt.Bytes and BigInt.SetBytes). +func mulBytes(x, y []byte) []byte { + z := make([]byte, len(x)+len(y)) + + // multiply + k0 := len(z) - 1 + for j := len(y) - 1; j >= 0; j-- { + d := int(y[j]) + if d != 0 { + k := k0 + carry := 0 + for i := len(x) - 1; i >= 0; i-- { + t := int(z[k]) + int(x[i])*d + carry + z[k], carry = byte(t), t>>8 + k-- + } + z[k] = byte(carry) + } + k0-- + } + + // normalize (remove leading 0's) + i := 0 + for i < len(z) && z[i] == 0 { + i++ + } + + return z[i:] +} + +func checkMul(a, b []byte) bool { + var x, y, z1 BigInt + x.SetBytes(a) + y.SetBytes(b) + z1.Mul(&x, &y) + + var z2 BigInt + z2.SetBytes(mulBytes(a, b)) + + return z1.Cmp(&z2) == 0 +} + +func TestBigIntMul(t *testing.T) { + if err := quick.Check(checkMul, nil); err != nil { + t.Error(err) + } +} + +var mulRangesN = []struct { + a, b uint64 + prod string +}{ + {0, 0, "0"}, + {1, 1, "1"}, + {1, 2, "2"}, + {1, 3, "6"}, + {10, 10, "10"}, + {0, 100, "0"}, + {0, 1e9, "0"}, + {1, 0, "1"}, // empty range + {100, 1, "1"}, // empty range + {1, 10, "3628800"}, // 10! + {1, 20, "2432902008176640000"}, // 20! + {1, 100, + "933262154439441526816992388562667004907159682643816214685929" + + "638952175999932299156089414639761565182862536979208272237582" + + "51185210916864000000000000000000000000", // 100! + }, +} + +var mulRangesZ = []struct { + a, b int64 + prod string +}{ + // entirely positive ranges are covered by mulRangesN + {-1, 1, "0"}, + {-2, -1, "2"}, + {-3, -2, "6"}, + {-3, -1, "-6"}, + {1, 3, "6"}, + {-10, -10, "-10"}, + {0, -1, "1"}, // empty range + {-1, -100, "1"}, // empty range + {-1, 1, "0"}, // range includes 0 + {-1e9, 0, "0"}, // range includes 0 + {-1e9, 1e9, "0"}, // range includes 0 + {-10, -1, "3628800"}, // 10! + {-20, -2, "-2432902008176640000"}, // -20! + {-99, -1, + "-933262154439441526816992388562667004907159682643816214685929" + + "638952175999932299156089414639761565182862536979208272237582" + + "511852109168640000000000000000000000", // -99! + }, +} + +func TestBigIntMulRangeZ(t *testing.T) { + var tmp BigInt + // test entirely positive ranges + for i, r := range mulRangesN { + prod := tmp.MulRange(int64(r.a), int64(r.b)).String() + if prod != r.prod { + t.Errorf("#%da: got %s; want %s", i, prod, r.prod) + } + } + // test other ranges + for i, r := range mulRangesZ { + prod := tmp.MulRange(r.a, r.b).String() + if prod != r.prod { + t.Errorf("#%db: got %s; want %s", i, prod, r.prod) + } + } +} + +func TestBigIntBinomial(t *testing.T) { + var z BigInt + for _, test := range []struct { + n, k int64 + want string + }{ + {0, 0, "1"}, + {0, 1, "0"}, + {1, 0, "1"}, + {1, 1, "1"}, + {1, 10, "0"}, + {4, 0, "1"}, + {4, 1, "4"}, + {4, 2, "6"}, + {4, 3, "4"}, + {4, 4, "1"}, + {10, 1, "10"}, + {10, 9, "10"}, + {10, 5, "252"}, + {11, 5, "462"}, + {11, 6, "462"}, + {100, 10, "17310309456440"}, + {100, 90, "17310309456440"}, + {1000, 10, "263409560461970212832400"}, + {1000, 990, "263409560461970212832400"}, + } { + if got := z.Binomial(test.n, test.k).String(); got != test.want { + t.Errorf("Binomial(%d, %d) = %s; want %s", test.n, test.k, got, test.want) + } + } +} + +// Examples from the Go Language Spec, section "Arithmetic operators" +var divisionSignsTests = []struct { + x, y int64 + q, r int64 // T-division + d, m int64 // Euclidean division +}{ + {5, 3, 1, 2, 1, 2}, + {-5, 3, -1, -2, -2, 1}, + {5, -3, -1, 2, -1, 2}, + {-5, -3, 1, -2, 2, 1}, + {1, 2, 0, 1, 0, 1}, + {8, 4, 2, 0, 2, 0}, +} + +func TestBigIntDivisionSigns(t *testing.T) { + for i, test := range divisionSignsTests { + x := NewBigInt(test.x) + y := NewBigInt(test.y) + q := NewBigInt(test.q) + r := NewBigInt(test.r) + d := NewBigInt(test.d) + m := NewBigInt(test.m) + + q1 := new(BigInt).Quo(x, y) + r1 := new(BigInt).Rem(x, y) + if q1.Cmp(q) != 0 || r1.Cmp(r) != 0 { + t.Errorf("#%d QuoRem: got (%s, %s), want (%s, %s)", i, q1, r1, q, r) + } + + q2, r2 := new(BigInt).QuoRem(x, y, new(BigInt)) + if q2.Cmp(q) != 0 || r2.Cmp(r) != 0 { + t.Errorf("#%d QuoRem: got (%s, %s), want (%s, %s)", i, q2, r2, q, r) + } + + d1 := new(BigInt).Div(x, y) + m1 := new(BigInt).Mod(x, y) + if d1.Cmp(d) != 0 || m1.Cmp(m) != 0 { + t.Errorf("#%d DivMod: got (%s, %s), want (%s, %s)", i, d1, m1, d, m) + } + + d2, m2 := new(BigInt).DivMod(x, y, new(BigInt)) + if d2.Cmp(d) != 0 || m2.Cmp(m) != 0 { + t.Errorf("#%d DivMod: got (%s, %s), want (%s, %s)", i, d2, m2, d, m) + } + } +} + +func checkSetBytes(b []byte) bool { + hex1 := hex.EncodeToString(new(BigInt).SetBytes(b).Bytes()) + hex2 := hex.EncodeToString(b) + + for len(hex1) < len(hex2) { + hex1 = "0" + hex1 + } + + for len(hex1) > len(hex2) { + hex2 = "0" + hex2 + } + + return hex1 == hex2 +} + +func TestBigIntSetBytes(t *testing.T) { + if err := quick.Check(checkSetBytes, nil); err != nil { + t.Error(err) + } +} + +func checkBytes(b []byte) bool { + // trim leading zero bytes since Bytes() won't return them + // (was issue 12231) + for len(b) > 0 && b[0] == 0 { + b = b[1:] + } + b2 := new(BigInt).SetBytes(b).Bytes() + return bytes.Equal(b, b2) +} + +func TestBigIntBytes(t *testing.T) { + if err := quick.Check(checkBytes, nil); err != nil { + t.Error(err) + } +} + +func checkQuo(x, y []byte) bool { + u := new(BigInt).SetBytes(x) + v := new(BigInt).SetBytes(y) + + var tmp1 big.Int + if len(v.inner(&tmp1).Bits()) == 0 { + return true + } + + r := new(BigInt) + q, r := new(BigInt).QuoRem(u, v, r) + + if r.Cmp(v) >= 0 { + return false + } + + uprime := new(BigInt).Set(q) + uprime.Mul(uprime, v) + uprime.Add(uprime, r) + + return uprime.Cmp(u) == 0 +} + +var quoTests = []struct { + x, y string + q, r string +}{ + { + "476217953993950760840509444250624797097991362735329973741718102894495832294430498335824897858659711275234906400899559094370964723884706254265559534144986498357", + "9353930466774385905609975137998169297361893554149986716853295022578535724979483772383667534691121982974895531435241089241440253066816724367338287092081996", + "50911", + "1", + }, + { + "11510768301994997771168", + "1328165573307167369775", + "8", + "885443715537658812968", + }, +} + +func TestBigIntQuo(t *testing.T) { + if err := quick.Check(checkQuo, nil); err != nil { + t.Error(err) + } + + for i, test := range quoTests { + x, _ := new(BigInt).SetString(test.x, 10) + y, _ := new(BigInt).SetString(test.y, 10) + expectedQ, _ := new(BigInt).SetString(test.q, 10) + expectedR, _ := new(BigInt).SetString(test.r, 10) + + r := new(BigInt) + q, r := new(BigInt).QuoRem(x, y, r) + + if q.Cmp(expectedQ) != 0 || r.Cmp(expectedR) != 0 { + t.Errorf("#%d got (%s, %s) want (%s, %s)", i, q, r, expectedQ, expectedR) + } + } +} + +var bitLenTests = []struct { + in string + out int +}{ + {"-1", 1}, + {"0", 0}, + {"1", 1}, + {"2", 2}, + {"4", 3}, + {"0xabc", 12}, + {"0x8000", 16}, + {"0x80000000", 32}, + {"0x800000000000", 48}, + {"0x8000000000000000", 64}, + {"0x80000000000000000000", 80}, + {"-0x4000000000000000000000", 87}, +} + +func TestBigIntBitLen(t *testing.T) { + for i, test := range bitLenTests { + x, ok := new(BigInt).SetString(test.in, 0) + if !ok { + t.Errorf("#%d test input invalid: %s", i, test.in) + continue + } + + if n := x.BitLen(); n != test.out { + t.Errorf("#%d got %d want %d", i, n, test.out) + } + } +} + +var expTests = []struct { + x, y, m string + out string +}{ + // y <= 0 + {"0", "0", "", "1"}, + {"1", "0", "", "1"}, + {"-10", "0", "", "1"}, + {"1234", "-1", "", "1"}, + {"1234", "-1", "0", "1"}, + {"17", "-100", "1234", "865"}, + {"2", "-100", "1234", ""}, + + // m == 1 + {"0", "0", "1", "0"}, + {"1", "0", "1", "0"}, + {"-10", "0", "1", "0"}, + {"1234", "-1", "1", "0"}, + + // misc + {"5", "1", "3", "2"}, + {"5", "-7", "", "1"}, + {"-5", "-7", "", "1"}, + {"5", "0", "", "1"}, + {"-5", "0", "", "1"}, + {"5", "1", "", "5"}, + {"-5", "1", "", "-5"}, + {"-5", "1", "7", "2"}, + {"-2", "3", "2", "0"}, + {"5", "2", "", "25"}, + {"1", "65537", "2", "1"}, + {"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"}, + {"0x8000000000000000", "2", "6719", "4944"}, + {"0x8000000000000000", "3", "6719", "5447"}, + {"0x8000000000000000", "1000", "6719", "1603"}, + {"0x8000000000000000", "1000000", "6719", "3199"}, + {"0x8000000000000000", "-1000000", "6719", "3663"}, // 3663 = ModInverse(3199, 6719) Issue #25865 + + {"0xffffffffffffffffffffffffffffffff", "0x12345678123456781234567812345678123456789", "0x01112222333344445555666677778889", "0x36168FA1DB3AAE6C8CE647E137F97A"}, + + { + "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", + "298472983472983471903246121093472394872319615612417471234712061", + "29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464", + "23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291", + }, + // test case for issue 8822 + { + "11001289118363089646017359372117963499250546375269047542777928006103246876688756735760905680604646624353196869572752623285140408755420374049317646428185270079555372763503115646054602867593662923894140940837479507194934267532831694565516466765025434902348314525627418515646588160955862839022051353653052947073136084780742729727874803457643848197499548297570026926927502505634297079527299004267769780768565695459945235586892627059178884998772989397505061206395455591503771677500931269477503508150175717121828518985901959919560700853226255420793148986854391552859459511723547532575574664944815966793196961286234040892865", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", + "21484252197776302499639938883777710321993113097987201050501182909581359357618579566746556372589385361683610524730509041328855066514963385522570894839035884713051640171474186548713546686476761306436434146475140156284389181808675016576845833340494848283681088886584219750554408060556769486628029028720727393293111678826356480455433909233520504112074401376133077150471237549474149190242010469539006449596611576612573955754349042329130631128234637924786466585703488460540228477440853493392086251021228087076124706778899179648655221663765993962724699135217212118535057766739392069738618682722216712319320435674779146070442", + }, + { + "-0x1BCE04427D8032319A89E5C4136456671AC620883F2C4139E57F91307C485AD2D6204F4F87A58262652DB5DBBAC72B0613E51B835E7153BEC6068F5C8D696B74DBD18FEC316AEF73985CF0475663208EB46B4F17DD9DA55367B03323E5491A70997B90C059FB34809E6EE55BCFBD5F2F52233BFE62E6AA9E4E26A1D4C2439883D14F2633D55D8AA66A1ACD5595E778AC3A280517F1157989E70C1A437B849F1877B779CC3CDDEDE2DAA6594A6C66D181A00A5F777EE60596D8773998F6E988DEAE4CCA60E4DDCF9590543C89F74F603259FCAD71660D30294FBBE6490300F78A9D63FA660DC9417B8B9DDA28BEB3977B621B988E23D4D954F322C3540541BC649ABD504C50FADFD9F0987D58A2BF689313A285E773FF02899A6EF887D1D4A0D2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", + "21484252197776302499639938883777710321993113097987201050501182909581359357618579566746556372589385361683610524730509041328855066514963385522570894839035884713051640171474186548713546686476761306436434146475140156284389181808675016576845833340494848283681088886584219750554408060556769486628029028720727393293111678826356480455433909233520504112074401376133077150471237549474149190242010469539006449596611576612573955754349042329130631128234637924786466585703488460540228477440853493392086251021228087076124706778899179648655221663765993962724699135217212118535057766739392069738618682722216712319320435674779146070442", + }, + + // test cases for issue 13907 + {"0xffffffff00000001", "0xffffffff00000001", "0xffffffff00000001", "0"}, + {"0xffffffffffffffff00000001", "0xffffffffffffffff00000001", "0xffffffffffffffff00000001", "0"}, + {"0xffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffff00000001", "0"}, + {"0xffffffffffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffffffffffff00000001", "0xffffffffffffffffffffffffffffffff00000001", "0"}, + + { + "2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", // odd + "0x6AADD3E3E424D5B713FCAA8D8945B1E055166132038C57BBD2D51C833F0C5EA2007A2324CE514F8E8C2F008A2F36F44005A4039CB55830986F734C93DAF0EB4BAB54A6A8C7081864F44346E9BC6F0A3EB9F2C0146A00C6A05187D0C101E1F2D038CDB70CB5E9E05A2D188AB6CBB46286624D4415E7D4DBFAD3BCC6009D915C406EED38F468B940F41E6BEDC0430DD78E6F19A7DA3A27498A4181E24D738B0072D8F6ADB8C9809A5B033A09785814FD9919F6EF9F83EEA519BEC593855C4C10CBEEC582D4AE0792158823B0275E6AEC35242740468FAF3D5C60FD1E376362B6322F78B7ED0CA1C5BBCD2B49734A56C0967A1D01A100932C837B91D592CE08ABFF", + }, + { + "2", + "0xB08FFB20760FFED58FADA86DFEF71AD72AA0FA763219618FE022C197E54708BB1191C66470250FCE8879487507CEE41381CA4D932F81C2B3F1AB20B539D50DCD", + "0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", // even + "0x7858794B5897C29F4ED0B40913416AB6C48588484E6A45F2ED3E26C941D878E923575AAC434EE2750E6439A6976F9BB4D64CEDB2A53CE8D04DD48CADCDF8E46F22747C6B81C6CEA86C0D873FBF7CEF262BAAC43A522BD7F32F3CDAC52B9337C77B3DCFB3DB3EDD80476331E82F4B1DF8EFDC1220C92656DFC9197BDC1877804E28D928A2A284B8DED506CBA304435C9D0133C246C98A7D890D1DE60CBC53A024361DA83A9B8775019083D22AC6820ED7C3C68F8E801DD4EC779EE0A05C6EB682EF9840D285B838369BA7E148FA27691D524FAEAF7C6ECE2A4B99A294B9F2C241857B5B90CC8BFFCFCF18DFA7D676131D5CD3855A5A3E8EBFA0CDFADB4D198B4A", + }, +} + +func TestBigIntExp(t *testing.T) { + for i, test := range expTests { + x, ok1 := new(BigInt).SetString(test.x, 0) + y, ok2 := new(BigInt).SetString(test.y, 0) + + var ok3, ok4 bool + var out, m *BigInt + + if len(test.out) == 0 { + out, ok3 = nil, true + } else { + out, ok3 = new(BigInt).SetString(test.out, 0) + } + + if len(test.m) == 0 { + m, ok4 = nil, true + } else { + m, ok4 = new(BigInt).SetString(test.m, 0) + } + + if !ok1 || !ok2 || !ok3 || !ok4 { + t.Errorf("#%d: error in input", i) + continue + } + + z1 := new(BigInt).Exp(x, y, m) + if !(z1 == nil && out == nil || z1.Cmp(out) == 0) { + t.Errorf("#%d: got %x want %x", i, z1, out) + } + + if m == nil { + // The result should be the same as for m == 0; + // specifically, there should be no div-zero panic. + m = new(BigInt) // m != nil && len(m.abs) == 0 + z2 := new(BigInt).Exp(x, y, m) + if z2.Cmp(z1) != 0 { + t.Errorf("#%d: got %x want %x", i, z2, z1) + } + } + } +} + +type intShiftTest struct { + in string + shift uint + out string +} + +var rshTests = []intShiftTest{ + {"0", 0, "0"}, + {"-0", 0, "0"}, + {"0", 1, "0"}, + {"0", 2, "0"}, + {"1", 0, "1"}, + {"1", 1, "0"}, + {"1", 2, "0"}, + {"2", 0, "2"}, + {"2", 1, "1"}, + {"-1", 0, "-1"}, + {"-1", 1, "-1"}, + {"-1", 10, "-1"}, + {"-100", 2, "-25"}, + {"-100", 3, "-13"}, + {"-100", 100, "-1"}, + {"4294967296", 0, "4294967296"}, + {"4294967296", 1, "2147483648"}, + {"4294967296", 2, "1073741824"}, + {"18446744073709551616", 0, "18446744073709551616"}, + {"18446744073709551616", 1, "9223372036854775808"}, + {"18446744073709551616", 2, "4611686018427387904"}, + {"18446744073709551616", 64, "1"}, + {"340282366920938463463374607431768211456", 64, "18446744073709551616"}, + {"340282366920938463463374607431768211456", 128, "1"}, +} + +func TestBigIntRsh(t *testing.T) { + for i, test := range rshTests { + in, _ := new(BigInt).SetString(test.in, 10) + expected, _ := new(BigInt).SetString(test.out, 10) + out := new(BigInt).Rsh(in, test.shift) + + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + } +} + +func TestBigIntRshSelf(t *testing.T) { + for i, test := range rshTests { + z, _ := new(BigInt).SetString(test.in, 10) + expected, _ := new(BigInt).SetString(test.out, 10) + z.Rsh(z, test.shift) + + if z.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, z, expected) + } + } +} + +var lshTests = []intShiftTest{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 2, "0"}, + {"1", 0, "1"}, + {"1", 1, "2"}, + {"1", 2, "4"}, + {"2", 0, "2"}, + {"2", 1, "4"}, + {"2", 2, "8"}, + {"-87", 1, "-174"}, + {"4294967296", 0, "4294967296"}, + {"4294967296", 1, "8589934592"}, + {"4294967296", 2, "17179869184"}, + {"18446744073709551616", 0, "18446744073709551616"}, + {"9223372036854775808", 1, "18446744073709551616"}, + {"4611686018427387904", 2, "18446744073709551616"}, + {"1", 64, "18446744073709551616"}, + {"18446744073709551616", 64, "340282366920938463463374607431768211456"}, + {"1", 128, "340282366920938463463374607431768211456"}, +} + +func TestBigIntLsh(t *testing.T) { + for i, test := range lshTests { + in, _ := new(BigInt).SetString(test.in, 10) + expected, _ := new(BigInt).SetString(test.out, 10) + out := new(BigInt).Lsh(in, test.shift) + + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + } +} + +func TestBigIntLshSelf(t *testing.T) { + for i, test := range lshTests { + z, _ := new(BigInt).SetString(test.in, 10) + expected, _ := new(BigInt).SetString(test.out, 10) + z.Lsh(z, test.shift) + + if z.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, z, expected) + } + } +} + +func TestBigIntLshRsh(t *testing.T) { + for i, test := range rshTests { + in, _ := new(BigInt).SetString(test.in, 10) + out := new(BigInt).Lsh(in, test.shift) + out = out.Rsh(out, test.shift) + + if in.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } + for i, test := range lshTests { + in, _ := new(BigInt).SetString(test.in, 10) + out := new(BigInt).Lsh(in, test.shift) + out.Rsh(out, test.shift) + + if in.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } +} + +// Entries must be sorted by value in ascending order. +var cmpAbsTests = []string{ + "0", + "1", + "2", + "10", + "10000000", + "2783678367462374683678456387645876387564783686583485", + "2783678367462374683678456387645876387564783686583486", + "32957394867987420967976567076075976570670947609750670956097509670576075067076027578341538", +} + +func TestBigIntCmpAbs(t *testing.T) { + values := make([]*BigInt, len(cmpAbsTests)) + var prev *BigInt + for i, s := range cmpAbsTests { + x, ok := new(BigInt).SetString(s, 0) + if !ok { + t.Fatalf("SetString(%s, 0) failed", s) + } + if prev != nil && prev.Cmp(x) >= 0 { + t.Fatal("cmpAbsTests entries not sorted in ascending order") + } + values[i] = x + prev = x + } + + for i, x := range values { + for j, y := range values { + // try all combinations of signs for x, y + for k := 0; k < 4; k++ { + var a, b BigInt + a.Set(x) + b.Set(y) + if k&1 != 0 { + a.Neg(&a) + } + if k&2 != 0 { + b.Neg(&b) + } + + got := a.CmpAbs(&b) + want := 0 + switch { + case i > j: + want = 1 + case i < j: + want = -1 + } + if got != want { + t.Errorf("absCmp |%s|, |%s|: got %d; want %d", &a, &b, got, want) + } + } + } + } +} + +func TestBigIntCmpSelf(t *testing.T) { + for _, s := range cmpAbsTests { + x, ok := new(BigInt).SetString(s, 0) + if !ok { + t.Fatalf("SetString(%s, 0) failed", s) + } + got := x.Cmp(x) + want := 0 + if got != want { + t.Errorf("x = %s: x.Cmp(x): got %d; want %d", x, got, want) + } + } +} + +var int64Tests = []string{ + // int64 + "0", + "1", + "-1", + "4294967295", + "-4294967295", + "4294967296", + "-4294967296", + "9223372036854775807", + "-9223372036854775807", + "-9223372036854775808", + + // not int64 + "0x8000000000000000", + "-0x8000000000000001", + "38579843757496759476987459679745", + "-38579843757496759476987459679745", +} + +func TestBigInt64(t *testing.T) { + for _, s := range int64Tests { + var x BigInt + _, ok := x.SetString(s, 0) + if !ok { + t.Errorf("SetString(%s, 0) failed", s) + continue + } + + want, err := strconv.ParseInt(s, 0, 64) + if err != nil { + if err.(*strconv.NumError).Err == strconv.ErrRange { + if x.IsInt64() { + t.Errorf("IsInt64(%s) succeeded unexpectedly", s) + } + } else { + t.Errorf("ParseInt(%s) failed", s) + } + continue + } + + if !x.IsInt64() { + t.Errorf("IsInt64(%s) failed unexpectedly", s) + } + + got := x.Int64() + if got != want { + t.Errorf("Int64(%s) = %d; want %d", s, got, want) + } + } +} + +var uint64Tests = []string{ + // uint64 + "0", + "1", + "4294967295", + "4294967296", + "8589934591", + "8589934592", + "9223372036854775807", + "9223372036854775808", + "0x08000000000000000", + + // not uint64 + "0x10000000000000000", + "-0x08000000000000000", + "-1", +} + +func TestBigIntUint64(t *testing.T) { + for _, s := range uint64Tests { + var x BigInt + _, ok := x.SetString(s, 0) + if !ok { + t.Errorf("SetString(%s, 0) failed", s) + continue + } + + want, err := strconv.ParseUint(s, 0, 64) + if err != nil { + // check for sign explicitly (ErrRange doesn't cover signed input) + if s[0] == '-' || err.(*strconv.NumError).Err == strconv.ErrRange { + if x.IsUint64() { + t.Errorf("IsUint64(%s) succeeded unexpectedly", s) + } + } else { + t.Errorf("ParseUint(%s) failed", s) + } + continue + } + + if !x.IsUint64() { + t.Errorf("IsUint64(%s) failed unexpectedly", s) + } + + got := x.Uint64() + if got != want { + t.Errorf("Uint64(%s) = %d; want %d", s, got, want) + } + } +} + +var bitwiseTests = []struct { + x, y string + and, or, xor, andNot string +}{ + {"0x00", "0x00", "0x00", "0x00", "0x00", "0x00"}, + {"0x00", "0x01", "0x00", "0x01", "0x01", "0x00"}, + {"0x01", "0x00", "0x00", "0x01", "0x01", "0x01"}, + {"-0x01", "0x00", "0x00", "-0x01", "-0x01", "-0x01"}, + {"-0xaf", "-0x50", "-0xf0", "-0x0f", "0xe1", "0x41"}, + {"0x00", "-0x01", "0x00", "-0x01", "-0x01", "0x00"}, + {"0x01", "0x01", "0x01", "0x01", "0x00", "0x00"}, + {"-0x01", "-0x01", "-0x01", "-0x01", "0x00", "0x00"}, + {"0x07", "0x08", "0x00", "0x0f", "0x0f", "0x07"}, + {"0x05", "0x0f", "0x05", "0x0f", "0x0a", "0x00"}, + {"0xff", "-0x0a", "0xf6", "-0x01", "-0xf7", "0x09"}, + {"0x013ff6", "0x9a4e", "0x1a46", "0x01bffe", "0x01a5b8", "0x0125b0"}, + {"-0x013ff6", "0x9a4e", "0x800a", "-0x0125b2", "-0x01a5bc", "-0x01c000"}, + {"-0x013ff6", "-0x9a4e", "-0x01bffe", "-0x1a46", "0x01a5b8", "0x8008"}, + { + "0x1000009dc6e3d9822cba04129bcbe3401", + "0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "0x1000001186210100001000009048c2001", + "0xb9bd7d543685789d57cb918e8bfeff7fddb2ebe87dfbbdfe35fd", + "0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fc", + "0x8c40c2d8822caa04120b8321400", + }, + { + "0x1000009dc6e3d9822cba04129bcbe3401", + "-0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "0x8c40c2d8822caa04120b8321401", + "-0xb9bd7d543685789d57ca918e82229142459020483cd2014001fd", + "-0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fe", + "0x1000001186210100001000009048c2000", + }, + { + "-0x1000009dc6e3d9822cba04129bcbe3401", + "-0xb9bd7d543685789d57cb918e833af352559021483cdb05cc21fd", + "-0xb9bd7d543685789d57cb918e8bfeff7fddb2ebe87dfbbdfe35fd", + "-0x1000001186210100001000009048c2001", + "0xb9bd7d543685789d57ca918e8ae69d6fcdb2eae87df2b97215fc", + "0xb9bd7d543685789d57ca918e82229142459020483cd2014001fc", + }, +} + +type bitFun func(z, x, y *BigInt) *BigInt + +func testBitFun(t *testing.T, msg string, f bitFun, x, y *BigInt, exp string) { + expected := new(BigInt) + expected.SetString(exp, 0) + + out := f(new(BigInt), x, y) + if out.Cmp(expected) != 0 { + t.Errorf("%s: got %s want %s", msg, out, expected) + } +} + +func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *BigInt, exp string) { + self := new(BigInt) + self.Set(x) + expected := new(BigInt) + expected.SetString(exp, 0) + + self = f(self, self, y) + if self.Cmp(expected) != 0 { + t.Errorf("%s: got %s want %s", msg, self, expected) + } +} + +func altBit(x *BigInt, i int) uint { + z := new(BigInt).Rsh(x, uint(i)) + z = z.And(z, NewBigInt(1)) + if z.Cmp(new(BigInt)) != 0 { + return 1 + } + return 0 +} + +func altSetBit(z *BigInt, x *BigInt, i int, b uint) *BigInt { + one := NewBigInt(1) + m := one.Lsh(one, uint(i)) + switch b { + case 1: + return z.Or(x, m) + case 0: + return z.AndNot(x, m) + } + panic("set bit is not 0 or 1") +} + +func testBitset(t *testing.T, x *BigInt) { + n := x.BitLen() + z := new(BigInt).Set(x) + z1 := new(BigInt).Set(x) + for i := 0; i < n+10; i++ { + old := z.Bit(i) + old1 := altBit(z1, i) + if old != old1 { + t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1) + } + z := new(BigInt).SetBit(z, i, 1) + z1 := altSetBit(new(BigInt), z1, i, 1) + if z.Bit(i) == 0 { + t.Errorf("bitset: bit %d of %s got 0 want 1", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1) + } + z.SetBit(z, i, 0) + altSetBit(z1, z1, i, 0) + if z.Bit(i) != 0 { + t.Errorf("bitset: bit %d of %s got 1 want 0", i, x) + } + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1) + } + altSetBit(z1, z1, i, old) + z.SetBit(z, i, old) + if z.Cmp(z1) != 0 { + t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1) + } + } + if z.Cmp(x) != 0 { + t.Errorf("bitset: got %s want %s", z, x) + } +} + +var bitsetTests = []struct { + x string + i int + b uint +}{ + {"0", 0, 0}, + {"0", 200, 0}, + {"1", 0, 1}, + {"1", 1, 0}, + {"-1", 0, 1}, + {"-1", 200, 1}, + {"0x2000000000000000000000000000", 108, 0}, + {"0x2000000000000000000000000000", 109, 1}, + {"0x2000000000000000000000000000", 110, 0}, + {"-0x2000000000000000000000000001", 108, 1}, + {"-0x2000000000000000000000000001", 109, 0}, + {"-0x2000000000000000000000000001", 110, 1}, +} + +func TestBigIntBitSet(t *testing.T) { + for _, test := range bitwiseTests { + x := new(BigInt) + x.SetString(test.x, 0) + testBitset(t, x) + x = new(BigInt) + x.SetString(test.y, 0) + testBitset(t, x) + } + for i, test := range bitsetTests { + x := new(BigInt) + x.SetString(test.x, 0) + b := x.Bit(test.i) + if b != test.b { + t.Errorf("#%d got %v want %v", i, b, test.b) + } + } + z := NewBigInt(1) + z.SetBit(NewBigInt(0), 2, 1) + if z.Cmp(NewBigInt(4)) != 0 { + t.Errorf("destination leaked into result; got %s want 4", z) + } +} + +var tzbTests = []struct { + in string + out uint +}{ + {"0", 0}, + {"1", 0}, + {"-1", 0}, + {"4", 2}, + {"-8", 3}, + {"0x4000000000000000000", 74}, + {"-0x8000000000000000000", 75}, +} + +func TestBigIntTrailingZeroBits(t *testing.T) { + for i, test := range tzbTests { + in, _ := new(BigInt).SetString(test.in, 0) + want := test.out + got := in.TrailingZeroBits() + + if got != want { + t.Errorf("#%d: got %v want %v", i, got, want) + } + } +} + +func TestBigIntBitwise(t *testing.T) { + x := new(BigInt) + y := new(BigInt) + for _, test := range bitwiseTests { + x.SetString(test.x, 0) + y.SetString(test.y, 0) + + testBitFun(t, "and", (*BigInt).And, x, y, test.and) + testBitFunSelf(t, "and", (*BigInt).And, x, y, test.and) + testBitFun(t, "andNot", (*BigInt).AndNot, x, y, test.andNot) + testBitFunSelf(t, "andNot", (*BigInt).AndNot, x, y, test.andNot) + testBitFun(t, "or", (*BigInt).Or, x, y, test.or) + testBitFunSelf(t, "or", (*BigInt).Or, x, y, test.or) + testBitFun(t, "xor", (*BigInt).Xor, x, y, test.xor) + testBitFunSelf(t, "xor", (*BigInt).Xor, x, y, test.xor) + } +} + +var notTests = []struct { + in string + out string +}{ + {"0", "-1"}, + {"1", "-2"}, + {"7", "-8"}, + {"0", "-1"}, + {"-81910", "81909"}, + { + "298472983472983471903246121093472394872319615612417471234712061", + "-298472983472983471903246121093472394872319615612417471234712062", + }, +} + +func TestBigIntNot(t *testing.T) { + in := new(BigInt) + out := new(BigInt) + expected := new(BigInt) + for i, test := range notTests { + in.SetString(test.in, 10) + expected.SetString(test.out, 10) + out = out.Not(in) + if out.Cmp(expected) != 0 { + t.Errorf("#%d: got %s want %s", i, out, expected) + } + out = out.Not(out) + if out.Cmp(in) != 0 { + t.Errorf("#%d: got %s want %s", i, out, in) + } + } +} + +var modInverseTests = []struct { + element string + modulus string +}{ + {"1234567", "458948883992"}, + {"239487239847", "2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919"}, + {"-10", "13"}, // issue #16984 + {"10", "-13"}, + {"-17", "-13"}, +} + +func TestBigIntModInverse(t *testing.T) { + var element, modulus, gcd, inverse BigInt + one := NewBigInt(1) + for _, test := range modInverseTests { + (&element).SetString(test.element, 10) + (&modulus).SetString(test.modulus, 10) + (&inverse).ModInverse(&element, &modulus) + (&inverse).Mul(&inverse, &element) + (&inverse).Mod(&inverse, &modulus) + if (&inverse).Cmp(one) != 0 { + t.Errorf("ModInverse(%d,%d)*%d%%%d=%d, not 1", &element, &modulus, &element, &modulus, &inverse) + } + } + // exhaustive test for small values + for n := 2; n < 100; n++ { + (&modulus).SetInt64(int64(n)) + for x := 1; x < n; x++ { + (&element).SetInt64(int64(x)) + (&gcd).GCD(nil, nil, &element, &modulus) + if (&gcd).Cmp(one) != 0 { + continue + } + (&inverse).ModInverse(&element, &modulus) + (&inverse).Mul(&inverse, &element) + (&inverse).Mod(&inverse, &modulus) + if (&inverse).Cmp(one) != 0 { + t.Errorf("ModInverse(%d,%d)*%d%%%d=%d, not 1", &element, &modulus, &element, &modulus, &inverse) + } + } + } +} + +// testModSqrt is a helper for TestModSqrt, +// which checks that ModSqrt can compute a square-root of elt^2. +func testModSqrt(t *testing.T, elt, mod, sq, sqrt *BigInt) bool { + var sqChk, sqrtChk, sqrtsq BigInt + sq.Mul(elt, elt) + sq.Mod(sq, mod) + z := sqrt.ModSqrt(sq, mod) + if z != sqrt { + t.Errorf("ModSqrt returned wrong value %s", z) + } + + // test ModSqrt arguments outside the range [0,mod) + sqChk.Add(sq, mod) + z = sqrtChk.ModSqrt(&sqChk, mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + sqChk.Sub(sq, mod) + z = sqrtChk.ModSqrt(&sqChk, mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + + // test x aliasing z + z = sqrtChk.ModSqrt(sqrtChk.Set(sq), mod) + if z != &sqrtChk || z.Cmp(sqrt) != 0 { + t.Errorf("ModSqrt returned inconsistent value %s", z) + } + + // make sure we actually got a square root + if sqrt.Cmp(elt) == 0 { + return true // we found the "desired" square root + } + sqrtsq.Mul(sqrt, sqrt) // make sure we found the "other" one + sqrtsq.Mod(&sqrtsq, mod) + return sq.Cmp(&sqrtsq) == 0 +} + +var primes = []string{ + "2", + "3", + "5", + "7", + "11", + + "13756265695458089029", + "13496181268022124907", + "10953742525620032441", + "17908251027575790097", + + // https://golang.org/issue/638 + "18699199384836356663", + + "98920366548084643601728869055592650835572950932266967461790948584315647051443", + "94560208308847015747498523884063394671606671904944666360068158221458669711639", + + // https://primes.utm.edu/lists/small/small3.html + "449417999055441493994709297093108513015373787049558499205492347871729927573118262811508386655998299074566974373711472560655026288668094291699357843464363003144674940345912431129144354948751003607115263071543163", + "230975859993204150666423538988557839555560243929065415434980904258310530753006723857139742334640122533598517597674807096648905501653461687601339782814316124971547968912893214002992086353183070342498989426570593", + "5521712099665906221540423207019333379125265462121169655563495403888449493493629943498064604536961775110765377745550377067893607246020694972959780839151452457728855382113555867743022746090187341871655890805971735385789993", + "203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", + + // ECC primes: https://tools.ietf.org/html/draft-ladd-safecurves-02 + "3618502788666131106986593281521497120414687020801267626233049500247285301239", // Curve1174: 2^251-9 + "57896044618658097711785492504343953926634992332820282019728792003956564819949", // Curve25519: 2^255-19 + "9850501549098619803069760025035903451269934817616361666987073351061430442874302652853566563721228910201656997576599", // E-382: 2^382-105 + "42307582002575910332922579714097346549017899709713998034217522897561970639123926132812109468141778230245837569601494931472367", // Curve41417: 2^414-17 + "6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", // E-521: 2^521-1 +} + +func TestBigIntModSqrt(t *testing.T) { + var elt, mod, modx4, sq, sqrt BigInt + r := rand.New(rand.NewSource(9)) + for i, s := range primes[1:] { // skip 2, use only odd primes + mod.SetString(s, 10) + modx4.Lsh(&mod, 2) + + // test a few random elements per prime + for x := 1; x < 5; x++ { + elt.Rand(r, &modx4) + elt.Sub(&elt, &mod) // test range [-mod, 3*mod) + if !testModSqrt(t, &elt, &mod, &sq, &sqrt) { + t.Errorf("#%d: failed (sqrt(e) = %s)", i, &sqrt) + } + } + + if testing.Short() && i > 2 { + break + } + } + + if testing.Short() { + return + } + + // exhaustive test for small values + for n := 3; n < 100; n++ { + mod.SetInt64(int64(n)) + if !mod.ProbablyPrime(10) { + continue + } + isSquare := make([]bool, n) + + // test all the squares + for x := 1; x < n; x++ { + elt.SetInt64(int64(x)) + if !testModSqrt(t, &elt, &mod, &sq, &sqrt) { + t.Errorf("#%d: failed (sqrt(%d,%d) = %s)", x, &elt, &mod, &sqrt) + } + isSquare[sq.Uint64()] = true + } + + // test all non-squares + for x := 1; x < n; x++ { + sq.SetInt64(int64(x)) + z := sqrt.ModSqrt(&sq, &mod) + if !isSquare[x] && z != nil { + t.Errorf("#%d: failed (sqrt(%d,%d) = nil)", x, &sqrt, &mod) + } + } + } +} + +func TestBigIntIssue2607(t *testing.T) { + // This code sequence used to hang. + n := NewBigInt(10) + n.Rand(rand.New(rand.NewSource(9)), n) +} + +func TestBigIntSqrt(t *testing.T) { + root := 0 + r := new(BigInt) + for i := 0; i < 10000; i++ { + if (root+1)*(root+1) <= i { + root++ + } + n := NewBigInt(int64(i)) + r.SetInt64(-2) + r.Sqrt(n) + if r.Cmp(NewBigInt(int64(root))) != 0 { + t.Errorf("Sqrt(%v) = %v, want %v", n, r, root) + } + } + + for i := 0; i < 1000; i += 10 { + n, _ := new(BigInt).SetString("1"+strings.Repeat("0", i), 10) + r := new(BigInt).Sqrt(n) + root, _ := new(BigInt).SetString("1"+strings.Repeat("0", i/2), 10) + if r.Cmp(root) != 0 { + t.Errorf("Sqrt(1e%d) = %v, want 1e%d", i, r, i/2) + } + } + + // Test aliasing. + r.SetInt64(100) + r.Sqrt(r) + if r.Int64() != 10 { + t.Errorf("Sqrt(100) = %v, want 10 (aliased output)", r.Int64()) + } +} + +// We can't test this together with the other Exp tests above because +// it requires a different receiver setup. +func TestBigIntIssue22830(t *testing.T) { + one := new(BigInt).SetInt64(1) + base, _ := new(BigInt).SetString("84555555300000000000", 10) + mod, _ := new(BigInt).SetString("66666670001111111111", 10) + want, _ := new(BigInt).SetString("17888885298888888889", 10) + + var tests = []int64{ + 0, 1, -1, + } + + for _, n := range tests { + m := NewBigInt(n) + if got := m.Exp(base, one, mod); got.Cmp(want) != 0 { + t.Errorf("(%v).Exp(%s, 1, %s) = %s, want %s", n, base, mod, got, want) + } + } +} + +// +// Tests from src/math/big/intconv_test.go +// + +var stringTests = []struct { + in string + out string + base int + val int64 + ok bool +}{ + // invalid inputs + {in: ""}, + {in: "a"}, + {in: "z"}, + {in: "+"}, + {in: "-"}, + {in: "0b"}, + {in: "0o"}, + {in: "0x"}, + {in: "0y"}, + {in: "2", base: 2}, + {in: "0b2", base: 0}, + {in: "08"}, + {in: "8", base: 8}, + {in: "0xg", base: 0}, + {in: "g", base: 16}, + + // invalid inputs with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {in: "_"}, + {in: "0_"}, + {in: "_0"}, + {in: "-1__0"}, + {in: "0x10_"}, + {in: "1_000", base: 10}, // separators are not permitted for bases != 0 + {in: "d_e_a_d", base: 16}, + + // valid inputs + {"0", "0", 0, 0, true}, + {"0", "0", 10, 0, true}, + {"0", "0", 16, 0, true}, + {"+0", "0", 0, 0, true}, + {"-0", "0", 0, 0, true}, + {"10", "10", 0, 10, true}, + {"10", "10", 10, 10, true}, + {"10", "10", 16, 16, true}, + {"-10", "-10", 16, -16, true}, + {"+10", "10", 16, 16, true}, + {"0b10", "2", 0, 2, true}, + {"0o10", "8", 0, 8, true}, + {"0x10", "16", 0, 16, true}, + {in: "0x10", base: 16}, + {"-0x10", "-16", 0, -16, true}, + {"+0x10", "16", 0, 16, true}, + {"00", "0", 0, 0, true}, + {"0", "0", 8, 0, true}, + {"07", "7", 0, 7, true}, + {"7", "7", 8, 7, true}, + {"023", "19", 0, 19, true}, + {"23", "23", 8, 19, true}, + {"cafebabe", "cafebabe", 16, 0xcafebabe, true}, + {"0b0", "0", 0, 0, true}, + {"-111", "-111", 2, -7, true}, + {"-0b111", "-7", 0, -7, true}, + {"0b1001010111", "599", 0, 0x257, true}, + {"1001010111", "1001010111", 2, 0x257, true}, + {"A", "a", 36, 10, true}, + {"A", "A", 37, 36, true}, + {"ABCXYZ", "abcxyz", 36, 623741435, true}, + {"ABCXYZ", "ABCXYZ", 62, 33536793425, true}, + + // valid input with separators + // (smoke tests only - a comprehensive set of tests is in natconv_test.go) + {"1_000", "1000", 0, 1000, true}, + {"0b_1010", "10", 0, 10, true}, + {"+0o_660", "432", 0, 0660, true}, + {"-0xF00D_1E", "-15731998", 0, -0xf00d1e, true}, +} + +func TestBigIntText(t *testing.T) { + z := new(BigInt) + for _, test := range stringTests { + if !test.ok { + continue + } + + _, ok := z.SetString(test.in, test.base) + if !ok { + t.Errorf("%v: failed to parse", test) + continue + } + + base := test.base + if base == 0 { + base = 10 + } + + if got := z.Text(base); got != test.out { + t.Errorf("%v: got %s; want %s", test, got, test.out) + } + } +} + +func TestBigIntAppendText(t *testing.T) { + z := new(BigInt) + var buf []byte + for _, test := range stringTests { + if !test.ok { + continue + } + + _, ok := z.SetString(test.in, test.base) + if !ok { + t.Errorf("%v: failed to parse", test) + continue + } + + base := test.base + if base == 0 { + base = 10 + } + + i := len(buf) + buf = z.Append(buf, base) + if got := string(buf[i:]); got != test.out { + t.Errorf("%v: got %s; want %s", test, got, test.out) + } + } +} + +func TestBigIntGetString(t *testing.T) { + format := func(base int) string { + switch base { + case 2: + return "%b" + case 8: + return "%o" + case 16: + return "%x" + } + return "%d" + } + + z := new(BigInt) + for i, test := range stringTests { + if !test.ok { + continue + } + z.SetInt64(test.val) + + if test.base == 10 { + if got := z.String(); got != test.out { + t.Errorf("#%da got %s; want %s", i, got, test.out) + } + } + + f := format(test.base) + got := fmt.Sprintf(f, z) + if f == "%d" { + if got != fmt.Sprintf("%d", test.val) { + t.Errorf("#%db got %s; want %d", i, got, test.val) + } + } else { + if got != test.out { + t.Errorf("#%dc got %s; want %s", i, got, test.out) + } + } + } +} + +func TestBigIntSetString(t *testing.T) { + tmp := new(BigInt) + for i, test := range stringTests { + // initialize to a non-zero value so that issues with parsing + // 0 are detected + tmp.SetInt64(1234567890) + n1, ok1 := new(BigInt).SetString(test.in, test.base) + n2, ok2 := tmp.SetString(test.in, test.base) + expected := NewBigInt(test.val) + if ok1 != test.ok || ok2 != test.ok { + t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok) + continue + } + if !ok1 { + if n1 != nil { + t.Errorf("#%d (input '%s') n1 != nil", i, test.in) + } + continue + } + if !ok2 { + if n2 != nil { + t.Errorf("#%d (input '%s') n2 != nil", i, test.in) + } + continue + } + + if n1.Cmp(expected) != 0 { + t.Errorf("#%d (input '%s') got: %s want: %d", i, test.in, n1, test.val) + } + if n2.Cmp(expected) != 0 { + t.Errorf("#%d (input '%s') got: %s want: %d", i, test.in, n2, test.val) + } + } +} + +var formatTests = []struct { + input string + format string + output string +}{ + {"", "%x", ""}, + {"", "%#x", ""}, + {"", "%#y", "%!y(big.Int=)"}, + + {"10", "%b", "1010"}, + {"10", "%o", "12"}, + {"10", "%d", "10"}, + {"10", "%v", "10"}, + {"10", "%x", "a"}, + {"10", "%X", "A"}, + {"-10", "%X", "-A"}, + {"10", "%y", "%!y(big.Int=10)"}, + {"-10", "%y", "%!y(big.Int=-10)"}, + + {"10", "%#b", "0b1010"}, + {"10", "%#o", "012"}, + {"10", "%O", "0o12"}, + {"-10", "%#b", "-0b1010"}, + {"-10", "%#o", "-012"}, + {"-10", "%O", "-0o12"}, + {"10", "%#d", "10"}, + {"10", "%#v", "10"}, + {"10", "%#x", "0xa"}, + {"10", "%#X", "0XA"}, + {"-10", "%#X", "-0XA"}, + {"10", "%#y", "%!y(big.Int=10)"}, + {"-10", "%#y", "%!y(big.Int=-10)"}, + + {"1234", "%d", "1234"}, + {"1234", "%3d", "1234"}, + {"1234", "%4d", "1234"}, + {"-1234", "%d", "-1234"}, + {"1234", "% 5d", " 1234"}, + {"1234", "%+5d", "+1234"}, + {"1234", "%-5d", "1234 "}, + {"1234", "%x", "4d2"}, + {"1234", "%X", "4D2"}, + {"-1234", "%3x", "-4d2"}, + {"-1234", "%4x", "-4d2"}, + {"-1234", "%5x", " -4d2"}, + {"-1234", "%-5x", "-4d2 "}, + {"1234", "%03d", "1234"}, + {"1234", "%04d", "1234"}, + {"1234", "%05d", "01234"}, + {"1234", "%06d", "001234"}, + {"-1234", "%06d", "-01234"}, + {"1234", "%+06d", "+01234"}, + {"1234", "% 06d", " 01234"}, + {"1234", "%-6d", "1234 "}, + {"1234", "%-06d", "1234 "}, + {"-1234", "%-06d", "-1234 "}, + + {"1234", "%.3d", "1234"}, + {"1234", "%.4d", "1234"}, + {"1234", "%.5d", "01234"}, + {"1234", "%.6d", "001234"}, + {"-1234", "%.3d", "-1234"}, + {"-1234", "%.4d", "-1234"}, + {"-1234", "%.5d", "-01234"}, + {"-1234", "%.6d", "-001234"}, + + {"1234", "%8.3d", " 1234"}, + {"1234", "%8.4d", " 1234"}, + {"1234", "%8.5d", " 01234"}, + {"1234", "%8.6d", " 001234"}, + {"-1234", "%8.3d", " -1234"}, + {"-1234", "%8.4d", " -1234"}, + {"-1234", "%8.5d", " -01234"}, + {"-1234", "%8.6d", " -001234"}, + + {"1234", "%+8.3d", " +1234"}, + {"1234", "%+8.4d", " +1234"}, + {"1234", "%+8.5d", " +01234"}, + {"1234", "%+8.6d", " +001234"}, + {"-1234", "%+8.3d", " -1234"}, + {"-1234", "%+8.4d", " -1234"}, + {"-1234", "%+8.5d", " -01234"}, + {"-1234", "%+8.6d", " -001234"}, + + {"1234", "% 8.3d", " 1234"}, + {"1234", "% 8.4d", " 1234"}, + {"1234", "% 8.5d", " 01234"}, + {"1234", "% 8.6d", " 001234"}, + {"-1234", "% 8.3d", " -1234"}, + {"-1234", "% 8.4d", " -1234"}, + {"-1234", "% 8.5d", " -01234"}, + {"-1234", "% 8.6d", " -001234"}, + + {"1234", "%.3x", "4d2"}, + {"1234", "%.4x", "04d2"}, + {"1234", "%.5x", "004d2"}, + {"1234", "%.6x", "0004d2"}, + {"-1234", "%.3x", "-4d2"}, + {"-1234", "%.4x", "-04d2"}, + {"-1234", "%.5x", "-004d2"}, + {"-1234", "%.6x", "-0004d2"}, + + {"1234", "%8.3x", " 4d2"}, + {"1234", "%8.4x", " 04d2"}, + {"1234", "%8.5x", " 004d2"}, + {"1234", "%8.6x", " 0004d2"}, + {"-1234", "%8.3x", " -4d2"}, + {"-1234", "%8.4x", " -04d2"}, + {"-1234", "%8.5x", " -004d2"}, + {"-1234", "%8.6x", " -0004d2"}, + + {"1234", "%+8.3x", " +4d2"}, + {"1234", "%+8.4x", " +04d2"}, + {"1234", "%+8.5x", " +004d2"}, + {"1234", "%+8.6x", " +0004d2"}, + {"-1234", "%+8.3x", " -4d2"}, + {"-1234", "%+8.4x", " -04d2"}, + {"-1234", "%+8.5x", " -004d2"}, + {"-1234", "%+8.6x", " -0004d2"}, + + {"1234", "% 8.3x", " 4d2"}, + {"1234", "% 8.4x", " 04d2"}, + {"1234", "% 8.5x", " 004d2"}, + {"1234", "% 8.6x", " 0004d2"}, + {"1234", "% 8.7x", " 00004d2"}, + {"1234", "% 8.8x", " 000004d2"}, + {"-1234", "% 8.3x", " -4d2"}, + {"-1234", "% 8.4x", " -04d2"}, + {"-1234", "% 8.5x", " -004d2"}, + {"-1234", "% 8.6x", " -0004d2"}, + {"-1234", "% 8.7x", "-00004d2"}, + {"-1234", "% 8.8x", "-000004d2"}, + + {"1234", "%-8.3d", "1234 "}, + {"1234", "%-8.4d", "1234 "}, + {"1234", "%-8.5d", "01234 "}, + {"1234", "%-8.6d", "001234 "}, + {"1234", "%-8.7d", "0001234 "}, + {"1234", "%-8.8d", "00001234"}, + {"-1234", "%-8.3d", "-1234 "}, + {"-1234", "%-8.4d", "-1234 "}, + {"-1234", "%-8.5d", "-01234 "}, + {"-1234", "%-8.6d", "-001234 "}, + {"-1234", "%-8.7d", "-0001234"}, + {"-1234", "%-8.8d", "-00001234"}, + + {"16777215", "%b", "111111111111111111111111"}, // 2**24 - 1 + + {"0", "%.d", ""}, + {"0", "%.0d", ""}, + {"0", "%3.d", ""}, +} + +func TestBigIntFormat(t *testing.T) { + for i, test := range formatTests { + var x *BigInt + if test.input != "" { + var ok bool + x, ok = new(BigInt).SetString(test.input, 0) + if !ok { + t.Errorf("#%d failed reading input %s", i, test.input) + } + } + output := fmt.Sprintf(test.format, x) + if output != test.output { + t.Errorf("#%d got %q; want %q, {%q, %q, %q}", i, output, test.output, test.input, test.format, test.output) + } + } +} + +var scanTests = []struct { + input string + format string + output string + remaining int +}{ + {"1010", "%b", "10", 0}, + {"0b1010", "%v", "10", 0}, + {"12", "%o", "10", 0}, + {"012", "%v", "10", 0}, + {"10", "%d", "10", 0}, + {"10", "%v", "10", 0}, + {"a", "%x", "10", 0}, + {"0xa", "%v", "10", 0}, + {"A", "%X", "10", 0}, + {"-A", "%X", "-10", 0}, + {"+0b1011001", "%v", "89", 0}, + {"0xA", "%v", "10", 0}, + {"0 ", "%v", "0", 1}, + {"2+3", "%v", "2", 2}, + {"0XABC 12", "%v", "2748", 3}, +} + +func TestBigIntScan(t *testing.T) { + var buf bytes.Buffer + for i, test := range scanTests { + x := new(BigInt) + buf.Reset() + buf.WriteString(test.input) + if _, err := fmt.Fscanf(&buf, test.format, x); err != nil { + t.Errorf("#%d error: %s", i, err) + } + if x.String() != test.output { + t.Errorf("#%d got %s; want %s", i, x.String(), test.output) + } + if buf.Len() != test.remaining { + t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining) + } + } +} + +// +// Tests from src/math/big/intmarsh_test.go +// + +var encodingTests = []string{ + "0", + "1", + "2", + "10", + "1000", + "1234567890", + "298472983472983471903246121093472394872319615612417471234712061", +} + +func TestBigIntGobEncoding(t *testing.T) { + var medium bytes.Buffer + enc := gob.NewEncoder(&medium) + dec := gob.NewDecoder(&medium) + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + medium.Reset() // empty buffer for each test case (in case of failures) + var tx BigInt + tx.SetString(x, 10) + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %s failed: %s", &tx, err) + continue + } + var rx BigInt + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +// Sending a nil BigInt pointer (inside a slice) on a round trip through gob should yield a zero. +func TestBigIntGobEncodingNilIntInSlice(t *testing.T) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + dec := gob.NewDecoder(buf) + + var in = make([]*BigInt, 1) + err := enc.Encode(&in) + if err != nil { + t.Errorf("gob encode failed: %q", err) + } + var out []*BigInt + err = dec.Decode(&out) + if err != nil { + t.Fatalf("gob decode failed: %q", err) + } + if len(out) != 1 { + t.Fatalf("wrong len; want 1 got %d", len(out)) + } + var zero BigInt + if out[0].Cmp(&zero) != 0 { + t.Fatalf("transmission of (*BigInt)(nil) failed: got %s want 0", out) + } +} + +func TestBigIntJSONEncoding(t *testing.T) { + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + var tx BigInt + tx.SetString(x, 10) + b, err := json.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx BigInt + if err := json.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("JSON encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +func TestBigIntXMLEncoding(t *testing.T) { + for _, test := range encodingTests { + for _, sign := range []string{"", "+", "-"} { + x := sign + test + var tx BigInt + tx.SetString(x, 0) + b, err := xml.Marshal(&tx) + if err != nil { + t.Errorf("marshaling of %s failed: %s", &tx, err) + continue + } + var rx BigInt + if err := xml.Unmarshal(b, &rx); err != nil { + t.Errorf("unmarshaling of %s failed: %s", &tx, err) + continue + } + if rx.Cmp(&tx) != 0 { + t.Errorf("XML encoding of %s failed: got %s want %s", &tx, &rx, &tx) + } + } + } +} + +// +// Benchmarks from src/math/big/int_test.go +// + +func BenchmarkBigIntBinomial(b *testing.B) { + var z BigInt + for i := b.N - 1; i >= 0; i-- { + z.Binomial(1000, 990) + } +} + +func BenchmarkBigIntQuoRem(b *testing.B) { + x, _ := new(BigInt).SetString("153980389784927331788354528594524332344709972855165340650588877572729725338415474372475094155672066328274535240275856844648695200875763869073572078279316458648124537905600131008790701752441155668003033945258023841165089852359980273279085783159654751552359397986180318708491098942831252291841441726305535546071", 0) + y, _ := new(BigInt).SetString("7746362281539803897849273317883545285945243323447099728551653406505888775727297253384154743724750941556720663282745352402758568446486952008757638690735720782793164586481245379056001310087907017524411556680030339452580238411650898523599802732790857831596547515523593979861803187084910989428312522918414417263055355460715745539358014631136245887418412633787074173796862711588221766398229333338511838891484974940633857861775630560092874987828057333663969469797013996401149696897591265769095952887917296740109742927689053276850469671231961384715398038978492733178835452859452433234470997285516534065058887757272972533841547437247509415567206632827453524027585684464869520087576386907357207827931645864812453790560013100879070175244115566800303394525802384116508985235998027327908578315965475155235939798618031870849109894283125229184144172630553554607112725169432413343763989564437170644270643461665184965150423819594083121075825", 0) + q := new(BigInt) + r := new(BigInt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.QuoRem(y, x, r) + } +} + +func BenchmarkBigIntExp(b *testing.B) { + x, _ := new(BigInt).SetString("11001289118363089646017359372117963499250546375269047542777928006103246876688756735760905680604646624353196869572752623285140408755420374049317646428185270079555372763503115646054602867593662923894140940837479507194934267532831694565516466765025434902348314525627418515646588160955862839022051353653052947073136084780742729727874803457643848197499548297570026926927502505634297079527299004267769780768565695459945235586892627059178884998772989397505061206395455591503771677500931269477503508150175717121828518985901959919560700853226255420793148986854391552859459511723547532575574664944815966793196961286234040892865", 0) + y, _ := new(BigInt).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", 0) + n, _ := new(BigInt).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", 0) + out := new(BigInt) + for i := 0; i < b.N; i++ { + out.Exp(x, y, n) + } +} + +func BenchmarkBigIntExp2(b *testing.B) { + x, _ := new(BigInt).SetString("2", 0) + y, _ := new(BigInt).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF72", 0) + n, _ := new(BigInt).SetString("0xAC6BDB41324A9A9BF166DE5E1389582FAF72B6651987EE07FC3192943DB56050A37329CBB4A099ED8193E0757767A13DD52312AB4B03310DCD7F48A9DA04FD50E8083969EDB767B0CF6095179A163AB3661A05FBD5FAAAE82918A9962F0B93B855F97993EC975EEAA80D740ADBF4FF747359D041D5C33EA71D281E446B14773BCA97B43A23FB801676BD207A436C6481F1D2B9078717461A5B9D32E688F87748544523B524B0D57D5EA77A2775D2ECFA032CFBDBF52FB3786160279004E57AE6AF874E7303CE53299CCC041C7BC308D82A5698F3A8D0C38271AE35F8E9DBFBB694B5C803D89F7AE435DE236D525F54759B65E372FCD68EF20FA7111F9E4AFF73", 0) + out := new(BigInt) + for i := 0; i < b.N; i++ { + out.Exp(x, y, n) + } +} + +func BenchmarkBigIntBitset(b *testing.B) { + z := new(BigInt) + z.SetBit(z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 1) + } +} + +func BenchmarkBigIntBitsetNeg(b *testing.B) { + z := NewBigInt(-1) + z.SetBit(z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + z.SetBit(z, i&512, 0) + } +} + +func BenchmarkBigIntBitsetOrig(b *testing.B) { + z := new(BigInt) + altSetBit(z, z, 512, 1) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 1) + } +} + +func BenchmarkBigIntBitsetNegOrig(b *testing.B) { + z := NewBigInt(-1) + altSetBit(z, z, 512, 0) + b.ResetTimer() + b.StartTimer() + for i := b.N - 1; i >= 0; i-- { + altSetBit(z, z, i&512, 0) + } +} + +func BenchmarkBigIntModInverse(b *testing.B) { + p := new(BigInt).SetInt64(1) // Mersenne prime 2**1279 -1 + p.Lsh(p, 1279) + p.Sub(p, bigOne) + x := new(BigInt).Sub(p, bigOne) + z := new(BigInt) + for i := 0; i < b.N; i++ { + z.ModInverse(x, p) + } +} + +func BenchmarkBigIntSqrt(b *testing.B) { + n, _ := new(BigInt).SetString("1"+strings.Repeat("0", 1001), 10) + b.ResetTimer() + t := new(BigInt) + for i := 0; i < b.N; i++ { + t.Sqrt(n) + } +} + +// randBigInt returns a pseudo-random Int in the range [1<<(size-1), (1< 1<<(size-1) +} + +func benchmarkBigIntDiv(b *testing.B, aSize, bSize int) { + var r = rand.New(rand.NewSource(1234)) + aa := randBigInt(r, uint(aSize)) + bb := randBigInt(r, uint(bSize)) + if aa.Cmp(bb) < 0 { + aa, bb = bb, aa + } + x := new(BigInt) + y := new(BigInt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.DivMod(aa, bb, y) + } +} + +func BenchmarkBigIntDiv(b *testing.B) { + sizes := []int{ + 10, 20, 50, 100, 200, 500, 1000, + 1e4, 1e5, 1e6, 1e7, + } + for _, i := range sizes { + j := 2 * i + b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) { + benchmarkBigIntDiv(b, j, i) + }) + } +} diff --git a/const.go b/const.go index 9a386e0..bbd0432 100644 --- a/const.go +++ b/const.go @@ -14,13 +14,13 @@ package apd -import "math/big" +import "math" var ( - bigOne = big.NewInt(1) - bigTwo = big.NewInt(2) - bigFive = big.NewInt(5) - bigTen = big.NewInt(10) + bigOne = NewBigInt(1) + bigTwo = NewBigInt(2) + bigFive = NewBigInt(5) + bigTen = NewBigInt(10) decimalZero = New(0, 0) decimalOneEighth = New(125, -3) @@ -30,6 +30,9 @@ var ( decimalThree = New(3, 0) decimalEight = New(8, 0) + decimalMaxInt64 = New(math.MaxInt64, 0) + decimalMinInt64 = New(math.MinInt64, 0) + decimalCbrtC1 = makeConst(strCbrtC1) decimalCbrtC2 = makeConst(strCbrtC2) decimalCbrtC3 = makeConst(strCbrtC3) diff --git a/context.go b/context.go index 2fb3eae..4b3a96c 100644 --- a/context.go +++ b/context.go @@ -16,7 +16,6 @@ package apd import ( "math" - "math/big" "github.com/pkg/errors" ) @@ -42,7 +41,7 @@ type Context struct { Traps Condition // Rounding specifies the Rounder to use during rounding. RoundHalfUp is used if // empty or not present in Roundings. - Rounding string + Rounding Rounder } const ( @@ -73,9 +72,10 @@ var BaseContext = Context{ // WithPrecision returns a copy of c but with the specified precision. func (c *Context) WithPrecision(p uint32) *Context { - r := *c + r := new(Context) + *r = *c r.Precision = p - return &r + return r } // goError converts flags into an error based on c.Traps. @@ -136,7 +136,8 @@ func (c *Context) add(d, x, y *Decimal, subtract bool) (Condition, error) { } return 0, nil } - a, b, s, err := upscale(x, y) + var tmp BigInt + a, b, s, err := upscale(x, y, &tmp) if err != nil { return 0, errors.Wrap(err, "add") } @@ -283,12 +284,13 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // An integer variable, adjust, is initialized to 0. var adjust int64 // The result coefficient is initialized to 0. - quo := new(Decimal) + var quo Decimal var res Condition var diff int64 if !x.IsZero() { - dividend := new(big.Int).Abs(&x.Coeff) - divisor := new(big.Int).Abs(&y.Coeff) + var dividend, divisor BigInt + dividend.Abs(&x.Coeff) + divisor.Abs(&y.Coeff) // The operand coefficients are adjusted so that the coefficient of the // dividend is greater than or equal to the coefficient of the divisor and @@ -296,20 +298,21 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // While the coefficient of the dividend is less than the coefficient of // the divisor it is multiplied by 10 and adjust is incremented by 1. - for dividend.Cmp(divisor) < 0 { - dividend.Mul(dividend, bigTen) + for dividend.Cmp(&divisor) < 0 { + dividend.Mul(÷nd, bigTen) adjust++ } // While the coefficient of the dividend is greater than or equal to ten // times the coefficient of the divisor the coefficient of the divisor is // multiplied by 10 and adjust is decremented by 1. - for tmp := new(big.Int); ; { - tmp.Mul(divisor, bigTen) - if dividend.Cmp(tmp) < 0 { + var tmp BigInt + for { + tmp.Mul(&divisor, bigTen) + if dividend.Cmp(&tmp) < 0 { break } - divisor.Set(tmp) + divisor.Set(&tmp) adjust-- } @@ -320,8 +323,8 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // While the coefficient of the divisor is smaller than or equal to the // coefficient of the dividend the former is subtracted from the latter and // the coefficient of the result is incremented by 1. - for divisor.Cmp(dividend) <= 0 { - dividend.Sub(dividend, divisor) + for divisor.Cmp(÷nd) <= 0 { + dividend.Sub(÷nd, &divisor) quo.Coeff.Add(&quo.Coeff, bigOne) } @@ -335,7 +338,7 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // Otherwise, the coefficients of the result and the dividend are multiplied // by 10 and adjust is incremented by 1. quo.Coeff.Mul(&quo.Coeff, bigTen) - dividend.Mul(dividend, bigTen) + dividend.Mul(÷nd, bigTen) adjust++ } @@ -346,10 +349,9 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // taken into account for rounding. if dividend.Sign() != 0 && adj >= int64(c.MinExponent) { res |= Inexact | Rounded - dividend.Mul(dividend, bigTwo) - half := dividend.Cmp(divisor) - rounding := c.rounding() - if rounding(&quo.Coeff, quo.Negative, half) { + dividend.Mul(÷nd, bigTwo) + half := dividend.Cmp(&divisor) + if c.Rounding.ShouldAddOne(&quo.Coeff, quo.Negative, half) { roundAddOne(&quo.Coeff, &diff) } } @@ -360,7 +362,7 @@ func (c *Context) Quo(d, x, y *Decimal) (Condition, error) { // the coefficient calculation from the original exponent of the dividend. res |= quo.setExponent(c, res, int64(x.Exponent), int64(-y.Exponent), -adjust, diff) quo.Negative = neg - d.Set(quo) + d.Set(&quo) return c.goError(res) } @@ -375,7 +377,8 @@ func (c *Context) QuoInteger(d, x, y *Decimal) (Condition, error) { neg := x.Negative != y.Negative var res Condition - a, b, _, err := upscale(x, y) + var tmp BigInt + a, b, _, err := upscale(x, y, &tmp) if err != nil { return 0, errors.Wrap(err, "QuoInteger") } @@ -416,13 +419,14 @@ func (c *Context) Rem(d, x, y *Decimal) (Condition, error) { d.Set(decimalNaN) return c.goError(res) } - a, b, s, err := upscale(x, y) + var tmp1 BigInt + a, b, s, err := upscale(x, y, &tmp1) if err != nil { return 0, errors.Wrap(err, "Rem") } - tmp := new(big.Int) - tmp.QuoRem(a, b, &d.Coeff) - if NumDigits(tmp) > int64(c.Precision) { + var tmp2 BigInt + tmp2.QuoRem(a, b, &d.Coeff) + if NumDigits(&tmp2) > int64(c.Precision) { d.Set(decimalNaN) return c.goError(DivisionImpossible) } @@ -485,7 +489,8 @@ func (c *Context) Sqrt(d, x *Decimal) (Condition, error) { workp = 7 } - f := new(Decimal).Set(x) + var f Decimal + f.Set(x) nd := x.NumDigits() e := nd + int64(x.Exponent) f.Exponent = int32(-nd) @@ -494,23 +499,23 @@ func (c *Context) Sqrt(d, x *Decimal) (Condition, error) { ed := MakeErrDecimal(nc) // Set approx to the first guess, based on whether e (the exponent part of x) // is odd or even. - approx := new(Decimal) + var approx Decimal if e%2 == 0 { approx.SetFinite(819, -3) - ed.Mul(approx, approx, f) - ed.Add(approx, approx, New(259, -3)) + ed.Mul(&approx, &approx, &f) + ed.Add(&approx, &approx, New(259, -3)) } else { f.Exponent-- e++ approx.SetFinite(259, -2) - ed.Mul(approx, approx, f) - ed.Add(approx, approx, New(819, -4)) + ed.Mul(&approx, &approx, &f) + ed.Add(&approx, &approx, New(819, -4)) } // Now we repeatedly improve approx. Our precision improves quadratically, // which we keep track of in p. p := uint32(3) - tmp := new(Decimal) + var tmp Decimal // The algorithm in the paper says to use c.Precision + 2. decNumber uses // workp + 2. But we use workp + 5 to make the tests pass. This means it is @@ -522,11 +527,11 @@ func (c *Context) Sqrt(d, x *Decimal) (Condition, error) { } nc.Precision = p // tmp = f / approx - ed.Quo(tmp, f, approx) + ed.Quo(&tmp, &f, &approx) // tmp = approx + f / approx - ed.Add(tmp, tmp, approx) + ed.Add(&tmp, &tmp, &approx) // approx = 0.5 * (approx + f / approx) - ed.Mul(approx, tmp, decimalHalf) + ed.Mul(&approx, &tmp, decimalHalf) } // At this point the paper says: "approx is now within 1 ulp of the properly @@ -542,7 +547,7 @@ func (c *Context) Sqrt(d, x *Decimal) (Condition, error) { return 0, err } - d.Set(approx) + d.Set(&approx) d.Exponent += int32(e / 2) nc.Precision = c.Precision nc.Rounding = RoundHalfEven @@ -561,12 +566,10 @@ func (c *Context) Cbrt(d, x *Decimal) (Condition, error) { return res, err } + var ax, z Decimal + ax.Abs(x) + z.Set(&ax) neg := x.Negative - ax := x - if x.Negative { - ax = new(Decimal).Abs(x) - } - z := new(Decimal).Set(ax) nc := BaseContext.WithPrecision(c.Precision*2 + 2) ed := MakeErrDecimal(nc) exp8 := 0 @@ -580,46 +583,47 @@ func (c *Context) Cbrt(d, x *Decimal) (Condition, error) { // x = z * 8^exp8 will hold. for z.Cmp(decimalOneEighth) < 0 { exp8-- - ed.Mul(z, z, decimalEight) + ed.Mul(&z, &z, decimalEight) } for z.Cmp(decimalOne) > 0 { exp8++ - ed.Mul(z, z, decimalOneEighth) + ed.Mul(&z, &z, decimalOneEighth) } // Use this polynomial to approximate the cube root between 0.125 and 1. // z = (-0.46946116 * z + 1.072302) * z + 0.3812513 // It will serve as an initial estimate, hence the precision of this // computation may only impact performance, not correctness. - z0 := new(Decimal).Set(z) - ed.Mul(z, z, decimalCbrtC1) - ed.Add(z, z, decimalCbrtC2) - ed.Mul(z, z, z0) - ed.Add(z, z, decimalCbrtC3) + var z0 Decimal + z0.Set(&z) + ed.Mul(&z, &z, decimalCbrtC1) + ed.Add(&z, &z, decimalCbrtC2) + ed.Mul(&z, &z, &z0) + ed.Add(&z, &z, decimalCbrtC3) for ; exp8 < 0; exp8++ { - ed.Mul(z, z, decimalHalf) + ed.Mul(&z, &z, decimalHalf) } for ; exp8 > 0; exp8-- { - ed.Mul(z, z, decimalTwo) + ed.Mul(&z, &z, decimalTwo) } // Loop until convergence. - for loop := nc.newLoop("cbrt", z, c.Precision+1, 1); ; { + for loop := nc.newLoop("cbrt", &z, c.Precision+1, 1); ; { // z = (2.0 * z0 + x / (z0 * z0) ) / 3.0; - z0.Set(z) - ed.Mul(z, z, z0) - ed.Quo(z, ax, z) - ed.Add(z, z, z0) - ed.Add(z, z, z0) - ed.Quo(z, z, decimalThree) + z0.Set(&z) + ed.Mul(&z, &z, &z0) + ed.Quo(&z, &ax, &z) + ed.Add(&z, &z, &z0) + ed.Add(&z, &z, &z0) + ed.Quo(&z, &z, decimalThree) if err := ed.Err(); err != nil { return 0, err } - if done, err := loop.done(z); err != nil { + if done, err := loop.done(&z); err != nil { return 0, err } else if done { break @@ -627,19 +631,19 @@ func (c *Context) Cbrt(d, x *Decimal) (Condition, error) { } z0.Set(x) - res, err := c.Round(d, z) + res, err := c.Round(d, &z) d.Negative = neg // Set z = d^3 to check for exactness. - ed.Mul(z, d, d) - ed.Mul(z, z, d) + ed.Mul(&z, d, d) + ed.Mul(&z, &z, d) if err := ed.Err(); err != nil { return 0, err } // Result is exact - if z0.Cmp(z) == 0 { + if z0.Cmp(&z) == 0 { return 0, nil } return res, err @@ -689,12 +693,8 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { nc.Rounding = RoundHalfEven ed := MakeErrDecimal(nc) - tmp1 := new(Decimal) - tmp2 := new(Decimal) - tmp3 := new(Decimal) - tmp4 := new(Decimal) - - z := new(Decimal).Set(x) + var tmp1, tmp2, tmp3, tmp4, z, resAdjust Decimal + z.Set(x) // To get an initial estimate, we first reduce the input range to the interval // [0.1, 1) by changing the exponent, and later adjust the result by a @@ -718,16 +718,14 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { // precision. So for z close to 1 (before scaling) we use a power series // instead (which converges very rapidly in this range). - resAdjust := new(Decimal) - // tmp1 = z - 1 - ed.Sub(tmp1, z, decimalOne) + ed.Sub(&tmp1, &z, decimalOne) // tmp3 = 0.1 tmp3.SetFinite(1, -1) usePowerSeries := false - if tmp2.Abs(tmp1).Cmp(tmp3) <= 0 { + if tmp2.Abs(&tmp1).Cmp(&tmp3) <= 0 { usePowerSeries = true } else { // Reduce input to range [0.1, 1). @@ -738,12 +736,12 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { // ln(10^expDelta) = expDelta * ln(10) // to the result. resAdjust.setCoefficient(int64(expDelta)) - ed.Mul(resAdjust, resAdjust, decimalLn10.get(p)) + ed.Mul(&resAdjust, &resAdjust, decimalLn10.get(p)) // tmp1 = z - 1 - ed.Sub(tmp1, z, decimalOne) + ed.Sub(&tmp1, &z, decimalOne) - if tmp2.Abs(tmp1).Cmp(tmp3) <= 0 { + if tmp2.Abs(&tmp1).Cmp(&tmp3) <= 0 { usePowerSeries = true } else { // Compute an initial estimate using floats. @@ -768,30 +766,32 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { // tmp1 is already x // tmp3 = x + 2 - ed.Add(tmp3, tmp1, decimalTwo) + ed.Add(&tmp3, &tmp1, decimalTwo) // tmp2 = (x / (x+2)) - ed.Quo(tmp2, tmp1, tmp3) + ed.Quo(&tmp2, &tmp1, &tmp3) // tmp1 = tmp3 = 2 * (x / (x+2)) - ed.Add(tmp3, tmp2, tmp2) - tmp1.Set(tmp3) + ed.Add(&tmp3, &tmp2, &tmp2) + tmp1.Set(&tmp3) - eps := Decimal{Coeff: *bigOne, Exponent: -int32(p)} + var eps Decimal + eps.Coeff.Set(bigOne) + eps.Exponent = -int32(p) for n := 1; ; n++ { // tmp3 *= (x / (x+2))^2 - ed.Mul(tmp3, tmp3, tmp2) - ed.Mul(tmp3, tmp3, tmp2) + ed.Mul(&tmp3, &tmp3, &tmp2) + ed.Mul(&tmp3, &tmp3, &tmp2) // tmp4 = 2n+1 tmp4.SetFinite(int64(2*n+1), 0) - ed.Quo(tmp4, tmp3, tmp4) + ed.Quo(&tmp4, &tmp3, &tmp4) - ed.Add(tmp1, tmp1, tmp4) + ed.Add(&tmp1, &tmp1, &tmp4) - if tmp4.Abs(tmp4).Cmp(&eps) <= 0 { + if tmp4.Abs(&tmp4).Cmp(&eps) <= 0 { break } } @@ -803,24 +803,24 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { // tmp1 = a_n (either from initial estimate or last iteration) // tmp2 = exp(a_n) - ed.Exp(tmp2, tmp1) + ed.Exp(&tmp2, &tmp1) // tmp3 = exp(a_n) - z - ed.Sub(tmp3, tmp2, z) + ed.Sub(&tmp3, &tmp2, &z) // tmp3 = 2 * (exp(a_n) - z) - ed.Add(tmp3, tmp3, tmp3) + ed.Add(&tmp3, &tmp3, &tmp3) // tmp4 = exp(a_n) + z - ed.Add(tmp4, tmp2, z) + ed.Add(&tmp4, &tmp2, &z) // tmp2 = 2 * (exp(a_n) - z) / (exp(a_n) + z) - ed.Quo(tmp2, tmp3, tmp4) + ed.Quo(&tmp2, &tmp3, &tmp4) // tmp1 = a_(n+1) = a_n - 2 * (exp(a_n) - z) / (exp(a_n) + z) - ed.Sub(tmp1, tmp1, tmp2) + ed.Sub(&tmp1, &tmp1, &tmp2) - if done, err := loop.done(tmp1); err != nil { + if done, err := loop.done(&tmp1); err != nil { return 0, err } else if done { break @@ -832,12 +832,12 @@ func (c *Context) Ln(d, x *Decimal) (Condition, error) { } // Apply the adjustment due to the initial rescaling. - ed.Add(tmp1, tmp1, resAdjust) + ed.Add(&tmp1, &tmp1, &resAdjust) if err := ed.Err(); err != nil { return 0, err } - res := c.round(d, tmp1) + res := c.round(d, &tmp1) res |= Inexact return c.goError(res) } @@ -853,14 +853,14 @@ func (c *Context) Log10(d, x *Decimal) (Condition, error) { nc := BaseContext.WithPrecision(c.Precision + 2) nc.Rounding = RoundHalfEven - z := new(Decimal) - _, err := nc.Ln(z, x) + var z Decimal + _, err := nc.Ln(&z, x) if err != nil { return 0, errors.Wrap(err, "ln") } nc.Precision = c.Precision - qr, err := nc.Mul(d, z, decimalInvLn10.get(c.Precision+2)) + qr, err := nc.Mul(d, &z, decimalInvLn10.get(c.Precision+2)) if err != nil { return 0, err } @@ -898,7 +898,8 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { // Stage 1 cp := c.Precision - tmp1 := new(Decimal).Abs(x) + var tmp1 Decimal + tmp1.Abs(x) if f, err := tmp1.Float64(); err == nil { // This algorithm doesn't work if currentprecision*23 < |x|. Attempt to // increase the working precision if needed as long as it isn't too large. If @@ -907,9 +908,10 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { cp = uint32(math.Ceil(ncp)) } } - tmp2 := New(int64(cp)*23, 0) + var tmp2 Decimal + tmp2.SetInt64(int64(cp) * 23) // if abs(x) > 23*currentprecision; assert false - if tmp1.Cmp(tmp2) > 0 { + if tmp1.Cmp(&tmp2) > 0 { res |= Overflow if x.Sign() < 0 { res = res.negateOverflowFlags() @@ -922,7 +924,7 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { } // if abs(x) <= setexp(.9, -currentprecision); then result 1 tmp2.SetFinite(9, int32(-cp)-1) - if tmp1.Cmp(tmp2) <= 0 { + if tmp1.Cmp(&tmp2) <= 0 { d.Set(decimalOne) return c.goError(res) } @@ -933,14 +935,15 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { if t < 0 { t = 0 } - k := New(1, t) - r := new(Decimal) + var k, r Decimal + k.SetFinite(1, t) nc := c.WithPrecision(cp) nc.Rounding = RoundHalfEven - if _, err := nc.Quo(r, x, k); err != nil { + if _, err := nc.Quo(&r, x, &k); err != nil { return 0, errors.Wrap(err, "Quo") } - ra := new(Decimal).Abs(r) + var ra Decimal + ra.Abs(&r) p := int64(cp) + int64(t) + 2 // Stage 3 @@ -958,27 +961,29 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { // Stage 4 nc.Precision = uint32(p) ed := MakeErrDecimal(nc) - sum := New(1, 0) + var sum Decimal + sum.SetInt64(1) tmp2.Exponent = 0 for i := n - 1; i > 0; i-- { tmp2.setCoefficient(i) // tmp1 = r / i - ed.Quo(tmp1, r, tmp2) + ed.Quo(&tmp1, &r, &tmp2) // sum = sum * r / i - ed.Mul(sum, tmp1, sum) + ed.Mul(&sum, &tmp1, &sum) // sum = sum + 1 - ed.Add(sum, sum, decimalOne) + ed.Add(&sum, &sum, decimalOne) } if err != ed.Err() { return 0, err } // sum ** k - ki, err := exp10(int64(t)) + var tmpE BigInt + ki, err := exp10(int64(t), &tmpE) if err != nil { return 0, errors.Wrap(err, "ki") } - ires, err := nc.integerPower(d, sum, ki) + ires, err := nc.integerPower(d, &sum, ki) if err != nil { return 0, errors.Wrap(err, "integer power") } @@ -989,29 +994,31 @@ func (c *Context) Exp(d, x *Decimal) (Condition, error) { } // integerPower sets d = x**y. d and x must not point to the same Decimal. -func (c *Context) integerPower(d, x *Decimal, y *big.Int) (Condition, error) { +func (c *Context) integerPower(d, x *Decimal, y *BigInt) (Condition, error) { // See: https://en.wikipedia.org/wiki/Exponentiation_by_squaring. - b := new(big.Int).Set(y) + var b BigInt + b.Set(y) neg := b.Sign() < 0 if neg { - b.Abs(b) + b.Abs(&b) } - n, z := new(Decimal), d + var n Decimal n.Set(x) + z := d z.Set(decimalOne) ed := MakeErrDecimal(c) for b.Sign() > 0 { if b.Bit(0) == 1 { - ed.Mul(z, z, n) + ed.Mul(z, z, &n) } - b.Rsh(b, 1) + b.Rsh(&b, 1) // Only compute the next n if we are going to use it. Otherwise n can overflow // on the last iteration causing this to error. if b.Sign() > 0 { - ed.Mul(n, n, n) + ed.Mul(&n, &n, &n) } if err := ed.Err(); err != nil { // In the negative case, convert overflow to underflow. @@ -1034,8 +1041,8 @@ func (c *Context) Pow(d, x, y *Decimal) (Condition, error) { return res, err } - integ, frac := new(Decimal), new(Decimal) - y.Modf(integ, frac) + var integ, frac Decimal + y.Modf(&integ, &frac) yIsInt := frac.IsZero() neg := x.Negative && y.Form == Finite && yIsInt && integ.Coeff.Bit(0) == 1 && integ.Exponent == 0 @@ -1056,7 +1063,8 @@ func (c *Context) Pow(d, x, y *Decimal) (Condition, error) { } // Check if y is of type int. - tmp := new(Decimal).Abs(y) + var tmp Decimal + tmp.Abs(y) xs := x.Sign() ys := y.Sign() @@ -1101,7 +1109,7 @@ func (c *Context) Pow(d, x, y *Decimal) (Condition, error) { } // If integ.Exponent > 0, we need to add trailing 0s to integ.Coeff. - res := c.quantize(integ, integ, 0) + res := c.quantize(&integ, &integ, 0) nres, err := nc.integerPower(z, x, integ.setBig(&integ.Coeff)) res |= nres if err != nil { @@ -1117,18 +1125,18 @@ func (c *Context) Pow(d, x, y *Decimal) (Condition, error) { ed := MakeErrDecimal(nc) // Compute x**frac(y) - ed.Abs(tmp, x) - ed.Ln(tmp, tmp) - ed.Mul(tmp, tmp, frac) - ed.Exp(tmp, tmp) + ed.Abs(&tmp, x) + ed.Ln(&tmp, &tmp) + ed.Mul(&tmp, &tmp, &frac) + ed.Exp(&tmp, &tmp) // Join integer and frac parts back. - ed.Mul(tmp, z, tmp) + ed.Mul(&tmp, z, &tmp) if err := ed.Err(); err != nil { return ed.Flags, err } - res |= c.round(d, tmp) + res |= c.round(d, &tmp) d.Negative = neg res |= Inexact return c.goError(res) @@ -1166,7 +1174,8 @@ func (c *Context) quantize(d, v *Decimal, exp int32) Condition { if diff < MinExponent { return SystemUnderflow | Underflow } - d.Coeff.Mul(&d.Coeff, tableExp10(-int64(diff), nil)) + var tmpE BigInt + d.Coeff.Mul(&d.Coeff, tableExp10(-int64(diff), &tmpE)) } else if diff > 0 { p := int32(d.NumDigits()) - diff if p < 0 { @@ -1193,7 +1202,7 @@ func (c *Context) quantize(d, v *Decimal, exp int32) Condition { d.Exponent = -diff // Avoid the c.Precision == 0 check. - res = nc.rounding().Round(nc, d, d) + res = nc.Rounding.Round(nc, d, d) // Adjust for 0.9 -> 1.0 rollover. if d.Exponent > 0 { d.Coeff.Mul(&d.Coeff, bigTen) @@ -1242,8 +1251,8 @@ func (c *Context) RoundToIntegralExact(d, x *Decimal) (Condition, error) { // Ceil sets d to the smallest integer >= x. func (c *Context) Ceil(d, x *Decimal) (Condition, error) { - frac := new(Decimal) - x.Modf(d, frac) + var frac Decimal + x.Modf(d, &frac) if frac.Sign() > 0 { return c.Add(d, d, decimalOne) } @@ -1252,8 +1261,8 @@ func (c *Context) Ceil(d, x *Decimal) (Condition, error) { // Floor sets d to the largest integer <= x. func (c *Context) Floor(d, x *Decimal) (Condition, error) { - frac := new(Decimal) - x.Modf(d, frac) + var frac Decimal + x.Modf(d, &frac) if frac.Sign() < 0 { return c.Sub(d, d, decimalOne) } @@ -1274,9 +1283,10 @@ func (c *Context) Reduce(d, x *Decimal) (int, Condition, error) { } // exp10 returns x, 10^x. An error is returned if x is too large. -func exp10(x int64) (exp *big.Int, err error) { +// The returned value must not be mutated. +func exp10(x int64, tmp *BigInt) (exp *BigInt, err error) { if x > MaxExponent || x < MinExponent { return nil, errors.New(errExponentOutOfRangeStr) } - return tableExp10(x, nil), nil + return tableExp10(x, tmp), nil } diff --git a/decimal.go b/decimal.go index 4d36872..443ccc6 100644 --- a/decimal.go +++ b/decimal.go @@ -15,12 +15,11 @@ package apd import ( - "database/sql/driver" - "math" - "math/big" "strconv" "strings" + "unsafe" + "database/sql/driver" "github.com/pkg/errors" ) @@ -34,11 +33,11 @@ type Decimal struct { Form Form Negative bool Exponent int32 - Coeff big.Int + Coeff BigInt } // Form specifies the form of a Decimal. -type Form int +type Form int8 const ( // These constants must be in the following order. CmpTotal assumes that @@ -78,25 +77,20 @@ const ( // New creates a new decimal with the given coefficient and exponent. func New(coeff int64, exponent int32) *Decimal { - d := &Decimal{ - Negative: coeff < 0, - Coeff: *big.NewInt(coeff), - Exponent: exponent, - } - d.Coeff.Abs(&d.Coeff) + d := new(Decimal) + d.SetFinite(coeff, exponent) return d } // NewWithBigInt creates a new decimal with the given coefficient and exponent. -func NewWithBigInt(coeff *big.Int, exponent int32) *Decimal { - d := &Decimal{ - Exponent: exponent, - } +func NewWithBigInt(coeff *BigInt, exponent int32) *Decimal { + d := new(Decimal) d.Coeff.Set(coeff) if d.Coeff.Sign() < 0 { d.Negative = true d.Coeff.Abs(&d.Coeff) } + d.Exponent = exponent return d } @@ -243,21 +237,22 @@ func (d *Decimal) SetFloat64(f float64) (*Decimal, error) { return d, err } -// Int64 returns the int64 representation of x. If x cannot be represented in an int64, an error is returned. +// Int64 returns the int64 representation of x. If x cannot be represented in an +// int64, an error is returned. func (d *Decimal) Int64() (int64, error) { if d.Form != Finite { return 0, errors.Errorf("%s is not finite", d.String()) } - integ, frac := new(Decimal), new(Decimal) - d.Modf(integ, frac) + var integ, frac Decimal + d.Modf(&integ, &frac) if !frac.IsZero() { return 0, errors.Errorf("%s: has fractional part", d.String()) } var ed ErrDecimal - if integ.Cmp(New(math.MaxInt64, 0)) > 0 { + if integ.Cmp(decimalMaxInt64) > 0 { return 0, errors.Errorf("%s: greater than max int64", d.String()) } - if integ.Cmp(New(math.MinInt64, 0)) < 0 { + if integ.Cmp(decimalMinInt64) < 0 { return 0, errors.Errorf("%s: less than min int64", d.String()) } if err := ed.Err(); err != nil { @@ -325,17 +320,15 @@ func (d *Decimal) setExponent(c *Context, res Condition, xs ...int64) Condition // fractional parts and do operations similar Round. We avoid calling Round // directly because it calls setExponent and modifies the result's exponent // and coeff in ways that would be wrong here. - b := new(big.Int).Set(&d.Coeff) - tmp := &Decimal{ - Coeff: *b, - Exponent: r - Etiny, - } - integ, frac := new(Decimal), new(Decimal) - tmp.Modf(integ, frac) - frac.Abs(frac) + var tmp Decimal + tmp.Coeff.Set(&d.Coeff) + tmp.Exponent = r - Etiny + var integ, frac Decimal + tmp.Modf(&integ, &frac) + frac.Abs(&frac) if !frac.IsZero() { res |= Inexact - if c.rounding()(&integ.Coeff, integ.Negative, frac.Cmp(decimalHalf)) { + if c.Rounding.ShouldAddOne(&integ.Coeff, integ.Negative, frac.Cmp(decimalHalf)) { integ.Coeff.Add(&integ.Coeff, bigOne) } } @@ -343,7 +336,7 @@ func (d *Decimal) setExponent(c *Context, res Condition, xs ...int64) Condition res |= Clamped } r = Etiny - d.Coeff = integ.Coeff + d.Coeff.Set(&integ.Coeff) res |= Rounded } } else if v > c.MaxExponent { @@ -364,10 +357,11 @@ func (d *Decimal) setExponent(c *Context, res Condition, xs ...int64) Condition return res } -// upscale converts a and b to big.Ints with the same scaling. It returns +// upscale converts a and b to BigInts with the same scaling. It returns // them with this scaling, along with the scaling. An error can be produced -// if the resulting scale factor is out of range. -func upscale(a, b *Decimal) (*big.Int, *big.Int, int32, error) { +// if the resulting scale factor is out of range. The tmp argument must be +// provided and can be (but won't always be) one of the return values. +func upscale(a, b *Decimal, tmp *BigInt) (*BigInt, *BigInt, int32, error) { if a.Exponent == b.Exponent { return &a.Coeff, &b.Coeff, a.Exponent, nil } @@ -382,7 +376,7 @@ func upscale(a, b *Decimal) (*big.Int, *big.Int, int32, error) { if s > MaxExponent { return nil, nil, 0, errors.New(errExponentOutOfRangeStr) } - x := new(big.Int) + x := tmp e := tableExp10(s, x) x.Mul(&a.Coeff, e) y := &b.Coeff @@ -393,7 +387,7 @@ func upscale(a, b *Decimal) (*big.Int, *big.Int, int32, error) { } // setBig sets b to d's coefficient with negative. -func (d *Decimal) setBig(b *big.Int) *big.Int { +func (d *Decimal) setBig(b *BigInt) *BigInt { b.Set(&d.Coeff) if d.Negative { b.Neg(b) @@ -556,7 +550,7 @@ func (d *Decimal) Cmp(x *Decimal) int { return gt } - // Now have to use aligned big.Ints. This function previously used upscale to + // Now have to use aligned BigInts. This function previously used upscale to // align in all cases, but that requires an error in the return value. upscale // does that so that it can fail if it needs to take the Exp of too-large a // number, which is very slow. The only way for that to happen here is for d @@ -566,14 +560,14 @@ func (d *Decimal) Cmp(x *Decimal) int { var cmp int if d.Exponent < x.Exponent { - var xScaled big.Int + var xScaled, tmpE BigInt xScaled.Set(&x.Coeff) - xScaled.Mul(&xScaled, tableExp10(int64(x.Exponent)-int64(d.Exponent), nil)) + xScaled.Mul(&xScaled, tableExp10(int64(x.Exponent)-int64(d.Exponent), &tmpE)) cmp = d.Coeff.Cmp(&xScaled) } else { - var dScaled big.Int + var dScaled, tmpE BigInt dScaled.Set(&d.Coeff) - dScaled.Mul(&dScaled, tableExp10(int64(d.Exponent)-int64(x.Exponent), nil)) + dScaled.Mul(&dScaled, tableExp10(int64(d.Exponent)-int64(x.Exponent), &tmpE)) cmp = dScaled.Cmp(&x.Coeff) } if ds < 0 { @@ -646,9 +640,10 @@ func (d *Decimal) Modf(integ, frac *Decimal) { return } - e := tableExp10(exp, nil) + var tmpE BigInt + e := tableExp10(exp, &tmpE) - var icoeff *big.Int + var icoeff *BigInt if integ != nil { icoeff = &integ.Coeff integ.Exponent = 0 @@ -656,7 +651,7 @@ func (d *Decimal) Modf(integ, frac *Decimal) { } else { // This is the integ == nil branch, and we already checked if both integ and // frac were nil above, so frac can never be nil in this branch. - icoeff = new(big.Int) + icoeff = new(BigInt) } if frac != nil { @@ -728,12 +723,12 @@ func (d *Decimal) Reduce(x *Decimal) (*Decimal, int) { // Divide by 10 in a loop. In benchmarks of reduce0.decTest, this is 20% // faster than converting to a string and trimming the 0s from the end. - z := d.setBig(new(big.Int)) - r := new(big.Int) + var z, r BigInt + d.setBig(&z) for { - z.QuoRem(&d.Coeff, bigTen, r) + z.QuoRem(&d.Coeff, bigTen, &r) if r.Sign() == 0 { - d.Coeff.Set(z) + d.Coeff.Set(&z) nd++ } else { break @@ -743,6 +738,13 @@ func (d *Decimal) Reduce(x *Decimal) (*Decimal, int) { return d, nd } +const decimalSize = unsafe.Sizeof(Decimal{}) + +// Size returns the total memory footprint of d in bytes. +func (d *Decimal) Size() uintptr { + return decimalSize - bigIntSize + d.Coeff.Size() +} + // Value implements the database/sql/driver.Valuer interface. It converts d to a // string. func (d Decimal) Value() (driver.Value, error) { diff --git a/decimal_test.go b/decimal_test.go index 1bc2b63..29342e3 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -18,7 +18,6 @@ import ( "encoding/json" "fmt" "math" - "math/big" "testing" "unsafe" ) @@ -65,7 +64,7 @@ func TestNewWithBigInt(t *testing.T) { if err != nil { t.Fatal(err) } - b, ok := new(big.Int).SetString(tc, 10) + b, ok := new(BigInt).SetString(tc, 10) if !ok { t.Fatal("bad bigint") } @@ -74,7 +73,7 @@ func TestNewWithBigInt(t *testing.T) { t.Fatal("unexpected negative coeff") } // Verify that changing b doesn't change d. - b.Set(big.NewInt(1234)) + b.Set(NewBigInt(1234)) if d.CmpTotal(expect) != 0 { t.Fatalf("expected %s, got %s", expect, d) } @@ -85,19 +84,19 @@ func TestNewWithBigInt(t *testing.T) { func TestUpscale(t *testing.T) { tests := []struct { x, y *Decimal - a, b *big.Int + a, b *BigInt s int32 }{ - {x: New(1, 0), y: New(100, -1), a: big.NewInt(10), b: big.NewInt(100), s: -1}, - {x: New(1, 0), y: New(10, -1), a: big.NewInt(10), b: big.NewInt(10), s: -1}, - {x: New(1, 0), y: New(10, 0), a: big.NewInt(1), b: big.NewInt(10), s: 0}, - {x: New(1, 1), y: New(1, 0), a: big.NewInt(10), b: big.NewInt(1), s: 0}, - {x: New(10, -2), y: New(1, -1), a: big.NewInt(10), b: big.NewInt(10), s: -2}, - {x: New(1, -2), y: New(100, 1), a: big.NewInt(1), b: big.NewInt(100000), s: -2}, + {x: New(1, 0), y: New(100, -1), a: NewBigInt(10), b: NewBigInt(100), s: -1}, + {x: New(1, 0), y: New(10, -1), a: NewBigInt(10), b: NewBigInt(10), s: -1}, + {x: New(1, 0), y: New(10, 0), a: NewBigInt(1), b: NewBigInt(10), s: 0}, + {x: New(1, 1), y: New(1, 0), a: NewBigInt(10), b: NewBigInt(1), s: 0}, + {x: New(10, -2), y: New(1, -1), a: NewBigInt(10), b: NewBigInt(10), s: -2}, + {x: New(1, -2), y: New(100, 1), a: NewBigInt(1), b: NewBigInt(100000), s: -2}, } for _, tc := range tests { t.Run(fmt.Sprintf("%s, %s", tc.x, tc.y), func(t *testing.T) { - a, b, s, err := upscale(tc.x, tc.y) + a, b, s, err := upscale(tc.x, tc.y, new(BigInt)) if err != nil { t.Fatal(err) } @@ -786,10 +785,14 @@ func TestReduce(t *testing.T) { } // TestSizeof is meant to catch changes that unexpectedly increase -// the size of the Decimal struct. +// the size of the BigInt, Decimal, and Context structs. func TestSizeof(t *testing.T) { + var b BigInt + if s := unsafe.Sizeof(b); s != 24 { + t.Errorf("sizeof(BigInt) changed: %d", s) + } var d Decimal - if s := unsafe.Sizeof(d); s != 48 { + if s := unsafe.Sizeof(d); s != 32 { t.Errorf("sizeof(Decimal) changed: %d", s) } var c Context @@ -798,6 +801,41 @@ func TestSizeof(t *testing.T) { } } +// TestSize tests the Size method on BigInt and Decimal. Unlike Sizeof, which +// returns the shallow size of the structs, the Size method reports the total +// memory footprint of each struct and all referenced objects. +func TestSize(t *testing.T) { + var d Decimal + if e, s := uintptr(32), d.Size(); e != s { + t.Errorf("(*Decimal).Size() != %d: %d", e, s) + } + if e, s := uintptr(24), d.Coeff.Size(); e != s { + t.Errorf("(*BigInt).Size() != %d: %d", e, s) + } + // Set to an inlinable value. + d.SetInt64(1234) + if e, s := uintptr(32), d.Size(); e != s { + t.Errorf("(*Decimal).Size() != %d: %d", e, s) + } + if e, s := uintptr(24), d.Coeff.Size(); e != s { + t.Errorf("(*BigInt).Size() != %d: %d", e, s) + } + // Set to a non-inlinable value. + if _, _, err := d.SetString("123456789123456789123456789.123456789123456789"); err != nil { + t.Fatal(err) + } + if d.Coeff.isInline() { + // Sanity-check, in case inlineWords changes. + t.Fatal("BigInt inlined large value. Did inlineWords change?") + } + if e, s := uintptr(120), d.Size(); e != s { + t.Errorf("(*Decimal).Size() != %d: %d", e, s) + } + if e, s := uintptr(112), d.Coeff.Size(); e != s { + t.Errorf("(*BigInt).Size() != %d: %d", e, s) + } +} + func TestJSONEncoding(t *testing.T) { var encodingTests = []string{ "0", diff --git a/error.go b/error.go index 9dff9dc..4f3a24f 100644 --- a/error.go +++ b/error.go @@ -44,49 +44,55 @@ func (e *ErrDecimal) Err() error { return nil } -func (e *ErrDecimal) op2(d, x *Decimal, f func(a, b *Decimal) (Condition, error)) *Decimal { - if e.Err() != nil { - return d - } - res, err := f(d, x) +// update adjusts the ErrDecimal's state with the result of an operation. +func (e *ErrDecimal) update(res Condition, err error) { e.Flags |= res e.err = err - return d } -func (e *ErrDecimal) op3(d, x, y *Decimal, f func(a, b, c *Decimal) (Condition, error)) *Decimal { +// Abs performs e.Ctx.Abs(d, x) and returns d. +func (e *ErrDecimal) Abs(d, x *Decimal) *Decimal { if e.Err() != nil { return d } - res, err := f(d, x, y) - e.Flags |= res - e.err = err + e.update(e.Ctx.Abs(d, x)) return d } -// Abs performs e.Ctx.Abs(d, x) and returns d. -func (e *ErrDecimal) Abs(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Abs) -} - // Add performs e.Ctx.Add(d, x, y) and returns d. func (e *ErrDecimal) Add(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Add) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Add(d, x, y)) + return d } // Ceil performs e.Ctx.Ceil(d, x) and returns d. func (e *ErrDecimal) Ceil(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Ceil) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Ceil(d, x)) + return d } // Exp performs e.Ctx.Exp(d, x) and returns d. func (e *ErrDecimal) Exp(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Exp) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Exp(d, x)) + return d } // Floor performs e.Ctx.Floor(d, x) and returns d. func (e *ErrDecimal) Floor(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Floor) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Floor(d, x)) + return d } // Int64 returns 0 if err is set. Otherwise returns d.Int64(). @@ -101,27 +107,47 @@ func (e *ErrDecimal) Int64(d *Decimal) int64 { // Ln performs e.Ctx.Ln(d, x) and returns d. func (e *ErrDecimal) Ln(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Ln) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Ln(d, x)) + return d } // Log10 performs d.Log10(x) and returns d. func (e *ErrDecimal) Log10(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Log10) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Log10(d, x)) + return d } // Mul performs e.Ctx.Mul(d, x, y) and returns d. func (e *ErrDecimal) Mul(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Mul) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Mul(d, x, y)) + return d } // Neg performs e.Ctx.Neg(d, x) and returns d. func (e *ErrDecimal) Neg(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Neg) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Neg(d, x)) + return d } // Pow performs e.Ctx.Pow(d, x, y) and returns d. func (e *ErrDecimal) Pow(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Pow) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Pow(d, x, y)) + return d } // Quantize performs e.Ctx.Quantize(d, v, exp) and returns d. @@ -129,20 +155,26 @@ func (e *ErrDecimal) Quantize(d, v *Decimal, exp int32) *Decimal { if e.Err() != nil { return d } - res, err := e.Ctx.Quantize(d, v, exp) - e.Flags |= res - e.err = err + e.update(e.Ctx.Quantize(d, v, exp)) return d } // Quo performs e.Ctx.Quo(d, x, y) and returns d. func (e *ErrDecimal) Quo(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Quo) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Quo(d, x, y)) + return d } // QuoInteger performs e.Ctx.QuoInteger(d, x, y) and returns d. func (e *ErrDecimal) QuoInteger(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.QuoInteger) + if e.Err() != nil { + return d + } + e.update(e.Ctx.QuoInteger(d, x, y)) + return d } // Reduce performs e.Ctx.Reduce(d, x) and returns the number of zeros removed @@ -152,37 +184,60 @@ func (e *ErrDecimal) Reduce(d, x *Decimal) (int, *Decimal) { return 0, d } n, res, err := e.Ctx.Reduce(d, x) - e.Flags |= res - e.err = err + e.update(res, err) return n, d } // Rem performs e.Ctx.Rem(d, x, y) and returns d. func (e *ErrDecimal) Rem(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Rem) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Rem(d, x, y)) + return d } // Round performs e.Ctx.Round(d, x) and returns d. func (e *ErrDecimal) Round(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Round) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Round(d, x)) + return d } // Sqrt performs e.Ctx.Sqrt(d, x) and returns d. func (e *ErrDecimal) Sqrt(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.Sqrt) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Sqrt(d, x)) + return d } // Sub performs e.Ctx.Sub(d, x, y) and returns d. func (e *ErrDecimal) Sub(d, x, y *Decimal) *Decimal { - return e.op3(d, x, y, e.Ctx.Sub) + if e.Err() != nil { + return d + } + e.update(e.Ctx.Sub(d, x, y)) + return d } // RoundToIntegralValue performs e.Ctx.RoundToIntegralValue(d, x) and returns d. func (e *ErrDecimal) RoundToIntegralValue(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.RoundToIntegralValue) + if e.Err() != nil { + return d + } + e.update(e.Ctx.RoundToIntegralValue(d, x)) + return d } // RoundToIntegralExact performs e.Ctx.RoundToIntegralExact(d, x) and returns d. func (e *ErrDecimal) RoundToIntegralExact(d, x *Decimal) *Decimal { - return e.op2(d, x, e.Ctx.RoundToIntegralExact) + if e.Err() != nil { + return d + } + e.update(e.Ctx.RoundToIntegralExact(d, x)) + return d } diff --git a/example_test.go b/example_test.go index c8f6e98..70bd264 100644 --- a/example_test.go +++ b/example_test.go @@ -17,7 +17,7 @@ package apd_test import ( "fmt" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" ) // ExampleOverflow demonstrates how to detect or error on overflow. diff --git a/gda_test.go b/gda_test.go index 087c802..54681bd 100644 --- a/gda_test.go +++ b/gda_test.go @@ -296,32 +296,52 @@ func (tc TestCase) Run(c *Context, done chan error, d, x, y *Decimal) (res Condi func BenchmarkGDA(b *testing.B) { for _, fname := range GDAfiles { b.Run(fname, func(b *testing.B) { - b.StopTimer() + type benchCase struct { + tc TestCase + ctx *Context + ops [2]*Decimal + } _, tcs := readGDA(b, fname) - res := new(Decimal) - for i := 0; i < b.N; i++ { - Loop: - for _, tc := range tcs { - if GDAignore[tc.ID] || tc.Result == "?" || tc.HasNull() { - continue - } - switch tc.Operation { - case "apply", "toeng": - continue - } - operands := make([]*Decimal, 2) - for i, o := range tc.Operands { - d, _, err := NewFromString(o) - if err != nil { - continue Loop - } - operands[i] = d + bcs := make([]benchCase, 0, len(tcs)) + Loop: + for _, tc := range tcs { + if GDAignore[tc.ID] || tc.Result == "?" || tc.HasNull() { + continue + } + switch tc.Operation { + case "apply", "toeng": + continue + } + bc := benchCase{ + tc: tc, + ctx: tc.Context(b), + } + for i, o := range tc.Operands { + d, _, err := NewFromString(o) + if err != nil { + continue Loop } - c := tc.Context(b) - b.StartTimer() + bc.ops[i] = d + } + bcs = append(bcs, bc) + } + + // Translate inputs and outputs to Decimal vectors. + op1s := make([]Decimal, len(bcs)) + op2s := make([]Decimal, len(bcs)) + res := make([]Decimal, b.N*len(bcs)) + for i, bc := range bcs { + op1s[i].Set(bc.ops[0]) + if bc.ops[1] != nil { + op2s[i].Set(bc.ops[1]) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j, bc := range bcs { // Ignore errors here because the full tests catch them. - _, _ = tc.Run(c, nil, res, operands[0], operands[1]) - b.StopTimer() + _, _ = bc.tc.Run(bc.ctx, nil, &res[i*len(bcs)+j], &op1s[j], &op2s[j]) } } }) @@ -343,15 +363,15 @@ func readGDA(t testing.TB, name string) (string, []TestCase) { } func (tc TestCase) Context(t testing.TB) *Context { - _, ok := Roundings[tc.Rounding] - if !ok { + rounding := Rounder(tc.Rounding) + if _, ok := roundings[rounding]; !ok { t.Fatalf("unsupported rounding mode %s", tc.Rounding) } c := &Context{ Precision: uint32(tc.Precision), MaxExponent: int32(tc.MaxExponent), MinExponent: int32(tc.MinExponent), - Rounding: tc.Rounding, + Rounding: rounding, Traps: 0, } return c @@ -400,10 +420,6 @@ func gdaTest(t *testing.T, path string, tcs []TestCase) { t.Logf("%s:/^%s ", path, tc.ID) t.Logf("%s %s = %s (%s)", tc.Operation, strings.Join(tc.Operands, " "), tc.Result, strings.Join(tc.Conditions, " ")) t.Logf("prec: %d, round: %s, Emax: %d, Emin: %d", tc.Precision, tc.Rounding, tc.MaxExponent, tc.MinExponent) - _, ok := Roundings[tc.Rounding] - if !ok { - t.Fatalf("unsupported rounding mode %s", tc.Rounding) - } operands := make([]*Decimal, 2) c := tc.Context(t) var res, opres Condition diff --git a/go.mod b/go.mod index b5f8f21..ac9e498 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ -module github.com/cockroachdb/apd/v2 +module github.com/cockroachdb/apd/v3 + +go 1.13 require github.com/pkg/errors v0.8.0 diff --git a/loop.go b/loop.go index 4dfc0d9..210210d 100644 --- a/loop.go +++ b/loop.go @@ -19,8 +19,8 @@ type loop struct { precision int32 maxIterations uint64 // When to give up. arg *Decimal // original argument to function; only used for diagnostic. - prevZ *Decimal // Result from the previous iteration. - delta *Decimal // |Change| from previous iteration. + prevZ Decimal // Result from the previous iteration. + delta Decimal // |Change| from previous iteration. } const digitsToBitsRatio = math.Ln10 / math.Ln2 @@ -41,15 +41,13 @@ func (c *Context) newLoop(name string, arg *Decimal, precision uint32, maxItersP arg: new(Decimal).Set(arg), precision: int32(precision), maxIterations: 10 + uint64(maxItersPerDigit*int(precision)), - prevZ: new(Decimal), - delta: new(Decimal), } } // done reports whether the loop is done. If it does not converge // after the maximum number of iterations, it returns an error. func (l *loop) done(z *Decimal) (bool, error) { - if _, err := l.c.Sub(l.delta, l.prevZ, z); err != nil { + if _, err := l.c.Sub(&l.delta, &l.prevZ, z); err != nil { return false, err } sign := l.delta.Sign() @@ -60,7 +58,7 @@ func (l *loop) done(z *Decimal) (bool, error) { // Convergence can oscillate when the calculation is nearly // done and we're running out of bits. This stops that. // See next comment. - l.delta.Neg(l.delta) + l.delta.Neg(&l.delta) } // We stop if the delta is smaller than a change of 1 in the @@ -73,7 +71,9 @@ func (l *loop) done(z *Decimal) (bool, error) { // p = 3 // z = 0.001234 = 1234 * 10^-6 // eps = 0.00001 = 10^(-3+4-6) - eps := Decimal{Coeff: *bigOne, Exponent: -l.precision + int32(z.NumDigits()) + z.Exponent} + var eps Decimal + eps.Coeff.Set(bigOne) + eps.Exponent = -l.precision + int32(z.NumDigits()) + z.Exponent if l.delta.Cmp(&eps) <= 0 { return true, nil } @@ -81,7 +81,7 @@ func (l *loop) done(z *Decimal) (bool, error) { if l.i == l.maxIterations { return false, errors.Errorf( "%s %s: did not converge after %d iterations; prev,last result %s,%s delta %s precision: %d", - l.name, l.arg.String(), l.maxIterations, z, l.prevZ, l.delta, l.precision, + l.name, l.arg.String(), l.maxIterations, z.String(), l.prevZ.String(), l.delta.String(), l.precision, ) } l.prevZ.Set(z) diff --git a/round.go b/round.go index 09b62e1..4cea210 100644 --- a/round.go +++ b/round.go @@ -14,10 +14,6 @@ package apd -import ( - "math/big" -) - // Round sets d to rounded x, rounded to the precision specified by c. If c // has zero precision, no rounding will occur. If c has no Rounding specified, // RoundHalfUp is used. @@ -30,25 +26,42 @@ func (c *Context) round(d, x *Decimal) Condition { d.Set(x) return d.setExponent(c, 0, int64(d.Exponent)) } - rounder := c.rounding() - res := rounder.Round(c, d, x) + res := c.Rounding.Round(c, d, x) return res } -func (c *Context) rounding() Rounder { - rounding, ok := Roundings[c.Rounding] - if !ok { - return roundHalfUp +// Rounder specifies the behavior of rounding. +type Rounder string + +// ShouldAddOne returns true if 1 should be added to the absolute value +// of a number being rounded. result is the result to which the 1 would +// be added. neg is true if the number is negative. half is -1 if the +// discarded digits are < 0.5, 0 if = 0.5, or 1 if > 0.5. +func (r Rounder) ShouldAddOne(result *BigInt, neg bool, half int) bool { + // NOTE: this is written using a switch statement instead of some + // other form of dynamic dispatch to assist Go's escape analysis. + switch r { + case RoundDown: + return roundDown(result, neg, half) + case RoundHalfUp: + return roundHalfUp(result, neg, half) + case RoundHalfEven: + return roundHalfEven(result, neg, half) + case RoundCeiling: + return roundCeiling(result, neg, half) + case RoundFloor: + return roundFloor(result, neg, half) + case RoundHalfDown: + return roundHalfDown(result, neg, half) + case RoundUp: + return roundUp(result, neg, half) + case Round05Up: + return round05Up(result, neg, half) + default: + return roundHalfUp(result, neg, half) } - return rounding } -// Rounder defines a function that returns true if 1 should be added to the -// absolute value of a number being rounded. result is the result to which -// the 1 would be added. neg is true if the number is negative. half is -1 -// if the discarded digits are < 0.5, 0 if = 0.5, or 1 if > 0.5. -type Rounder func(result *big.Int, neg bool, half int) bool - // Round sets d to rounded x. func (r Rounder) Round(c *Context, d, x *Decimal) Condition { d.Set(x) @@ -74,18 +87,19 @@ func (r Rounder) Round(c *Context, d, x *Decimal) Condition { return SystemUnderflow | Underflow } res |= Rounded - y := new(big.Int) - e := tableExp10(diff, y) - m := new(big.Int) - y.QuoRem(&d.Coeff, e, m) + var y, m BigInt + e := tableExp10(diff, &y) + y.QuoRem(&d.Coeff, e, &m) if m.Sign() != 0 { res |= Inexact - discard := NewWithBigInt(m, int32(-diff)) - if r(y, x.Negative, discard.Cmp(decimalHalf)) { - roundAddOne(y, &diff) + var discard Decimal + discard.Coeff.Set(&m) + discard.Exponent = int32(-diff) + if r.ShouldAddOne(&y, x.Negative, discard.Cmp(decimalHalf)) { + roundAddOne(&y, &diff) } } - d.Coeff = *y + d.Coeff.Set(&y) } else { diff = 0 } @@ -94,7 +108,7 @@ func (r Rounder) Round(c *Context, d, x *Decimal) Condition { } // roundAddOne adds 1 to abs(b). -func roundAddOne(b *big.Int, diff *int64) { +func roundAddOne(b *BigInt, diff *int64) { if b.Sign() < 0 { panic("unexpected negative") } @@ -107,56 +121,52 @@ func roundAddOne(b *big.Int, diff *int64) { } } -var ( - // Roundings defines the set of Rounders used by Context. Users may add their - // own, but modification of this map is not safe during any other parallel - // Context operations. - Roundings = map[string]Rounder{ - RoundDown: roundDown, - RoundHalfUp: roundHalfUp, - RoundHalfEven: roundHalfEven, - RoundCeiling: roundCeiling, - RoundFloor: roundFloor, - RoundHalfDown: roundHalfDown, - RoundUp: roundUp, - Round05Up: round05Up, - } -) +// roundings is a set containing all available Rounders. +var roundings = map[Rounder]struct{}{ + RoundDown: {}, + RoundHalfUp: {}, + RoundHalfEven: {}, + RoundCeiling: {}, + RoundFloor: {}, + RoundHalfDown: {}, + RoundUp: {}, + Round05Up: {}, +} const ( // RoundDown rounds toward 0; truncate. - RoundDown = "down" + RoundDown Rounder = "down" // RoundHalfUp rounds up if the digits are >= 0.5. - RoundHalfUp = "half_up" + RoundHalfUp Rounder = "half_up" // RoundHalfEven rounds up if the digits are > 0.5. If the digits are equal // to 0.5, it rounds up if the previous digit is odd, always producing an // even digit. - RoundHalfEven = "half_even" + RoundHalfEven Rounder = "half_even" // RoundCeiling towards +Inf: rounds up if digits are > 0 and the number // is positive. - RoundCeiling = "ceiling" + RoundCeiling Rounder = "ceiling" // RoundFloor towards -Inf: rounds up if digits are > 0 and the number // is negative. - RoundFloor = "floor" + RoundFloor Rounder = "floor" // RoundHalfDown rounds up if the digits are > 0.5. - RoundHalfDown = "half_down" + RoundHalfDown Rounder = "half_down" // RoundUp rounds away from 0. - RoundUp = "up" + RoundUp Rounder = "up" // Round05Up rounds zero or five away from 0; same as round-up, except that // rounding up only occurs if the digit to be rounded up is 0 or 5. - Round05Up = "05up" + Round05Up Rounder = "05up" ) -func roundDown(result *big.Int, neg bool, half int) bool { +func roundDown(result *BigInt, neg bool, half int) bool { return false } -func roundUp(result *big.Int, neg bool, half int) bool { +func roundUp(result *BigInt, neg bool, half int) bool { return true } -func round05Up(result *big.Int, neg bool, half int) bool { - z := new(big.Int) +func round05Up(result *BigInt, neg bool, half int) bool { + var z BigInt z.Rem(result, bigFive) if z.Sign() == 0 { return true @@ -165,11 +175,11 @@ func round05Up(result *big.Int, neg bool, half int) bool { return z.Sign() == 0 } -func roundHalfUp(result *big.Int, neg bool, half int) bool { +func roundHalfUp(result *BigInt, neg bool, half int) bool { return half >= 0 } -func roundHalfEven(result *big.Int, neg bool, half int) bool { +func roundHalfEven(result *BigInt, neg bool, half int) bool { if half > 0 { return true } @@ -179,14 +189,14 @@ func roundHalfEven(result *big.Int, neg bool, half int) bool { return result.Bit(0) == 1 } -func roundHalfDown(result *big.Int, neg bool, half int) bool { +func roundHalfDown(result *BigInt, neg bool, half int) bool { return half > 0 } -func roundFloor(result *big.Int, neg bool, half int) bool { +func roundFloor(result *BigInt, neg bool, half int) bool { return neg } -func roundCeiling(result *big.Int, neg bool, half int) bool { +func roundCeiling(result *BigInt, neg bool, half int) bool { return !neg } diff --git a/table.go b/table.go index c1e2342..9abee51 100644 --- a/table.go +++ b/table.go @@ -14,8 +14,6 @@ package apd -import "math/big" - // digitsLookupTable is used to map binary digit counts to their corresponding // decimal border values. The map relies on the proof that (without leading zeros) // for any given number of binary digits r, such that the number represented is @@ -31,13 +29,13 @@ var digitsLookupTable [digitsTableSize + 1]tableVal type tableVal struct { digits int64 - border big.Int - nborder big.Int + border BigInt + nborder BigInt } func init() { - curVal := big.NewInt(1) - curExp := new(big.Int) + curVal := NewBigInt(1) + curExp := new(BigInt) for i := 1; i <= digitsTableSize; i++ { if i > 1 { curVal.Lsh(curVal, 1) @@ -59,14 +57,14 @@ func (d *Decimal) NumDigits() int64 { } // NumDigits returns the number of decimal digits of b. -func NumDigits(b *big.Int) int64 { +func NumDigits(b *BigInt) int64 { bl := b.BitLen() if bl == 0 { return 1 } if bl <= digitsTableSize { - val := digitsLookupTable[bl] + val := &digitsLookupTable[bl] // In general, we either have val.digits or val.digits+1 digits and we have // to compare with the border value. But that's not true for all values of // bl: in particular, if bl+1 maps to the same number of digits, then we @@ -90,9 +88,12 @@ func NumDigits(b *big.Int) int64 { } n := int64(float64(bl) / digitsToBitsRatio) - a := new(big.Int) - e := tableExp10(n, a) + var tmp BigInt + e := tableExp10(n, &tmp) + var a *BigInt if b.Sign() < 0 { + var tmpA BigInt + a := &tmpA a.Abs(b) } else { a = b @@ -109,30 +110,27 @@ func NumDigits(b *big.Int) int64 { // 10^3 inclusive. const powerTenTableSize = 128 -var pow10LookupTable [powerTenTableSize + 1]big.Int +var pow10LookupTable [powerTenTableSize + 1]BigInt func init() { - tmpInt := new(big.Int) for i := int64(0); i <= powerTenTableSize; i++ { - setBigWithPow(&pow10LookupTable[i], tmpInt, i) + setBigWithPow(&pow10LookupTable[i], i) } } -func setBigWithPow(bi *big.Int, tmpInt *big.Int, pow int64) { - if tmpInt == nil { - tmpInt = new(big.Int) - } - bi.Exp(bigTen, tmpInt.SetInt64(pow), nil) +func setBigWithPow(res *BigInt, pow int64) { + var tmp BigInt + tmp.SetInt64(pow) + res.Exp(bigTen, &tmp, nil) } // tableExp10 returns 10^x for x >= 0, looked up from a table when // possible. This returned value must not be mutated. tmp is used as an -// intermediate variable, but may be nil. -func tableExp10(x int64, tmp *big.Int) *big.Int { +// intermediate variable and must not be nil. +func tableExp10(x int64, tmp *BigInt) *BigInt { if x <= powerTenTableSize { return &pow10LookupTable[x] } - b := new(big.Int) - setBigWithPow(b, tmp, x) - return b + setBigWithPow(tmp, x) + return tmp } diff --git a/table_test.go b/table_test.go index 5d6266a..36fd193 100644 --- a/table_test.go +++ b/table_test.go @@ -16,51 +16,56 @@ package apd import ( "bytes" - "math/big" "math/rand" "strings" "testing" ) func BenchmarkNumDigitsLookup(b *testing.B) { - b.StopTimer() - runTest := func(start string, c byte) { + prep := func(start string, c byte) []*Decimal { + var ds []*Decimal buf := bytes.NewBufferString(start) for i := 1; i < digitsTableSize; i++ { buf.WriteByte(c) d, _, _ := NewFromString(buf.String()) - - b.StartTimer() - d.NumDigits() - b.StopTimer() + ds = append(ds, d) } + return ds } + var ds []*Decimal + ds = append(ds, prep("", '9')...) + ds = append(ds, prep("1", '0')...) + ds = append(ds, prep("-", '9')...) + ds = append(ds, prep("-1", '0')...) + b.ResetTimer() for i := 0; i < b.N; i++ { - runTest("", '9') - runTest("1", '0') - runTest("-", '9') - runTest("-1", '0') + for _, d := range ds { + d.NumDigits() + } } } func BenchmarkNumDigitsFull(b *testing.B) { - b.StopTimer() - runTest := func(start string, c byte) { + prep := func(start string, c byte) []*Decimal { + var ds []*Decimal buf := bytes.NewBufferString(start) for i := 1; i < 1000; i++ { buf.WriteByte(c) d, _, _ := NewFromString(buf.String()) - - b.StartTimer() - d.NumDigits() - b.StopTimer() + ds = append(ds, d) } + return ds } + var ds []*Decimal + ds = append(ds, prep("", '9')...) + ds = append(ds, prep("1", '0')...) + ds = append(ds, prep("-", '9')...) + ds = append(ds, prep("-1", '0')...) + b.ResetTimer() for i := 0; i < b.N; i++ { - runTest("", '9') - runTest("1", '0') - runTest("-", '9') - runTest("-1", '0') + for _, d := range ds { + d.NumDigits() + } } } @@ -92,13 +97,13 @@ func TestNumDigits(t *testing.T) { func TestDigitsLookupTable(t *testing.T) { // Make sure all elements in table make sense. - min := new(big.Int) - prevBorder := big.NewInt(0) + min := new(BigInt) + prevBorder := NewBigInt(0) for i := 1; i <= digitsTableSize; i++ { - elem := digitsLookupTable[i] + elem := &digitsLookupTable[i] min.SetInt64(2) - min.Exp(min, big.NewInt(int64(i-1)), nil) + min.Exp(min, NewBigInt(int64(i-1)), nil) if minLen := int64(len(min.String())); minLen != elem.digits { t.Errorf("expected 2^%d to have %d digits, found %d", i, elem.digits, minLen) } @@ -123,11 +128,11 @@ func TestDigitsLookupTable(t *testing.T) { // digit lengths line up. const randomTrials = 100 for i := 0; i < randomTrials; i++ { - a := big.NewInt(rand.Int63()) - b := big.NewInt(rand.Int63()) + a := NewBigInt(rand.Int63()) + b := NewBigInt(rand.Int63()) a.Mul(a, b) - d := &Decimal{Coeff: *a} + d := NewWithBigInt(a, 0) tableDigits := d.NumDigits() if actualDigits := int64(len(a.String())); actualDigits != tableDigits { t.Errorf("expected %d digits for %v, found %d", tableDigits, a, actualDigits) @@ -159,7 +164,8 @@ func TestTableExp10(t *testing.T) { } for i, test := range tests { - d := tableExp10(test.pow, nil) + var tmpE BigInt + d := tableExp10(test.pow, &tmpE) if s := d.String(); s != test.str { t.Errorf("%d: expected PowerOfTenDec(%d) to give %s, got %s", i, test.pow, test.str, s) }