From db2f575931030ee0fdd5ad2324d12f8517f2e96c Mon Sep 17 00:00:00 2001 From: jmacd Date: Tue, 12 Nov 2019 23:26:37 -0800 Subject: [PATCH 1/3] Check for NaN weight --- varopt.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/varopt.go b/varopt.go index 5ab1ee5..fc802c5 100644 --- a/varopt.go +++ b/varopt.go @@ -5,6 +5,7 @@ package varopt import ( "container/heap" "fmt" + "math" "math/rand" ) @@ -64,8 +65,8 @@ func (s *Varopt) Add(sample Sample, weight float64) { weight: weight, } - if weight <= 0 { - panic(fmt.Sprint("Invalid weight <= 0: ", weight)) + if weight <= 0 || math.IsNaN(weight) { + panic(fmt.Sprint("Invalid weight: ", weight)) } s.totalCount++ From 5cd2650f3fdc3f9c914b102ec11df57c687d439d Mon Sep 17 00:00:00 2001 From: jmacd Date: Wed, 13 Nov 2019 21:51:59 -0800 Subject: [PATCH 2/3] Return an error, don't panic --- varopt.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/varopt.go b/varopt.go index fc802c5..f786d12 100644 --- a/varopt.go +++ b/varopt.go @@ -49,6 +49,8 @@ type vsample struct { type largeHeap []vsample +var ErrInvalidWeight = fmt.Errorf("Negative or NaN weight") + // New returns a new Varopt sampler with given capacity (i.e., // reservoir size) and random number generator. func New(capacity int, rnd *rand.Rand) *Varopt { @@ -59,14 +61,16 @@ func New(capacity int, rnd *rand.Rand) *Varopt { } // Add considers a new observation for the sample with given weight. -func (s *Varopt) Add(sample Sample, weight float64) { +// +// An error will be returned if the weight is either negative or NaN. +func (s *Varopt) Add(sample Sample, weight float64) error { individual := vsample{ sample: sample, weight: weight, } if weight <= 0 || math.IsNaN(weight) { - panic(fmt.Sprint("Invalid weight: ", weight)) + return ErrInvalidWeight } s.totalCount++ @@ -74,7 +78,7 @@ func (s *Varopt) Add(sample Sample, weight float64) { if s.Size() < s.capacity { heap.Push(&s.L, individual) - return + return nil } // the X <- {} step from the paper is not done here, @@ -116,6 +120,7 @@ func (s *Varopt) Add(sample Sample, weight float64) { } s.T = append(s.T, s.X...) s.X = s.X[:0] + return nil } func (s *Varopt) uniform() float64 { From b00b2fa4dd994f7f8bcb12518bcc1d84de31c555 Mon Sep 17 00:00:00 2001 From: jmacd Date: Fri, 15 Nov 2019 12:42:56 -0800 Subject: [PATCH 3/3] Add a test --- varopt.go | 2 +- varopt_test.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/varopt.go b/varopt.go index f786d12..22751c2 100644 --- a/varopt.go +++ b/varopt.go @@ -49,7 +49,7 @@ type vsample struct { type largeHeap []vsample -var ErrInvalidWeight = fmt.Errorf("Negative or NaN weight") +var ErrInvalidWeight = fmt.Errorf("Negative, zero, or NaN weight") // New returns a new Varopt sampler with given capacity (i.e., // reservoir size) and random number generator. diff --git a/varopt_test.go b/varopt_test.go index 06f7bc1..c7f0b68 100644 --- a/varopt_test.go +++ b/varopt_test.go @@ -147,3 +147,17 @@ func testUnbiased(t *testing.T, bbr, bsr float64) { [][][]varopt.Sample{smallBlocks, bigBlocks}, ) } + +func TestInvalidWeight(t *testing.T) { + rnd := rand.New(rand.NewSource(98887)) + v := varopt.New(1, rnd) + + err := v.Add(nil, math.NaN()) + require.Equal(t, err, varopt.ErrInvalidWeight) + + err = v.Add(nil, -1) + require.Equal(t, err, varopt.ErrInvalidWeight) + + err = v.Add(nil, 0) + require.Equal(t, err, varopt.ErrInvalidWeight) +}