Skip to content

Commit

Permalink
Make math.Add64 and math.Mul64 generic (#3205)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Jul 19, 2024
1 parent d4ec4e7 commit 9a6418c
Show file tree
Hide file tree
Showing 41 changed files with 122 additions and 111 deletions.
4 changes: 2 additions & 2 deletions genesis/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ func (c Config) Unparse() (UnparsedConfig, error) {
func (c *Config) InitialSupply() (uint64, error) {
initialSupply := uint64(0)
for _, allocation := range c.Allocations {
newInitialSupply, err := math.Add64(initialSupply, allocation.InitialAmount)
newInitialSupply, err := math.Add(initialSupply, allocation.InitialAmount)
if err != nil {
return 0, err
}
for _, unlock := range allocation.UnlockSchedule {
newInitialSupply, err = math.Add64(newInitialSupply, unlock.Amount)
newInitialSupply, err = math.Add(newInitialSupply, unlock.Amount)
if err != nil {
return 0, err
}
Expand Down
4 changes: 2 additions & 2 deletions snow/consensus/snowman/bootstrapper/majority.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (m *Majority) RecordOpinion(_ context.Context, nodeID ids.NodeID, blkIDs se

weight := m.nodeWeights[nodeID]
for blkID := range blkIDs {
newWeight, err := math.Add64(m.received[blkID], weight)
newWeight, err := math.Add(m.received[blkID], weight)
if err != nil {
return err
}
Expand All @@ -84,7 +84,7 @@ func (m *Majority) RecordOpinion(_ context.Context, nodeID ids.NodeID, blkIDs se
err error
)
for _, weight := range m.nodeWeights {
totalWeight, err = math.Add64(totalWeight, weight)
totalWeight, err = math.Add(totalWeight, weight)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion snow/consensus/snowman/bootstrapper/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func Sample[T comparable](elements map[T]uint64, maxSize int) (set.Set[T], error
for key, weight := range elements {
keys[i] = key
weights[i] = weight
totalWeight, err = math.Add64(totalWeight, weight)
totalWeight, err = math.Add(totalWeight, weight)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion snow/engine/avalanche/state/serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *Serializer) BuildStopVtx(
return nil, err
}
parentHeight := parent.v.vtx.Height()
childHeight, err := math.Add64(parentHeight, 1)
childHeight, err := math.Add(parentHeight, 1)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion snow/engine/snowman/syncer/state_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func (ss *stateSyncer) AcceptedStateSummary(ctx context.Context, nodeID ids.Node
continue
}

newWeight, err := safemath.Add64(nodeWeight, ws.weight)
newWeight, err := safemath.Add(nodeWeight, ws.weight)
if err != nil {
ss.Ctx.Log.Error("failed to calculate new summary weight",
zap.Stringer("nodeID", nodeID),
Expand Down
4 changes: 2 additions & 2 deletions snow/engine/snowman/transitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (t *Transitive) Gossip(ctx context.Context) error {
return nil
}

nextHeightToAccept, err := math.Add64(lastAcceptedHeight, 1)
nextHeightToAccept, err := math.Add(lastAcceptedHeight, 1)
if err != nil {
t.Ctx.Log.Error("skipping block gossip",
zap.String("reason", "block height overflow"),
Expand Down Expand Up @@ -886,7 +886,7 @@ func (t *Transitive) sendQuery(
}

_, lastAcceptedHeight := t.Consensus.LastAccepted()
nextHeightToAccept, err := math.Add64(lastAcceptedHeight, 1)
nextHeightToAccept, err := math.Add(lastAcceptedHeight, 1)
if err != nil {
t.Ctx.Log.Error("dropped query for block",
zap.String("reason", "block height overflow"),
Expand Down
2 changes: 1 addition & 1 deletion snow/networking/benchlist/benchlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (b *benchlist) bench(nodeID ids.NodeID) {
return
}

newBenchedStake, err := safemath.Add64(benchedStake, validatorStake)
newBenchedStake, err := safemath.Add(benchedStake, validatorStake)
if err != nil {
// This should never happen
b.ctx.Log.Error("overflow calculating new benched stake",
Expand Down
4 changes: 2 additions & 2 deletions snow/validators/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (s *vdrSet) addWeight(nodeID ids.NodeID, weight uint64) error {
}

oldWeight := vdr.Weight
newWeight, err := math.Add64(oldWeight, weight)
newWeight, err := math.Add(oldWeight, weight)
if err != nil {
return err
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func (s *vdrSet) subsetWeight(subset set.Set[ids.NodeID]) (uint64, error) {
err error
)
for nodeID := range subset {
totalWeight, err = math.Add64(totalWeight, s.getWeight(nodeID))
totalWeight, err = math.Add(totalWeight, s.getWeight(nodeID))
if err != nil {
return 0, err
}
Expand Down
36 changes: 18 additions & 18 deletions utils/math/safe_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@ package math

import (
"errors"
"math"

"golang.org/x/exp/constraints"

"github.com/ava-labs/avalanchego/utils"
)

var (
ErrOverflow = errors.New("overflow")
ErrUnderflow = errors.New("underflow")

// Deprecated: Add64 is deprecated. Use Add[uint64] instead.
Add64 = Add[uint64]

// Deprecated: Mul64 is deprecated. Use Mul[uint64] instead.
Mul64 = Mul[uint64]
)

// Add64 returns:
// MaxUint returns the maximum value of an unsigned integer of type T.
func MaxUint[T constraints.Unsigned]() T {
return ^T(0)
}

// Add returns:
// 1) a + b
// 2) If there is overflow, an error
//
// Note that we don't have a generic Add function because checking for
// an overflow requires knowing the max size of a given type, which we
// don't know if we're adding generic types.
func Add64(a, b uint64) (uint64, error) {
if a > math.MaxUint64-b {
func Add[T constraints.Unsigned](a, b T) (T, error) {
if a > MaxUint[T]()-b {
return 0, ErrOverflow
}
return a + b, nil
Expand All @@ -36,20 +40,16 @@ func Add64(a, b uint64) (uint64, error) {
// 2) If there is underflow, an error
func Sub[T constraints.Unsigned](a, b T) (T, error) {
if a < b {
return utils.Zero[T](), ErrUnderflow
return 0, ErrUnderflow
}
return a - b, nil
}

// Mul64 returns:
// Mul returns:
// 1) a * b
// 2) If there is overflow, an error
//
// Note that we don't have a generic Mul function because checking for
// an overflow requires knowing the max size of a given type, which we
// don't know if we're adding generic types.
func Mul64(a, b uint64) (uint64, error) {
if b != 0 && a > math.MaxUint64/b {
func Mul[T constraints.Unsigned](a, b T) (T, error) {
if b != 0 && a > MaxUint[T]()/b {
return 0, ErrOverflow
}
return a * b, nil
Expand Down
41 changes: 26 additions & 15 deletions utils/math/safe_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,39 @@ import (

const maxUint64 uint64 = math.MaxUint64

func TestAdd64(t *testing.T) {
func TestMaxUint(t *testing.T) {
require := require.New(t)

sum, err := Add64(0, maxUint64)
require.Equal(uint(math.MaxUint), MaxUint[uint]())
require.Equal(uint8(math.MaxUint8), MaxUint[uint8]())
require.Equal(uint16(math.MaxUint16), MaxUint[uint16]())
require.Equal(uint32(math.MaxUint32), MaxUint[uint32]())
require.Equal(uint64(math.MaxUint64), MaxUint[uint64]())
require.Equal(uintptr(math.MaxUint), MaxUint[uintptr]())
}

func TestAdd(t *testing.T) {
require := require.New(t)

sum, err := Add(0, maxUint64)
require.NoError(err)
require.Equal(maxUint64, sum)

sum, err = Add64(maxUint64, 0)
sum, err = Add(maxUint64, 0)
require.NoError(err)
require.Equal(maxUint64, sum)

sum, err = Add64(uint64(1<<62), uint64(1<<62))
sum, err = Add(uint64(1<<62), uint64(1<<62))
require.NoError(err)
require.Equal(uint64(1<<63), sum)

_, err = Add64(1, maxUint64)
_, err = Add(1, maxUint64)
require.ErrorIs(err, ErrOverflow)

_, err = Add64(maxUint64, 1)
_, err = Add(maxUint64, 1)
require.ErrorIs(err, ErrOverflow)

_, err = Add64(maxUint64, maxUint64)
_, err = Add(maxUint64, maxUint64)
require.ErrorIs(err, ErrOverflow)
}

Expand Down Expand Up @@ -63,34 +74,34 @@ func TestSub(t *testing.T) {
require.ErrorIs(err, ErrUnderflow)
}

func TestMul64(t *testing.T) {
func TestMul(t *testing.T) {
require := require.New(t)

got, err := Mul64(0, maxUint64)
got, err := Mul(0, maxUint64)
require.NoError(err)
require.Zero(got)

got, err = Mul64(maxUint64, 0)
got, err = Mul(maxUint64, 0)
require.NoError(err)
require.Zero(got)

got, err = Mul64(uint64(1), uint64(3))
got, err = Mul(uint64(1), uint64(3))
require.NoError(err)
require.Equal(uint64(3), got)

got, err = Mul64(uint64(3), uint64(1))
got, err = Mul(uint64(3), uint64(1))
require.NoError(err)
require.Equal(uint64(3), got)

got, err = Mul64(uint64(2), uint64(3))
got, err = Mul(uint64(2), uint64(3))
require.NoError(err)
require.Equal(uint64(6), got)

got, err = Mul64(maxUint64, 0)
got, err = Mul(maxUint64, 0)
require.NoError(err)
require.Zero(got)

_, err = Mul64(maxUint64-1, 2)
_, err = Mul(maxUint64-1, 2)
require.ErrorIs(err, ErrOverflow)
}

Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *weightedArray) Initialize(weights []uint64) error {

cumulativeWeight := uint64(0)
for i := 0; i < len(s.arr); i++ {
newWeight, err := math.Add64(
newWeight, err := math.Add(
cumulativeWeight,
s.arr[i].cumulativeWeight,
)
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func CalcWeightedPoW(exponent float64, size int) (uint64, []uint64, error) {
weight := uint64(math.Pow(float64(i+1), exponent))
weights[i] = weight

newWeight, err := safemath.Add64(totalWeight, weight)
newWeight, err := safemath.Add(totalWeight, weight)
if err != nil {
return 0, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_best.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type weightedBest struct {
func (s *weightedBest) Initialize(weights []uint64) error {
totalWeight := uint64(0)
for _, weight := range weights {
newWeight, err := safemath.Add64(totalWeight, weight)
newWeight, err := safemath.Add(totalWeight, weight)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *weightedHeap) Initialize(weights []uint64) error {
// Explicitly performing a shift here allows the compiler to avoid
// checking for negative numbers, which saves a couple cycles
parentIndex := (i - 1) >> 1
newWeight, err := math.Add64(
newWeight, err := math.Add(
s.heap[parentIndex].cumulativeWeight,
s.heap[i].cumulativeWeight,
)
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_linear.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *weightedLinear) Initialize(weights []uint64) error {
utils.Sort(s.arr)

for i := 1; i < len(s.arr); i++ {
newWeight, err := math.Add64(
newWeight, err := math.Add(
s.arr[i-1].cumulativeWeight,
s.arr[i].cumulativeWeight,
)
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_uniform.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type weightedUniform struct {
func (s *weightedUniform) Initialize(weights []uint64) error {
totalWeight := uint64(0)
for _, weight := range weights {
newWeight, err := safemath.Add64(totalWeight, weight)
newWeight, err := safemath.Add(totalWeight, weight)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion utils/sampler/weighted_without_replacement_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type weightedWithoutReplacementGeneric struct {
func (s *weightedWithoutReplacementGeneric) Initialize(weights []uint64) error {
totalWeight := uint64(0)
for _, weight := range weights {
newWeight, err := safemath.Add64(totalWeight, weight)
newWeight, err := safemath.Add(totalWeight, weight)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions vms/avm/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ func (s *Service) GetBalance(_ *http.Request, args *GetBalanceArgs, reply *GetBa
if !args.IncludePartial && (len(owners.Addrs) != 1 || owners.Locktime > now) {
continue
}
amt, err := safemath.Add64(transferable.Amount(), uint64(reply.Balance))
amt, err := safemath.Add(transferable.Amount(), uint64(reply.Balance))
if err != nil {
return err
}
Expand Down Expand Up @@ -650,7 +650,7 @@ func (s *Service) GetAllBalances(_ *http.Request, args *GetAllBalancesArgs, repl
assetID := utxo.AssetID()
assetIDs.Add(assetID)
balance := balances[assetID] // 0 if key doesn't exist
balance, err := safemath.Add64(transferable.Amount(), balance)
balance, err := safemath.Add(transferable.Amount(), balance)
if err != nil {
balances[assetID] = math.MaxUint64
} else {
Expand Down Expand Up @@ -1264,7 +1264,7 @@ func (s *Service) buildSendMultiple(args *SendMultipleArgs) (*txs.Tx, ids.ShortI
assetIDs[output.AssetID] = assetID
}
currentAmount := amounts[assetID]
newAmount, err := safemath.Add64(currentAmount, uint64(output.Amount))
newAmount, err := safemath.Add(currentAmount, uint64(output.Amount))
if err != nil {
return nil, ids.ShortEmpty, fmt.Errorf("problem calculating required spend amount: %w", err)
}
Expand Down Expand Up @@ -1295,7 +1295,7 @@ func (s *Service) buildSendMultiple(args *SendMultipleArgs) (*txs.Tx, ids.ShortI
amountsWithFee[assetID] = amount
}

amountWithFee, err := safemath.Add64(amounts[s.vm.feeAssetID], s.vm.TxFee)
amountWithFee, err := safemath.Add(amounts[s.vm.feeAssetID], s.vm.TxFee)
if err != nil {
return nil, ids.ShortEmpty, fmt.Errorf("problem calculating required spend amount: %w", err)
}
Expand Down Expand Up @@ -1818,7 +1818,7 @@ func (s *Service) buildImport(args *ImportArgs) (*txs.Tx, error) {
return nil, err
}
for asset, amount := range localAmountsSpent {
newAmount, err := safemath.Add64(amountsSpent[asset], amount)
newAmount, err := safemath.Add(amountsSpent[asset], amount)
if err != nil {
return nil, fmt.Errorf("problem calculating required spend amount: %w", err)
}
Expand Down Expand Up @@ -1955,7 +1955,7 @@ func (s *Service) buildExport(args *ExportArgs) (*txs.Tx, ids.ShortID, error) {

amounts := map[ids.ID]uint64{}
if assetID == s.vm.feeAssetID {
amountWithFee, err := safemath.Add64(uint64(args.Amount), s.vm.TxFee)
amountWithFee, err := safemath.Add(uint64(args.Amount), s.vm.TxFee)
if err != nil {
return nil, ids.ShortEmpty, fmt.Errorf("problem calculating required spend amount: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions vms/avm/utxo/spender.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (s *spender) Spend(
// this input doesn't have an amount, so I don't care about it here
continue
}
newAmountSpent, err := math.Add64(amountSpent, input.Amount())
newAmountSpent, err := math.Add(amountSpent, input.Amount())
if err != nil {
// there was an error calculating the consumed amount, just error
return nil, nil, nil, errSpendOverflow
Expand Down Expand Up @@ -274,7 +274,7 @@ func (s *spender) SpendAll(
// this input doesn't have an amount, so I don't care about it here
continue
}
newAmountSpent, err := math.Add64(amountSpent, input.Amount())
newAmountSpent, err := math.Add(amountSpent, input.Amount())
if err != nil {
// there was an error calculating the consumed amount, just error
return nil, nil, nil, errSpendOverflow
Expand Down
Loading

0 comments on commit 9a6418c

Please sign in to comment.