Skip to content

Commit

Permalink
Merge pull request #12 from lightstep/jmacd/nan_check
Browse files Browse the repository at this point in the history
Check for NaN values; return error instead of panicking
  • Loading branch information
jmacd authored Nov 15, 2019
2 parents f865a35 + cd12c03 commit 358db24
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
14 changes: 10 additions & 4 deletions varopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package varopt
import (
"container/heap"
"fmt"
"math"
"math/rand"
)

Expand Down Expand Up @@ -48,6 +49,8 @@ type vsample struct {

type largeHeap []vsample

var ErrInvalidWeight = fmt.Errorf("Negative, zero, or NaN weight")

// New returns a new Varopt sampler with given capacity (i.e.,
// reservoir size) and random number generator.
func New(capacity int, rnd *rand.Rand) *Varopt {
Expand All @@ -58,22 +61,24 @@ func New(capacity int, rnd *rand.Rand) *Varopt {
}

// Add considers a new observation for the sample with given weight.
func (s *Varopt) Add(sample Sample, weight float64) {
//
// An error will be returned if the weight is either negative or NaN.
func (s *Varopt) Add(sample Sample, weight float64) error {
individual := vsample{
sample: sample,
weight: weight,
}

if weight <= 0 {
panic(fmt.Sprint("Invalid weight <= 0: ", weight))
if weight <= 0 || math.IsNaN(weight) {
return ErrInvalidWeight
}

s.totalCount++
s.totalWeight += weight

if s.Size() < s.capacity {
heap.Push(&s.L, individual)
return
return nil
}

// the X <- {} step from the paper is not done here,
Expand Down Expand Up @@ -115,6 +120,7 @@ func (s *Varopt) Add(sample Sample, weight float64) {
}
s.T = append(s.T, s.X...)
s.X = s.X[:0]
return nil
}

func (s *Varopt) uniform() float64 {
Expand Down
14 changes: 14 additions & 0 deletions varopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ func testUnbiased(t *testing.T, bbr, bsr float64) {
[][][]varopt.Sample{smallBlocks, bigBlocks},
)
}

func TestInvalidWeight(t *testing.T) {
rnd := rand.New(rand.NewSource(98887))
v := varopt.New(1, rnd)

err := v.Add(nil, math.NaN())
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, -1)
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, 0)
require.Equal(t, err, varopt.ErrInvalidWeight)
}

0 comments on commit 358db24

Please sign in to comment.