Skip to content

Commit

Permalink
Add inplace decimal operations (#10)
Browse files Browse the repository at this point in the history
* add mutable decimal methods

* fix root

* gofmt

* add Clone(), rm assignment quo

* add SetInt64

* fix SetInt64

* More mut speedups

Co-authored-by: ValarDragon <[email protected]>
  • Loading branch information
2 people authored and faddat committed Jan 7, 2023
1 parent 8c67157 commit 44a8b5a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 54 deletions.
189 changes: 135 additions & 54 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ func (d Dec) GTE(d2 Dec) bool { return (d.i).Cmp(d2.i) >= 0 } // greater
func (d Dec) LT(d2 Dec) bool { return (d.i).Cmp(d2.i) < 0 } // less than
func (d Dec) LTE(d2 Dec) bool { return (d.i).Cmp(d2.i) <= 0 } // less than or equal
func (d Dec) Neg() Dec { return Dec{new(big.Int).Neg(d.i)} } // reverse the decimal sign
func (d Dec) NegMut() Dec { d.i.Neg(d.i); return d } // reverse the decimal sign, mutable
func (d Dec) Abs() Dec { return Dec{new(big.Int).Abs(d.i)} } // absolute value
func (d Dec) Set(d2 Dec) Dec { d.i.Set(d2.i); return d } // set to existing dec value
func (d Dec) Clone() Dec { return Dec{new(big.Int).Set(d.i)} } // clone new dec

// BigInt returns a copy of the underlying big.Int.
func (d Dec) BigInt() *big.Int {
Expand All @@ -224,123 +227,192 @@ func (d Dec) BigInt() *big.Int {
return copy.Set(d.i)
}

func (d Dec) ImmutOp(op func(Dec, Dec) Dec, d2 Dec) Dec {
return op(d.Clone(), d2)
}

func (d Dec) ImmutOpInt(op func(Dec, Int) Dec, d2 Int) Dec {
return op(d.Clone(), d2)
}

func (d Dec) ImmutOpInt64(op func(Dec, int64) Dec, d2 int64) Dec {
// TODO: use already allocated operand bigint to avoid
// newint each time, add mutex for race condition
return op(d.Clone(), d2)
}

func (d Dec) SetInt64(i int64) Dec {
d.i.SetInt64(i)
d.i.Mul(d.i, precisionReuse)
return d
}

// addition
func (d Dec) Add(d2 Dec) Dec {
res := new(big.Int).Add(d.i, d2.i)
return d.ImmutOp(Dec.AddMut, d2)
}

// mutable addition
func (d Dec) AddMut(d2 Dec) Dec {
d.i.Add(d.i, d2.i)

if res.BitLen() > 255+DecimalPrecisionBits {
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{res}
return d
}

// subtraction
func (d Dec) Sub(d2 Dec) Dec {
res := new(big.Int).Sub(d.i, d2.i)
return d.ImmutOp(Dec.SubMut, d2)
}

// mutable subtraction
func (d Dec) SubMut(d2 Dec) Dec {
d.i.Sub(d.i, d2.i)

if res.BitLen() > 255+DecimalPrecisionBits {
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{res}
return d
}

// multiplication
func (d Dec) Mul(d2 Dec) Dec {
mul := new(big.Int).Mul(d.i, d2.i)
chopped := chopPrecisionAndRound(mul)
return d.ImmutOp(Dec.MulMut, d2)
}

// mutable multiplication
func (d Dec) MulMut(d2 Dec) Dec {
d.i.Mul(d.i, d2.i)
chopped := chopPrecisionAndRound(d.i)

if chopped.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{chopped}
*d.i = *chopped
return d
}

// multiplication truncate
func (d Dec) MulTruncate(d2 Dec) Dec {
mul := new(big.Int).Mul(d.i, d2.i)
chopped := chopPrecisionAndTruncate(mul)
return d.ImmutOp(Dec.MulTruncateMut, d2)
}

// mutable multiplication truncage
func (d Dec) MulTruncateMut(d2 Dec) Dec {
d.i.Mul(d.i, d2.i)
chopped := chopPrecisionAndTruncate(d.i)

if chopped.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{chopped}
*d.i = *chopped
return d
}

// multiplication
func (d Dec) MulInt(i Int) Dec {
mul := new(big.Int).Mul(d.i, i.i)
return d.ImmutOpInt(Dec.MulIntMut, i)
}

if mul.BitLen() > 255+DecimalPrecisionBits {
func (d Dec) MulIntMut(i Int) Dec {
d.i.Mul(d.i, i.i)
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{mul}
return d
}

// MulInt64 - multiplication with int64
func (d Dec) MulInt64(i int64) Dec {
mul := new(big.Int).Mul(d.i, big.NewInt(i))
return d.ImmutOpInt64(Dec.MulInt64Mut, i)
}

if mul.BitLen() > 255+DecimalPrecisionBits {
func (d Dec) MulInt64Mut(i int64) Dec {
d.i.Mul(d.i, big.NewInt(i))

if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{mul}
return d
}

// quotient
func (d Dec) Quo(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoMut, d2)
}

quo := new(big.Int).Quo(mul, d2.i)
chopped := chopPrecisionAndRound(quo)
// mutable quotient
func (d Dec) QuoMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > 255+DecimalPrecisionBits {
chopPrecisionAndRound(d.i)
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient truncate
func (d Dec) QuoTruncate(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoTruncateMut, d2)
}

quo := new(big.Int).Quo(mul, d2.i)
chopped := chopPrecisionAndTruncate(quo)
// mutable quotient truncate
func (d Dec) QuoTruncateMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > 255+DecimalPrecisionBits {
chopPrecisionAndTruncate(d.i)
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient, round up
func (d Dec) QuoRoundUp(d2 Dec) Dec {
// multiply precision twice
mul := new(big.Int).Mul(d.i, precisionReuse)
mul.Mul(mul, precisionReuse)
return d.ImmutOp(Dec.QuoRoundupMut, d2)
}

quo := new(big.Int).Quo(mul, d2.i)
chopped := chopPrecisionAndRoundUp(quo)
// mutable quotient, round up
func (d Dec) QuoRoundupMut(d2 Dec) Dec {
// multiply precision twice
d.i.Mul(d.i, precisionReuse)
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if chopped.BitLen() > 255+DecimalPrecisionBits {
chopPrecisionAndRoundUp(d.i)
if d.i.BitLen() > 255+DecimalPrecisionBits {
panic("Int overflow")
}
return Dec{chopped}
return d
}

// quotient
func (d Dec) QuoInt(i Int) Dec {
mul := new(big.Int).Quo(d.i, i.i)
return Dec{mul}
return d.ImmutOpInt(Dec.QuoIntMut, i)
}

func (d Dec) QuoIntMut(i Int) Dec {
d.i.Quo(d.i, i.i)
return d
}

// QuoInt64 - quotient with int64
func (d Dec) QuoInt64(i int64) Dec {
mul := new(big.Int).Quo(d.i, big.NewInt(i))
return Dec{mul}
return d.ImmutOpInt64(Dec.QuoInt64Mut, i)
}

func (d Dec) QuoInt64Mut(i int64) Dec {
d.i.Quo(d.i, big.NewInt(i))
return d
}

// ApproxRoot returns an approximate estimation of a Dec's positive real nth root
Expand All @@ -361,8 +433,8 @@ func (d Dec) ApproxRoot(root uint64) (guess Dec, err error) {
}()

if d.IsNegative() {
absRoot, err := d.MulInt64(-1).ApproxRoot(root)
return absRoot.MulInt64(-1), err
absRoot, err := d.Neg().ApproxRoot(root)
return absRoot.NegMut(), err
}

if root == 1 || d.IsZero() || d.Equal(OneDec()) {
Expand All @@ -373,40 +445,45 @@ func (d Dec) ApproxRoot(root uint64) (guess Dec, err error) {
return OneDec(), nil
}

rootInt := NewIntFromUint64(root)
guess, delta := OneDec(), OneDec()

for iter := 0; delta.Abs().GT(SmallestDec()) && iter < maxApproxRootIterations; iter++ {
prev := guess.Power(root - 1)
if prev.IsZero() {
prev = SmallestDec()
}
delta = d.Quo(prev)
delta = delta.Sub(guess)
delta = delta.QuoInt(rootInt)
delta.Set(d).QuoMut(prev)
delta.SubMut(guess)
delta.QuoInt64Mut(int64(root))

guess = guess.Add(delta)
guess.AddMut(delta)
}

return guess, nil
}

// Power returns a the result of raising to a positive integer power
func (d Dec) Power(power uint64) Dec {
res := Dec{new(big.Int).Set(d.i)}
return res.PowerMut(power)
}

func (d Dec) PowerMut(power uint64) Dec {
// TODO: use mutable functions here
if power == 0 {
return OneDec()
}
tmp := OneDec()

for i := power; i > 1; {
if i%2 != 0 {
tmp = tmp.Mul(d)
tmp.MulMut(d)
}
i /= 2
d = d.Mul(d)
d.MulMut(d)
}

return d.Mul(tmp)
return d.MulMut(tmp)
}

// ApproxSqrt is a wrapper around ApproxRoot for the common special case
Expand Down Expand Up @@ -621,7 +698,11 @@ func (d Dec) Ceil() Dec {

// MaxSortableDec is the largest Dec that can be passed into SortableDecBytes()
// Its negative form is the least Dec that can be passed in.
var MaxSortableDec = OneDec().Quo(SmallestDec())
var MaxSortableDec Dec

func init() {
MaxSortableDec = OneDec().Quo(SmallestDec())
}

// ValidSortableDec ensures that a Dec is within the sortable bounds,
// a Dec can't have a precision of less than 10^-18.
Expand Down
3 changes: 3 additions & 0 deletions types/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ func (s *decimalTestSuite) TestPower() {
for i, tc := range testCases {
res := tc.input.Power(tc.power)
s.Require().True(tc.expected.Sub(res).Abs().LTE(sdk.SmallestDec()), "unexpected result for test case %d, input: %v", i, tc.input)
s.Require().True(tc.expected.Sub(tc.input.PowerMut(tc.power)).Abs().LTE(sdk.SmallestDec()),
"unexpected result for test case %d, input %v", i, tc.input)
s.Require().True(res.Equal(tc.input), "unexpected result for test case %d, input: %v", i, tc.input)
}
}

Expand Down

0 comments on commit 44a8b5a

Please sign in to comment.