From 497ea02745d1c34f21f78f5708e21a0f53bb3dc9 Mon Sep 17 00:00:00 2001 From: aalu1418 <50029043+aalu1418@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:35:35 -0600 Subject: [PATCH] allow signed + unsigned = integers --- pkg/utils/mathutil/mathutil.go | 8 +++++--- pkg/utils/mathutil/mathutil_test.go | 17 ++++++++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pkg/utils/mathutil/mathutil.go b/pkg/utils/mathutil/mathutil.go index 370fe871f..4c7da09f0 100644 --- a/pkg/utils/mathutil/mathutil.go +++ b/pkg/utils/mathutil/mathutil.go @@ -2,6 +2,7 @@ package mathutil import ( "fmt" + "math" "golang.org/x/exp/constraints" ) @@ -26,15 +27,16 @@ func Min[V constraints.Ordered](first V, vals ...V) V { return min } -func Avg[V constraints.Unsigned](arr []V) (V, error) { +func Avg[V constraints.Integer](arr ...V) (V, error) { total := V(0) for _, v := range arr { prev := total total += v - // check addition overflow - if total < prev { + // check addition overflow (positive + negative) + if (total < prev && !math.Signbit(float64(v))) || + (total > prev && math.Signbit(float64(v))) { return 0, fmt.Errorf("overflow: addition %T", V(0)) } } diff --git a/pkg/utils/mathutil/mathutil_test.go b/pkg/utils/mathutil/mathutil_test.go index ab3c796bf..8ef63510b 100644 --- a/pkg/utils/mathutil/mathutil_test.go +++ b/pkg/utils/mathutil/mathutil_test.go @@ -2,6 +2,7 @@ package mathutil import ( "fmt" + "math" "testing" "github.com/stretchr/testify/assert" @@ -35,21 +36,23 @@ func TestMin(t *testing.T) { func TestAvg(t *testing.T) { // happy path - r, err := Avg([]uint8{1, 2, 3}) + r, err := Avg(int8(1), -2, 4) assert.NoError(t, err) - assert.Equal(t, uint8(2), r) + assert.Equal(t, int8(1), r) // single element - r, err = Avg([]uint8{0}) + r, err = Avg(int8(0)) assert.NoError(t, err) - assert.Equal(t, uint8(0), r) + assert.Equal(t, int8(0), r) // overflow addition - r, err = Avg([]uint8{255, 1}) + r, err = Avg(int8(math.MaxInt8), 1) + assert.ErrorContains(t, err, fmt.Sprintf("overflow: addition")) + r, err = Avg(int8(math.MinInt8), -1) assert.ErrorContains(t, err, fmt.Sprintf("overflow: addition")) // overflow length - a := make([]uint8, 256) - r, err = Avg(a) + a := make([]int8, 256) + r, err = Avg(a...) assert.ErrorContains(t, err, "overflow: array len") }