diff --git a/inmem.go b/inmem.go index 71deef1..7c427ac 100644 --- a/inmem.go +++ b/inmem.go @@ -55,6 +55,10 @@ type IntervalMetrics struct { // Samples maps the key to an AggregateSample, // which has the rolled up view of a sample Samples map[string]SampledValue + + // done is closed when this interval has ended, and a new IntervalMetrics + // has been created to receive any future metrics. + done chan struct{} } // NewIntervalMetrics creates a new IntervalMetrics for a given interval @@ -65,6 +69,7 @@ func NewIntervalMetrics(intv time.Time) *IntervalMetrics { Points: make(map[string][]float32), Counters: make(map[string]SampledValue), Samples: make(map[string]SampledValue), + done: make(chan struct{}), } } @@ -270,33 +275,39 @@ func (i *InmemSink) Data() []*IntervalMetrics { return intervals } -func (i *InmemSink) getExistingInterval(intv time.Time) *IntervalMetrics { - i.intervalLock.RLock() - defer i.intervalLock.RUnlock() +// getInterval returns the current interval. A new interval is created if no +// previous interval exists, or if the current time is beyond the window for the +// current interval. +func (i *InmemSink) getInterval() *IntervalMetrics { + intv := time.Now().Truncate(i.interval) + // Attempt to return the existing interval first, because it only requires + // a read lock. + i.intervalLock.RLock() n := len(i.intervals) if n > 0 && i.intervals[n-1].Interval == intv { + defer i.intervalLock.RUnlock() return i.intervals[n-1] } - return nil -} + i.intervalLock.RUnlock() -func (i *InmemSink) createInterval(intv time.Time) *IntervalMetrics { i.intervalLock.Lock() defer i.intervalLock.Unlock() - // Check for an existing interval - n := len(i.intervals) + // Re-check for an existing interval now that the lock is re-acquired. + n = len(i.intervals) if n > 0 && i.intervals[n-1].Interval == intv { return i.intervals[n-1] } - // Add the current interval current := NewIntervalMetrics(intv) i.intervals = append(i.intervals, current) - n++ + if n > 0 { + close(i.intervals[n-1].done) + } - // Truncate the intervals if they are too long + n++ + // Prune old intervals if the count exceeds the max. if n >= i.maxIntervals { copy(i.intervals[0:], i.intervals[n-i.maxIntervals:]) i.intervals = i.intervals[:i.maxIntervals] @@ -304,15 +315,6 @@ func (i *InmemSink) createInterval(intv time.Time) *IntervalMetrics { return current } -// getInterval returns the current interval to write to -func (i *InmemSink) getInterval() *IntervalMetrics { - intv := time.Now().Truncate(i.interval) - if m := i.getExistingInterval(intv); m != nil { - return m - } - return i.createInterval(intv) -} - // Flattens the key for formatting, removes spaces func (i *InmemSink) flattenKey(parts []string) string { buf := &bytes.Buffer{} diff --git a/inmem_endpoint.go b/inmem_endpoint.go index 5fac958..24eefa9 100644 --- a/inmem_endpoint.go +++ b/inmem_endpoint.go @@ -1,6 +1,7 @@ package metrics import ( + "context" "fmt" "net/http" "sort" @@ -68,6 +69,10 @@ func (i *InmemSink) DisplayMetrics(resp http.ResponseWriter, req *http.Request) interval = data[n-2] } + return newMetricSummaryFromInterval(interval), nil +} + +func newMetricSummaryFromInterval(interval *IntervalMetrics) MetricsSummary { interval.RLock() defer interval.RUnlock() @@ -103,7 +108,7 @@ func (i *InmemSink) DisplayMetrics(resp http.ResponseWriter, req *http.Request) summary.Counters = formatSamples(interval.Counters) summary.Samples = formatSamples(interval.Samples) - return summary, nil + return summary } func formatSamples(source map[string]SampledValue) []SampledValue { @@ -129,3 +134,29 @@ func formatSamples(source map[string]SampledValue) []SampledValue { return output } + +type Encoder interface { + Encode(interface{}) error +} + +// Stream writes metrics using encoder.Encode each time an interval ends. Runs +// until the request context is cancelled, or the encoder returns an error. +// The caller is responsible for logging any errors from encoder. +func (i *InmemSink) Stream(ctx context.Context, encoder Encoder) { + interval := i.getInterval() + + for { + select { + case <-interval.done: + summary := newMetricSummaryFromInterval(interval) + if err := encoder.Encode(summary); err != nil { + return + } + + // update interval to the next one + interval = i.getInterval() + case <-ctx.Done(): + return + } + } +} diff --git a/inmem_endpoint_test.go b/inmem_endpoint_test.go index bb3ebe0..baccd8b 100644 --- a/inmem_endpoint_test.go +++ b/inmem_endpoint_test.go @@ -1,6 +1,11 @@ package metrics import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" "testing" "time" @@ -273,3 +278,72 @@ func TestDisplayMetrics_RaceMetricsSetGauge(t *testing.T) { verify.Values(t, "all", got, float32(42)) } +func TestInmemSink_Stream(t *testing.T) { + interval := 10 * time.Millisecond + total := 50 * time.Millisecond + inm := NewInmemSink(interval, total) + + ctx, cancel := context.WithTimeout(context.Background(), total*2) + defer cancel() + + chDone := make(chan struct{}) + + go func() { + for i := float32(0); ctx.Err() == nil; i++ { + inm.SetGaugeWithLabels([]string{"gauge", "foo"}, 20+i, []Label{{"a", "b"}}) + inm.EmitKey([]string{"key", "foo"}, 30+i) + inm.IncrCounterWithLabels([]string{"counter", "bar"}, 40+i, []Label{{"a", "b"}}) + inm.IncrCounterWithLabels([]string{"counter", "bar"}, 50+i, []Label{{"a", "b"}}) + inm.AddSampleWithLabels([]string{"sample", "bar"}, 60+i, []Label{{"a", "b"}}) + inm.AddSampleWithLabels([]string{"sample", "bar"}, 70+i, []Label{{"a", "b"}}) + time.Sleep(interval / 3) + } + close(chDone) + }() + + resp := httptest.NewRecorder() + enc := encoder{ + encoder: json.NewEncoder(resp), + flusher: resp, + } + inm.Stream(ctx, enc) + + <-chDone + + decoder := json.NewDecoder(resp.Body) + var prevGaugeValue float32 + for i := 0; i < 8; i++ { + var summary MetricsSummary + if err := decoder.Decode(&summary); err != nil { + t.Fatalf("expected no error while decoding response %d, got %v", i, err) + } + if count := len(summary.Gauges); count != 1 { + t.Fatalf("expected at least one gauge in response %d, got %v", i, count) + } + value := summary.Gauges[0].Value + // The upper bound of the gauge value is not known, but we can expect it + // to be less than 50 because it increments by 3 every interval and we run + // for ~10 intervals. + if value < 20 || value > 50 { + t.Fatalf("expected interval %d guage value between 20 and 50, got %v", i, value) + } + if value <= prevGaugeValue { + t.Fatalf("expected interval %d guage value to be greater than previous, %v == %v", i, value, prevGaugeValue) + } + prevGaugeValue = value + } +} + +type encoder struct { + flusher http.Flusher + encoder *json.Encoder +} + +func (e encoder) Encode(metrics interface{}) error { + if err := e.encoder.Encode(metrics); err != nil { + fmt.Println("failed to encode metrics summary", "error", err) + return err + } + e.flusher.Flush() + return nil +}