Skip to content

Commit

Permalink
apd: support BigInt fast-paths on 32-bit architectures
Browse files Browse the repository at this point in the history
This commit switches the BigInt fast-paths from operating on uint values
to uint64 values. This allows them to work on 32-bit architectures for any
value that can fit in a 64-bit unsigned integer.

There are two benefits of this. The first is that it unifies the arithmetic
fast-paths across the architectures so that they perform more similarly. The
second is that it removes some dead code on 64-bit architectures, so that we
can avoid issues with jordanlewis/gcassert#3.
  • Loading branch information
nvanbenschoten committed Jan 9, 2022
1 parent fc7d319 commit 96f1660
Showing 1 changed file with 85 additions and 91 deletions.
176 changes: 85 additions & 91 deletions bigint.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

// The inlineWords capacity is set to accommodate any value that would fit in a
// 128-bit integer (i.e. values with an absolute value up to 2^128 - 1).
const inlineWords = 2
const inlineWords = 128 / bits.UintSize

// BigInt is a wrapper around big.Int. It minimizes memory allocation by using
// an inline array to back the big.Int's variable-length "nat" slice when the
Expand Down Expand Up @@ -206,54 +206,70 @@ func (z *BigInt) updateInner(src *big.Int) {
}
}

// innerAsUint returns the BigInt's current absolute value as a uint and a flag
// indicating whether the value is negative. If the value is not stored inline
// or if it can not fit in a uint, false is returned.
const wordsInUint64 = 64 / bits.UintSize

func init() {
if inlineWords < wordsInUint64 {
panic("inline array must be at least 64 bits large")
}
}

// innerAsUint64 returns the BigInt's current absolute value as a uint64 and a
// flag indicating whether the value is negative. If the value is not stored
// inline or if it can not fit in a uint64, false is returned.
//
// NOTE: this was carefully written to permit function inlining. Modify with
// care.
//gcassert:inline
func (z *BigInt) innerAsUint() (val uint, neg bool, ok bool) {
func (z *BigInt) innerAsUint64() (val uint64, neg bool, ok bool) {
if !z.isInline() {
// The value is not stored inline.
return 0, false, false
}
if inlineWords == 2 {
if bits.UintSize == 64 && inlineWords == 2 {
// Manually unrolled loop for current inlineWords setting.
if z._inline[1] != 0 {
// The value can not fit in a uint.
// The value can not fit in a uint64.
return 0, false, false
}
} else {
// Fallback for other values of inlineWords.
for i := 1; i < len(z._inline); i++ {
for i := wordsInUint64; i < len(z._inline); i++ {
if z._inline[i] != 0 {
// The value can not fit in a uint.
// The value can not fit in a uint64.
return 0, false, false
}
}
}

val = uint(z._inline[0])
val = uint64(z._inline[0])
if bits.UintSize == 32 {
// From big.low64.
val = uint64(z._inline[1])<<32 | val
}
neg = z._inner == negSentinel
return val, neg, true
}

// updateInnerFromUint updates the BigInt's current value with the provided
// updateInnerFromUint64 updates the BigInt's current value with the provided
// absolute value and sign.
//
// NOTE: this was carefully written to permit function inlining. Modify with
// care.
//gcassert:inline
func (z *BigInt) updateInnerFromUint(val uint, neg bool) {
func (z *BigInt) updateInnerFromUint64(val uint64, neg bool) {
// Set the inline value, making sure to clear out all other words.
z._inline[0] = big.Word(val)
if inlineWords == 2 {
if bits.UintSize == 32 {
// From (big.nat).setUint64.
z._inline[1] = big.Word(val >> 32)
}
if bits.UintSize == 64 && inlineWords == 2 {
// Manually unrolled loop for current inlineWords setting.
z._inline[1] = 0
} else {
// Fallback for other values of inlineWords.
for i := 1; i < len(z._inline); i++ {
for i := wordsInUint64; i < len(z._inline); i++ {
z._inline[i] = 0
}
}
Expand Down Expand Up @@ -285,14 +301,14 @@ func (z *BigInt) Size() uintptr {
///////////////////////////////////////////////////////////////////////////////

//gcassert:inline
func addInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) {
func addInline(xVal, yVal uint64, xNeg, yNeg bool) (zVal uint64, zNeg, ok bool) {
if xNeg == yNeg {
sum, carry := bits.Add(xVal, yVal, 0)
sum, carry := bits.Add64(xVal, yVal, 0)
overflow := carry != 0
return sum, xNeg, !overflow
}

diff, borrow := bits.Sub(xVal, yVal, 0)
diff, borrow := bits.Sub64(xVal, yVal, 0)
if borrow != 0 { // underflow
xNeg = !xNeg
diff = yVal - xVal
Expand All @@ -304,15 +320,15 @@ func addInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) {
}

//gcassert:inline
func mulInline(xVal, yVal uint, xNeg, yNeg bool) (zVal uint, zNeg, ok bool) {
hi, lo := bits.Mul(xVal, yVal)
func mulInline(xVal, yVal uint64, xNeg, yNeg bool) (zVal uint64, zNeg, ok bool) {
hi, lo := bits.Mul64(xVal, yVal)
neg := xNeg != yNeg
overflow := hi != 0
return lo, neg, !overflow
}

//gcassert:inline
func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool) {
func quoInline(xVal, yVal uint64, xNeg, yNeg bool) (quoVal uint64, quoNeg, ok bool) {
if yVal == 0 { // divide by 0
return 0, false, false
}
Expand All @@ -322,7 +338,7 @@ func quoInline(xVal, yVal uint, xNeg, yNeg bool) (quoVal uint, quoNeg, ok bool)
}

//gcassert:inline
func remInline(xVal, yVal uint, xNeg, yNeg bool) (remVal uint, remNeg, ok bool) {
func remInline(xVal, yVal uint64, xNeg, yNeg bool) (remVal uint64, remNeg, ok bool) {
if yVal == 0 { // divide by 0
return 0, false, false
}
Expand Down Expand Up @@ -350,10 +366,10 @@ func (z *BigInt) Abs(x *BigInt) *BigInt {

// Add calls (big.Int).Add.
func (z *BigInt) Add(x, y *BigInt) *BigInt {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, yNeg); ok {
z.updateInnerFromUint(zVal, zNeg)
z.updateInnerFromUint64(zVal, zNeg)
return z
}
}
Expand Down Expand Up @@ -389,13 +405,13 @@ func (z *BigInt) Append(buf []byte, base int) []byte {
// Fast-path that avoids innerOrNil, allowing inner to be inlined.
return append(buf, "<nil>"...)
}
if zVal, zNeg, ok := z.innerAsUint(); ok {
if zVal, zNeg, ok := z.innerAsUint64(); ok {
// Check if the base is supported by strconv.AppendUint.
if base >= 2 && base <= 36 {
if zNeg {
buf = append(buf, '-')
}
return strconv.AppendUint(buf, uint64(zVal), base)
return strconv.AppendUint(buf, zVal, base)
}
}
var tmp1 big.Int
Expand Down Expand Up @@ -450,8 +466,8 @@ func (z *BigInt) Bytes() []byte {

// Cmp calls (big.Int).Cmp.
func (z *BigInt) Cmp(y *BigInt) (r int) {
if zVal, zNeg, ok := z.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if zVal, zNeg, ok := z.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
switch {
case zNeg == yNeg:
switch {
Expand All @@ -477,8 +493,8 @@ func (z *BigInt) Cmp(y *BigInt) (r int) {

// CmpAbs calls (big.Int).CmpAbs.
func (z *BigInt) CmpAbs(y *BigInt) (r int) {
if zVal, _, ok := z.innerAsUint(); ok {
if yVal, _, ok := y.innerAsUint(); ok {
if zVal, _, ok := z.innerAsUint64(); ok {
if yVal, _, ok := y.innerAsUint64(); ok {
switch {
case zVal < yVal:
r = -1
Expand Down Expand Up @@ -571,42 +587,36 @@ func (z *BigInt) GobDecode(buf []byte) error {

// Int64 calls (big.Int).Int64.
func (z *BigInt) Int64() int64 {
if bits.UintSize == 64 {
if zVal, zNeg, ok := z.innerAsUint(); ok {
// The unchecked cast from uint64 to int64 looks unsafe, but it is
// allowed and is identical to the logic in (big.Int).Int64. Per the
// method's contract:
// > If z cannot be represented in an int64, the result is undefined.
zi := int64(zVal)
if zNeg {
zi = -zi
}
return zi
if zVal, zNeg, ok := z.innerAsUint64(); ok {
// The unchecked cast from uint64 to int64 looks unsafe, but it is
// allowed and is identical to the logic in (big.Int).Int64. Per the
// method's contract:
// > If z cannot be represented in an int64, the result is undefined.
zi := int64(zVal)
if zNeg {
zi = -zi
}
return zi
}
var tmp1 big.Int
return z.inner(&tmp1).Int64()
}

// IsInt64 calls (big.Int).IsInt64.
func (z *BigInt) IsInt64() bool {
if bits.UintSize == 64 {
if zVal, zNeg, ok := z.innerAsUint(); ok {
// From (big.Int).IsInt64.
zi := int64(zVal)
return zi >= 0 || zNeg && zi == -zi
}
if zVal, zNeg, ok := z.innerAsUint64(); ok {
// From (big.Int).IsInt64.
zi := int64(zVal)
return zi >= 0 || zNeg && zi == -zi
}
var tmp1 big.Int
return z.inner(&tmp1).IsInt64()
}

// IsUint64 calls (big.Int).IsUint64.
func (z *BigInt) IsUint64() bool {
if bits.UintSize == 64 {
if _, zNeg, ok := z.innerAsUint(); ok {
return !zNeg
}
if _, zNeg, ok := z.innerAsUint64(); ok {
return !zNeg
}
var tmp1 big.Int
return z.inner(&tmp1).IsUint64()
Expand Down Expand Up @@ -668,10 +678,10 @@ func (z *BigInt) ModSqrt(x, p *BigInt) *BigInt {

// Mul calls (big.Int).Mul.
func (z *BigInt) Mul(x, y *BigInt) *BigInt {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if zVal, zNeg, ok := mulInline(xVal, yVal, xNeg, yNeg); ok {
z.updateInnerFromUint(zVal, zNeg)
z.updateInnerFromUint64(zVal, zNeg)
return z
}
}
Expand Down Expand Up @@ -736,10 +746,10 @@ func (z *BigInt) ProbablyPrime(n int) bool {

// Quo calls (big.Int).Quo.
func (z *BigInt) Quo(x, y *BigInt) *BigInt {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok {
z.updateInnerFromUint(quoVal, quoNeg)
z.updateInnerFromUint64(quoVal, quoNeg)
return z
}
}
Expand All @@ -753,12 +763,12 @@ func (z *BigInt) Quo(x, y *BigInt) *BigInt {

// QuoRem calls (big.Int).QuoRem.
func (z *BigInt) QuoRem(x, y, r *BigInt) (*BigInt, *BigInt) {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if quoVal, quoNeg, ok := quoInline(xVal, yVal, xNeg, yNeg); ok {
if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok {
z.updateInnerFromUint(quoVal, quoNeg)
r.updateInnerFromUint(remVal, remNeg)
z.updateInnerFromUint64(quoVal, quoNeg)
r.updateInnerFromUint64(remVal, remNeg)
return z, r
}
}
Expand All @@ -784,10 +794,10 @@ func (z *BigInt) Rand(rnd *rand.Rand, n *BigInt) *BigInt {

// Rem calls (big.Int).Rem.
func (z *BigInt) Rem(x, y *BigInt) *BigInt {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if remVal, remNeg, ok := remInline(xVal, yVal, xNeg, yNeg); ok {
z.updateInnerFromUint(remVal, remNeg)
z.updateInnerFromUint64(remVal, remNeg)
return z
}
}
Expand Down Expand Up @@ -861,19 +871,12 @@ func (z *BigInt) SetBytes(buf []byte) *BigInt {

// SetInt64 calls (big.Int).SetInt64.
func (z *BigInt) SetInt64(x int64) *BigInt {
if bits.UintSize == 64 {
neg := false
if x < 0 {
neg = true
x = -x
}
z.updateInnerFromUint(uint(x), neg)
return z
neg := false
if x < 0 {
neg = true
x = -x
}
var tmp1 big.Int
zi := z.inner(&tmp1)
zi.SetInt64(x)
z.updateInner(zi)
z.updateInnerFromUint64(uint64(x), neg)
return z
}

Expand All @@ -890,14 +893,7 @@ func (z *BigInt) SetString(s string, base int) (*BigInt, bool) {

// SetUint64 calls (big.Int).SetUint64.
func (z *BigInt) SetUint64(x uint64) *BigInt {
if bits.UintSize == 64 {
z.updateInnerFromUint(uint(x), false)
return z
}
var tmp1 big.Int
zi := z.inner(&tmp1)
zi.SetUint64(x)
z.updateInner(zi)
z.updateInnerFromUint64(x, false)
return z
}

Expand Down Expand Up @@ -935,10 +931,10 @@ func (z *BigInt) String() string {

// Sub calls (big.Int).Sub.
func (z *BigInt) Sub(x, y *BigInt) *BigInt {
if xVal, xNeg, ok := x.innerAsUint(); ok {
if yVal, yNeg, ok := y.innerAsUint(); ok {
if xVal, xNeg, ok := x.innerAsUint64(); ok {
if yVal, yNeg, ok := y.innerAsUint64(); ok {
if zVal, zNeg, ok := addInline(xVal, yVal, xNeg, !yNeg); ok {
z.updateInnerFromUint(zVal, zNeg)
z.updateInnerFromUint64(zVal, zNeg)
return z
}
}
Expand Down Expand Up @@ -968,10 +964,8 @@ func (z *BigInt) TrailingZeroBits() uint {

// Uint64 calls (big.Int).Uint64.
func (z *BigInt) Uint64() uint64 {
if bits.UintSize == 64 {
if zVal, _, ok := z.innerAsUint(); ok {
return uint64(zVal)
}
if zVal, _, ok := z.innerAsUint64(); ok {
return zVal
}
var tmp1 big.Int
return z.inner(&tmp1).Uint64()
Expand Down

0 comments on commit 96f1660

Please sign in to comment.