diff --git a/varopt.go b/varopt.go index 5ab1ee5..22751c2 100644 --- a/varopt.go +++ b/varopt.go @@ -5,6 +5,7 @@ package varopt import ( "container/heap" "fmt" + "math" "math/rand" ) @@ -48,6 +49,8 @@ type vsample struct { type largeHeap []vsample +var ErrInvalidWeight = fmt.Errorf("Negative, zero, or NaN weight") + // New returns a new Varopt sampler with given capacity (i.e., // reservoir size) and random number generator. func New(capacity int, rnd *rand.Rand) *Varopt { @@ -58,14 +61,16 @@ func New(capacity int, rnd *rand.Rand) *Varopt { } // Add considers a new observation for the sample with given weight. -func (s *Varopt) Add(sample Sample, weight float64) { +// +// An error will be returned if the weight is either negative or NaN. +func (s *Varopt) Add(sample Sample, weight float64) error { individual := vsample{ sample: sample, weight: weight, } - if weight <= 0 { - panic(fmt.Sprint("Invalid weight <= 0: ", weight)) + if weight <= 0 || math.IsNaN(weight) { + return ErrInvalidWeight } s.totalCount++ @@ -73,7 +78,7 @@ func (s *Varopt) Add(sample Sample, weight float64) { if s.Size() < s.capacity { heap.Push(&s.L, individual) - return + return nil } // the X <- {} step from the paper is not done here, @@ -115,6 +120,7 @@ func (s *Varopt) Add(sample Sample, weight float64) { } s.T = append(s.T, s.X...) s.X = s.X[:0] + return nil } func (s *Varopt) uniform() float64 { diff --git a/varopt_test.go b/varopt_test.go index 06f7bc1..c7f0b68 100644 --- a/varopt_test.go +++ b/varopt_test.go @@ -147,3 +147,17 @@ func testUnbiased(t *testing.T, bbr, bsr float64) { [][][]varopt.Sample{smallBlocks, bigBlocks}, ) } + +func TestInvalidWeight(t *testing.T) { + rnd := rand.New(rand.NewSource(98887)) + v := varopt.New(1, rnd) + + err := v.Add(nil, math.NaN()) + require.Equal(t, err, varopt.ErrInvalidWeight) + + err = v.Add(nil, -1) + require.Equal(t, err, varopt.ErrInvalidWeight) + + err = v.Add(nil, 0) + require.Equal(t, err, varopt.ErrInvalidWeight) +}