From 5e38978e4496d125826ebbd35a670ae3548bdfd0 Mon Sep 17 00:00:00 2001 From: Nathan VanBenschoten Date: Sun, 9 Jan 2022 14:35:00 -0500 Subject: [PATCH] apd: support BigInt fast-paths on 32-bit architectures This commit switches the BigInt fast-paths from operating on uint values to uint64 values. This allows them to work on 32-bit architectures for any value that can fit in a 64-bit unsigned integer. There are two benefits of this. The first is that it unifies the arithmetic fast-paths across the architectures so that they perform more similarly. The second is that it removes some dead code on 64-bit architectures, so that we can avoid issues with https://github.com/jordanlewis/gcassert/pull/3. --- bigint.go | 180 ++++++++++++++++++++++++++---------------------------- 1 file changed, 88 insertions(+), 92 deletions(-) diff --git a/bigint.go b/bigint.go index 36b4d53..f1906c4 100644 --- a/bigint.go +++ b/bigint.go @@ -25,7 +25,7 @@ import ( // The inlineWords capacity is set to accommodate any value that would fit in a // 128-bit integer (i.e. values with an absolute value up to 2^128 - 1). -const inlineWords = 2 +const inlineWords = 128 / bits.UintSize // BigInt is a wrapper around big.Int. It minimizes memory allocation by using // an inline array to back the big.Int's variable-length "nat" slice when the @@ -206,54 +206,72 @@ func (z *BigInt) updateInner(src *big.Int) { } } -// innerAsUint returns the BigInt's current absolute value as a uint and a flag -// indicating whether the value is negative. If the value is not stored inline -// or if it can not fit in a uint, false is returned. +const wordsInUint64 = 64 / bits.UintSize + +func init() { + if inlineWords < wordsInUint64 { + panic("inline array must be at least 64 bits large") + } +} + +// innerAsUint64 returns the BigInt's current absolute value as a uint64 and a +// flag indicating whether the value is negative. If the value is not stored +// inline or if it can not fit in a uint64, false is returned. // // NOTE: this was carefully written to permit function inlining. Modify with // care. //gcassert:inline -func (z *BigInt) innerAsUint() (val uint, neg bool, ok bool) { +func (z *BigInt) innerAsUint64() (val uint64, neg bool, ok bool) { if !z.isInline() { // The value is not stored inline. return 0, false, false } - if inlineWords == 2 { + if wordsInUint64 == 1 && inlineWords == 2 { // Manually unrolled loop for current inlineWords setting. if z._inline[1] != 0 { - // The value can not fit in a uint. + // The value can not fit in a uint64. return 0, false, false } } else { // Fallback for other values of inlineWords. - for i := 1; i < len(z._inline); i++ { + for i := wordsInUint64; i < len(z._inline); i++ { if z._inline[i] != 0 { - // The value can not fit in a uint. + // The value can not fit in a uint64. return 0, false, false } } } - val = uint(z._inline[0]) + val = uint64(z._inline[0]) + if wordsInUint64 == 2 { + // From big.low64. + val = uint64(z._inline[1])<<32 | val + } neg = z._inner == negSentinel return val, neg, true } -// updateInnerFromUint updates the BigInt's current value with the provided +// updateInnerFromUint64 updates the BigInt's current value with the provided // absolute value and sign. // // NOTE: this was carefully written to permit function inlining. Modify with // care. //gcassert:inline -func (z *BigInt) updateInnerFromUint(val uint, neg bool) { - // Set the inline value, making sure to clear out all other words. +func (z *BigInt) updateInnerFromUint64(val uint64, neg bool) { + // Set the inline value. z._inline[0] = big.Word(val) - if inlineWords == 2 { + if wordsInUint64 == 2 { + // From (big.nat).setUint64. + z._inline[1] = big.Word(val >> 32) + } + + // Clear out all other words in the inline array. + if wordsInUint64 == 1 && inlineWords == 2 { // Manually unrolled loop for current inlineWords setting. z._inline[1] = 0 } else { // Fallback for other values of inlineWords. - for i := 1; i < len(z._inline); i++ { + for i := wordsInUint64; i < len(z._inline); i++ { z._inline[i] = 0 } } @@ -285,14 +303,14 @@ func (z *BigInt) Size() uintptr { /////////////////////////////////////////////////////////////////////////////// //gcassert:inline -func addInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { +func addInline(xVal, yVal uint64, xNeg, yNeg bool) (zVal uint64, zNeg, ok bool) { if xNeg == yNeg { - sum, carry := bits.Add(xVal, yVal, 0) + sum, carry := bits.Add64(xVal, yVal, 0) overflow := carry != 0 return sum, xNeg, !overflow } - diff, borrow := bits.Sub(xVal, yVal, 0) + diff, borrow := bits.Sub64(xVal, yVal, 0) if borrow != 0 { // underflow xNeg = !xNeg diff = yVal - xVal @@ -304,15 +322,15 @@ func addInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { } //gcassert:inline -func mulInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) { - hi, lo := bits.Mul(xVal, yVal) +func mulInline(xVal, yVal uint64, xNeg, yNeg bool) (zVal uint64, zNeg, ok bool) { + hi, lo := bits.Mul64(xVal, yVal) neg := xNeg != yNeg overflow := hi != 0 return lo, neg, !overflow } //gcassert:inline -func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool) { +func quoInline(xVal, yVal uint64, xNeg, yNeg bool) (quoVal uint64, quoNeg, ok bool) { if yVal == 0 { // divide by 0 return 0, false, false } @@ -322,7 +340,7 @@ func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool) } //gcassert:inline -func remInline(xVal, yVal uint, xNeg, yNeg bool) (remVal uint, remNeg, ok bool) { +func remInline(xVal, yVal uint64, xNeg, yNeg bool) (remVal uint64, remNeg, ok bool) { if yVal == 0 { // divide by 0 return 0, false, false } @@ -350,10 +368,10 @@ func (z *BigInt) Abs(x *BigInt) *BigInt { // Add calls (big.Int).Add. func (z *BigInt) Add(x, y *BigInt) *BigInt { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, yNeg); ok { - z.updateInnerFromUint(zVal, zNeg) + z.updateInnerFromUint64(zVal, zNeg) return z } } @@ -389,13 +407,13 @@ func (z *BigInt) Append(buf []byte, base int) []byte { // Fast-path that avoids innerOrNil, allowing inner to be inlined. return append(buf, ""...) } - if zVal, zNeg, ok := z.innerAsUint(); ok { + if zVal, zNeg, ok := z.innerAsUint64(); ok { // Check if the base is supported by strconv.AppendUint. if base >= 2 && base <= 36 { if zNeg { buf = append(buf, '-') } - return strconv.AppendUint(buf, uint64(zVal), base) + return strconv.AppendUint(buf, zVal, base) } } var tmp1 big.Int @@ -450,8 +468,8 @@ func (z *BigInt) Bytes() []byte { // Cmp calls (big.Int).Cmp. func (z *BigInt) Cmp(y *BigInt) (r int) { - if zVal, zNeg, ok := z.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if zVal, zNeg, ok := z.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { switch { case zNeg == yNeg: switch { @@ -477,8 +495,8 @@ func (z *BigInt) Cmp(y *BigInt) (r int) { // CmpAbs calls (big.Int).CmpAbs. func (z *BigInt) CmpAbs(y *BigInt) (r int) { - if zVal, _, ok := z.innerAsUint(); ok { - if yVal, _, ok := y.innerAsUint(); ok { + if zVal, _, ok := z.innerAsUint64(); ok { + if yVal, _, ok := y.innerAsUint64(); ok { switch { case zVal < yVal: r = -1 @@ -571,18 +589,16 @@ func (z *BigInt) GobDecode(buf []byte) error { // Int64 calls (big.Int).Int64. func (z *BigInt) Int64() int64 { - if bits.UintSize == 64 { - if zVal, zNeg, ok := z.innerAsUint(); ok { - // The unchecked cast from uint64 to int64 looks unsafe, but it is - // allowed and is identical to the logic in (big.Int).Int64. Per the - // method's contract: - // > If z cannot be represented in an int64, the result is undefined. - zi := int64(zVal) - if zNeg { - zi = -zi - } - return zi + if zVal, zNeg, ok := z.innerAsUint64(); ok { + // The unchecked cast from uint64 to int64 looks unsafe, but it is + // allowed and is identical to the logic in (big.Int).Int64. Per the + // method's contract: + // > If z cannot be represented in an int64, the result is undefined. + zi := int64(zVal) + if zNeg { + zi = -zi } + return zi } var tmp1 big.Int return z.inner(&tmp1).Int64() @@ -590,12 +606,10 @@ func (z *BigInt) Int64() int64 { // IsInt64 calls (big.Int).IsInt64. func (z *BigInt) IsInt64() bool { - if bits.UintSize == 64 { - if zVal, zNeg, ok := z.innerAsUint(); ok { - // From (big.Int).IsInt64. - zi := int64(zVal) - return zi >= 0 || zNeg && zi == -zi - } + if zVal, zNeg, ok := z.innerAsUint64(); ok { + // From (big.Int).IsInt64. + zi := int64(zVal) + return zi >= 0 || zNeg && zi == -zi } var tmp1 big.Int return z.inner(&tmp1).IsInt64() @@ -603,10 +617,8 @@ func (z *BigInt) IsInt64() bool { // IsUint64 calls (big.Int).IsUint64. func (z *BigInt) IsUint64() bool { - if bits.UintSize == 64 { - if _, zNeg, ok := z.innerAsUint(); ok { - return !zNeg - } + if _, zNeg, ok := z.innerAsUint64(); ok { + return !zNeg } var tmp1 big.Int return z.inner(&tmp1).IsUint64() @@ -668,10 +680,10 @@ func (z *BigInt) ModSqrt(x, p *BigInt) *BigInt { // Mul calls (big.Int).Mul. func (z *BigInt) Mul(x, y *BigInt) *BigInt { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if zVal, zNeg, ok := mulInline(xVal, yVal, xNeg, yNeg); ok { - z.updateInnerFromUint(zVal, zNeg) + z.updateInnerFromUint64(zVal, zNeg) return z } } @@ -736,10 +748,10 @@ func (z *BigInt) ProbablyPrime(n int) bool { // Quo calls (big.Int).Quo. func (z *BigInt) Quo(x, y *BigInt) *BigInt { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok { - z.updateInnerFromUint(quoVal, quoNeg) + z.updateInnerFromUint64(quoVal, quoNeg) return z } } @@ -753,12 +765,12 @@ func (z *BigInt) Quo(x, y *BigInt) *BigInt { // QuoRem calls (big.Int).QuoRem. func (z *BigInt) QuoRem(x, y, r *BigInt) (*BigInt, *BigInt) { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok { if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok { - z.updateInnerFromUint(quoVal, quoNeg) - r.updateInnerFromUint(remVal, remNeg) + z.updateInnerFromUint64(quoVal, quoNeg) + r.updateInnerFromUint64(remVal, remNeg) return z, r } } @@ -784,10 +796,10 @@ func (z *BigInt) Rand(rnd *rand.Rand, n *BigInt) *BigInt { // Rem calls (big.Int).Rem. func (z *BigInt) Rem(x, y *BigInt) *BigInt { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok { - z.updateInnerFromUint(remVal, remNeg) + z.updateInnerFromUint64(remVal, remNeg) return z } } @@ -861,19 +873,12 @@ func (z *BigInt) SetBytes(buf []byte) *BigInt { // SetInt64 calls (big.Int).SetInt64. func (z *BigInt) SetInt64(x int64) *BigInt { - if bits.UintSize == 64 { - neg := false - if x < 0 { - neg = true - x = -x - } - z.updateInnerFromUint(uint(x), neg) - return z + neg := false + if x < 0 { + neg = true + x = -x } - var tmp1 big.Int - zi := z.inner(&tmp1) - zi.SetInt64(x) - z.updateInner(zi) + z.updateInnerFromUint64(uint64(x), neg) return z } @@ -890,14 +895,7 @@ func (z *BigInt) SetString(s string, base int) (*BigInt, bool) { // SetUint64 calls (big.Int).SetUint64. func (z *BigInt) SetUint64(x uint64) *BigInt { - if bits.UintSize == 64 { - z.updateInnerFromUint(uint(x), false) - return z - } - var tmp1 big.Int - zi := z.inner(&tmp1) - zi.SetUint64(x) - z.updateInner(zi) + z.updateInnerFromUint64(x, false) return z } @@ -935,10 +933,10 @@ func (z *BigInt) String() string { // Sub calls (big.Int).Sub. func (z *BigInt) Sub(x, y *BigInt) *BigInt { - if xVal, xNeg, ok := x.innerAsUint(); ok { - if yVal, yNeg, ok := y.innerAsUint(); ok { + if xVal, xNeg, ok := x.innerAsUint64(); ok { + if yVal, yNeg, ok := y.innerAsUint64(); ok { if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, !yNeg); ok { - z.updateInnerFromUint(zVal, zNeg) + z.updateInnerFromUint64(zVal, zNeg) return z } } @@ -968,10 +966,8 @@ func (z *BigInt) TrailingZeroBits() uint { // Uint64 calls (big.Int).Uint64. func (z *BigInt) Uint64() uint64 { - if bits.UintSize == 64 { - if zVal, _, ok := z.innerAsUint(); ok { - return uint64(zVal) - } + if zVal, _, ok := z.innerAsUint64(); ok { + return zVal } var tmp1 big.Int return z.inner(&tmp1).Uint64()