Skip to content

Commit

Permalink
allow signed + unsigned = integers
Browse files Browse the repository at this point in the history
  • Loading branch information
aalu1418 committed Apr 22, 2024
1 parent 8e5c6f4 commit 497ea02
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
8 changes: 5 additions & 3 deletions pkg/utils/mathutil/mathutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mathutil

import (
"fmt"
"math"

"golang.org/x/exp/constraints"
)
Expand All @@ -26,15 +27,16 @@ func Min[V constraints.Ordered](first V, vals ...V) V {
return min
}

func Avg[V constraints.Unsigned](arr []V) (V, error) {
func Avg[V constraints.Integer](arr ...V) (V, error) {
total := V(0)

for _, v := range arr {
prev := total
total += v

// check addition overflow
if total < prev {
// check addition overflow (positive + negative)
if (total < prev && !math.Signbit(float64(v))) ||
(total > prev && math.Signbit(float64(v))) {
return 0, fmt.Errorf("overflow: addition %T", V(0))
}
}
Expand Down
17 changes: 10 additions & 7 deletions pkg/utils/mathutil/mathutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mathutil

import (
"fmt"
"math"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -35,21 +36,23 @@ func TestMin(t *testing.T) {

func TestAvg(t *testing.T) {
// happy path
r, err := Avg([]uint8{1, 2, 3})
r, err := Avg(int8(1), -2, 4)
assert.NoError(t, err)
assert.Equal(t, uint8(2), r)
assert.Equal(t, int8(1), r)

// single element
r, err = Avg([]uint8{0})
r, err = Avg(int8(0))
assert.NoError(t, err)
assert.Equal(t, uint8(0), r)
assert.Equal(t, int8(0), r)

// overflow addition
r, err = Avg([]uint8{255, 1})
r, err = Avg(int8(math.MaxInt8), 1)
assert.ErrorContains(t, err, fmt.Sprintf("overflow: addition"))
r, err = Avg(int8(math.MinInt8), -1)
assert.ErrorContains(t, err, fmt.Sprintf("overflow: addition"))

// overflow length
a := make([]uint8, 256)
r, err = Avg(a)
a := make([]int8, 256)
r, err = Avg(a...)
assert.ErrorContains(t, err, "overflow: array len")
}

0 comments on commit 497ea02

Please sign in to comment.