diff --git a/CHANGELOG.md b/CHANGELOG.md index c4bea447f7be..7b12f0d411d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,6 +106,13 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm Instead it uses the `net.sock.peer` attributes. (#3581) - The parameters for the `RegisterCallback` method of the `Meter` from `go.opentelemetry.io/otel/metric` are changed. The slice of `instrument.Asynchronous` parameter is now passed as a variadic argument. (#3587) +- The `Callback` in `go.opentelemetry.io/otel/metric` has the added `Observer` parameter added. + This new parameter is used by `Callback` implementations to observe values for asynchronous instruments instead of calling the `Observe` method of the instrument directly. (#3584) + +### Fixed + +- The `RegisterCallback` method of the `Meter` from `go.opentelemetry.io/otel/sdk/metric` only registers a callback for instruments created by that meter. + Trying to register a callback with instruments from a different meter will result in an error being returned. (#3584) ### Deprecated diff --git a/example/prometheus/main.go b/example/prometheus/main.go index 4e9be3b84da6..5c02d01c9b55 100644 --- a/example/prometheus/main.go +++ b/example/prometheus/main.go @@ -28,6 +28,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/prometheus" + api "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/sdk/metric" ) @@ -68,9 +69,9 @@ func main() { if err != nil { log.Fatal(err) } - _, err = meter.RegisterCallback(func(ctx context.Context) error { + _, err = meter.RegisterCallback(func(_ context.Context, o api.Observer) error { n := -10. + rand.Float64()*(90.) // [-10, 100) - gauge.Observe(ctx, n, attrs...) + o.ObserveFloat64(gauge, n, attrs...) return nil }, gauge) if err != nil { diff --git a/metric/example_test.go b/metric/example_test.go index 4653d612e8c6..9b33c0f9b5f2 100644 --- a/metric/example_test.go +++ b/metric/example_test.go @@ -90,13 +90,13 @@ func ExampleMeter_asynchronous_multiple() { gcPause, _ := meter.Float64Histogram("gcPause") _, err := meter.RegisterCallback( - func(ctx context.Context) error { + func(ctx context.Context, o metric.Observer) error { memStats := &runtime.MemStats{} // This call does work runtime.ReadMemStats(memStats) - heapAlloc.Observe(ctx, int64(memStats.HeapAlloc)) - gcCount.Observe(ctx, int64(memStats.NumGC)) + o.ObserveInt64(heapAlloc, int64(memStats.HeapAlloc)) + o.ObserveInt64(gcCount, int64(memStats.NumGC)) // This function synchronously records the pauses computeGCPauses(ctx, gcPause, memStats.PauseNs[:]) diff --git a/metric/internal/global/instruments.go b/metric/internal/global/instruments.go index c4b3d1ff5ab5..1398ada26be7 100644 --- a/metric/internal/global/instruments.go +++ b/metric/internal/global/instruments.go @@ -24,6 +24,11 @@ import ( "go.opentelemetry.io/otel/metric/instrument" ) +// unwrapper unwraps to return the underlying instrument implementation. +type unwrapper interface { + Unwrap() instrument.Asynchronous +} + type afCounter struct { name string opts []instrument.Float64ObserverOption @@ -33,6 +38,9 @@ type afCounter struct { instrument.Asynchronous } +var _ unwrapper = (*afCounter)(nil) +var _ instrument.Float64ObservableCounter = (*afCounter)(nil) + func (i *afCounter) setDelegate(m metric.Meter) { ctr, err := m.Float64ObservableCounter(i.name, i.opts...) if err != nil { @@ -48,7 +56,7 @@ func (i *afCounter) Observe(ctx context.Context, x float64, attrs ...attribute.K } } -func (i *afCounter) unwrap() instrument.Asynchronous { +func (i *afCounter) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Float64ObservableCounter) } @@ -64,6 +72,9 @@ type afUpDownCounter struct { instrument.Asynchronous } +var _ unwrapper = (*afUpDownCounter)(nil) +var _ instrument.Float64ObservableUpDownCounter = (*afUpDownCounter)(nil) + func (i *afUpDownCounter) setDelegate(m metric.Meter) { ctr, err := m.Float64ObservableUpDownCounter(i.name, i.opts...) if err != nil { @@ -79,7 +90,7 @@ func (i *afUpDownCounter) Observe(ctx context.Context, x float64, attrs ...attri } } -func (i *afUpDownCounter) unwrap() instrument.Asynchronous { +func (i *afUpDownCounter) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Float64ObservableUpDownCounter) } @@ -104,13 +115,16 @@ func (i *afGauge) setDelegate(m metric.Meter) { i.delegate.Store(ctr) } +var _ unwrapper = (*afGauge)(nil) +var _ instrument.Float64ObservableGauge = (*afGauge)(nil) + func (i *afGauge) Observe(ctx context.Context, x float64, attrs ...attribute.KeyValue) { if ctr := i.delegate.Load(); ctr != nil { ctr.(instrument.Float64ObservableGauge).Observe(ctx, x, attrs...) } } -func (i *afGauge) unwrap() instrument.Asynchronous { +func (i *afGauge) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Float64ObservableGauge) } @@ -126,6 +140,9 @@ type aiCounter struct { instrument.Asynchronous } +var _ unwrapper = (*aiCounter)(nil) +var _ instrument.Int64ObservableCounter = (*aiCounter)(nil) + func (i *aiCounter) setDelegate(m metric.Meter) { ctr, err := m.Int64ObservableCounter(i.name, i.opts...) if err != nil { @@ -141,7 +158,7 @@ func (i *aiCounter) Observe(ctx context.Context, x int64, attrs ...attribute.Key } } -func (i *aiCounter) unwrap() instrument.Asynchronous { +func (i *aiCounter) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Int64ObservableCounter) } @@ -157,6 +174,9 @@ type aiUpDownCounter struct { instrument.Asynchronous } +var _ unwrapper = (*aiUpDownCounter)(nil) +var _ instrument.Int64ObservableUpDownCounter = (*aiUpDownCounter)(nil) + func (i *aiUpDownCounter) setDelegate(m metric.Meter) { ctr, err := m.Int64ObservableUpDownCounter(i.name, i.opts...) if err != nil { @@ -172,7 +192,7 @@ func (i *aiUpDownCounter) Observe(ctx context.Context, x int64, attrs ...attribu } } -func (i *aiUpDownCounter) unwrap() instrument.Asynchronous { +func (i *aiUpDownCounter) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Int64ObservableUpDownCounter) } @@ -188,6 +208,9 @@ type aiGauge struct { instrument.Asynchronous } +var _ unwrapper = (*aiGauge)(nil) +var _ instrument.Int64ObservableGauge = (*aiGauge)(nil) + func (i *aiGauge) setDelegate(m metric.Meter) { ctr, err := m.Int64ObservableGauge(i.name, i.opts...) if err != nil { @@ -203,7 +226,7 @@ func (i *aiGauge) Observe(ctx context.Context, x int64, attrs ...attribute.KeyVa } } -func (i *aiGauge) unwrap() instrument.Asynchronous { +func (i *aiGauge) Unwrap() instrument.Asynchronous { if ctr := i.delegate.Load(); ctr != nil { return ctr.(instrument.Int64ObservableGauge) } @@ -220,6 +243,8 @@ type sfCounter struct { instrument.Synchronous } +var _ instrument.Float64Counter = (*sfCounter)(nil) + func (i *sfCounter) setDelegate(m metric.Meter) { ctr, err := m.Float64Counter(i.name, i.opts...) if err != nil { @@ -244,6 +269,8 @@ type sfUpDownCounter struct { instrument.Synchronous } +var _ instrument.Float64UpDownCounter = (*sfUpDownCounter)(nil) + func (i *sfUpDownCounter) setDelegate(m metric.Meter) { ctr, err := m.Float64UpDownCounter(i.name, i.opts...) if err != nil { @@ -268,6 +295,8 @@ type sfHistogram struct { instrument.Synchronous } +var _ instrument.Float64Histogram = (*sfHistogram)(nil) + func (i *sfHistogram) setDelegate(m metric.Meter) { ctr, err := m.Float64Histogram(i.name, i.opts...) if err != nil { @@ -292,6 +321,8 @@ type siCounter struct { instrument.Synchronous } +var _ instrument.Int64Counter = (*siCounter)(nil) + func (i *siCounter) setDelegate(m metric.Meter) { ctr, err := m.Int64Counter(i.name, i.opts...) if err != nil { @@ -316,6 +347,8 @@ type siUpDownCounter struct { instrument.Synchronous } +var _ instrument.Int64UpDownCounter = (*siUpDownCounter)(nil) + func (i *siUpDownCounter) setDelegate(m metric.Meter) { ctr, err := m.Int64UpDownCounter(i.name, i.opts...) if err != nil { @@ -340,6 +373,8 @@ type siHistogram struct { instrument.Synchronous } +var _ instrument.Int64Histogram = (*siHistogram)(nil) + func (i *siHistogram) setDelegate(m metric.Meter) { ctr, err := m.Int64Histogram(i.name, i.opts...) if err != nil { diff --git a/metric/internal/global/meter.go b/metric/internal/global/meter.go index 92f35e9730bb..8acf632863cb 100644 --- a/metric/internal/global/meter.go +++ b/metric/internal/global/meter.go @@ -275,9 +275,6 @@ func (m *meter) Float64ObservableGauge(name string, options ...instrument.Float6 } // RegisterCallback captures the function that will be called during Collect. -// -// It is only valid to call Observe within the scope of the passed function, -// and only on the instruments that were registered with this call. func (m *meter) RegisterCallback(f metric.Callback, insts ...instrument.Asynchronous) (metric.Registration, error) { if del, ok := m.delegate.Load().(metric.Meter); ok { insts = unwrapInstruments(insts) diff --git a/metric/internal/global/meter_test.go b/metric/internal/global/meter_test.go index eeb43689b0ef..704c1f956342 100644 --- a/metric/internal/global/meter_test.go +++ b/metric/internal/global/meter_test.go @@ -45,6 +45,10 @@ func TestMeterProviderRace(t *testing.T) { close(finish) } +var zeroCallback metric.Callback = func(ctx context.Context, or metric.Observer) error { + return nil +} + func TestMeterRace(t *testing.T) { mtr := &meter{} @@ -66,7 +70,7 @@ func TestMeterRace(t *testing.T) { _, _ = mtr.Int64Counter(name) _, _ = mtr.Int64UpDownCounter(name) _, _ = mtr.Int64Histogram(name) - _, _ = mtr.RegisterCallback(func(ctx context.Context) error { return nil }) + _, _ = mtr.RegisterCallback(zeroCallback) if !once { wg.Done() once = true @@ -86,7 +90,7 @@ func TestMeterRace(t *testing.T) { func TestUnregisterRace(t *testing.T) { mtr := &meter{} - reg, err := mtr.RegisterCallback(func(ctx context.Context) error { return nil }) + reg, err := mtr.RegisterCallback(zeroCallback) require.NoError(t, err) wg := &sync.WaitGroup{} @@ -128,8 +132,8 @@ func testSetupAllInstrumentTypes(t *testing.T, m metric.Meter) (instrument.Float _, err = m.Int64ObservableGauge("test_Async_Gauge") assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - afcounter.Observe(ctx, 3) + _, err = m.RegisterCallback(func(ctx context.Context, obs metric.Observer) error { + obs.ObserveFloat64(afcounter, 3) return nil }, afcounter) require.NoError(t, err) @@ -323,7 +327,7 @@ func TestRegistrationDelegation(t *testing.T) { require.NoError(t, err) var called0 bool - reg0, err := m.RegisterCallback(func(context.Context) error { + reg0, err := m.RegisterCallback(func(context.Context, metric.Observer) error { called0 = true return nil }, actr) @@ -334,7 +338,7 @@ func TestRegistrationDelegation(t *testing.T) { assert.Equal(t, 0, mImpl.registry.Len(), "callback not unregistered") var called1 bool - reg1, err := m.RegisterCallback(func(context.Context) error { + reg1, err := m.RegisterCallback(func(context.Context, metric.Observer) error { called1 = true return nil }, actr) diff --git a/metric/internal/global/meter_types_test.go b/metric/internal/global/meter_types_test.go index 3dfc74af7b3d..84637b286f9b 100644 --- a/metric/internal/global/meter_types_test.go +++ b/metric/internal/global/meter_types_test.go @@ -17,6 +17,7 @@ package global // import "go.opentelemetry.io/otel/metric/internal/global" import ( "context" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" ) @@ -112,9 +113,6 @@ func (m *testMeter) Float64ObservableGauge(name string, options ...instrument.Fl } // RegisterCallback captures the function that will be called during Collect. -// -// It is only valid to call Observe within the scope of the passed function, -// and only on the instruments that were registered with this call. func (m *testMeter) RegisterCallback(f metric.Callback, i ...instrument.Asynchronous) (metric.Registration, error) { m.callbacks = append(m.callbacks, f) return testReg{ @@ -136,11 +134,24 @@ func (r testReg) Unregister() error { // This enables async collection. func (m *testMeter) collect() { ctx := context.Background() + o := observationRecorder{ctx} for _, f := range m.callbacks { if f == nil { // Unregister. continue } - _ = f(ctx) + _ = f(ctx, o) } } + +type observationRecorder struct { + ctx context.Context +} + +func (o observationRecorder) ObserveFloat64(i instrument.Float64Observer, value float64, attr ...attribute.KeyValue) { + i.Observe(o.ctx, value, attr...) +} + +func (o observationRecorder) ObserveInt64(i instrument.Int64Observer, value int64, attr ...attribute.KeyValue) { + i.Observe(o.ctx, value, attr...) +} diff --git a/metric/meter.go b/metric/meter.go index d384d0df17e2..fc39f40b3d84 100644 --- a/metric/meter.go +++ b/metric/meter.go @@ -17,6 +17,7 @@ package metric // import "go.opentelemetry.io/otel/metric" import ( "context" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric/instrument" ) @@ -106,7 +107,8 @@ type Meter interface { } // Callback is a function registered with a Meter that makes observations for -// the set of instruments it is registered with. +// the set of instruments it is registered with. The Observer parameter is used +// to record measurment observations for these instruments. // // The function needs to complete in a finite amount of time and the deadline // of the passed context is expected to be honored. @@ -116,7 +118,15 @@ type Meter interface { // the same attributes as another Callback will report. // // The function needs to be concurrent safe. -type Callback func(context.Context) error +type Callback func(context.Context, Observer) error + +// Observer records measurements for multiple instruments in a Callback. +type Observer interface { + // ObserveFloat64 records the float64 value with attributes for obsrv. + ObserveFloat64(obsrv instrument.Float64Observer, value float64, attributes ...attribute.KeyValue) + // ObserveInt64 records the int64 value with attributes for obsrv. + ObserveInt64(obsrv instrument.Int64Observer, value int64, attributes ...attribute.KeyValue) +} // Registration is an token representing the unique registration of a callback // for a set of instruments with a Meter. diff --git a/sdk/metric/instrument.go b/sdk/metric/instrument.go index 5414a8db7e68..2b3c2356d3a6 100644 --- a/sdk/metric/instrument.go +++ b/sdk/metric/instrument.go @@ -16,8 +16,11 @@ package metric // import "go.opentelemetry.io/otel/sdk/metric" import ( "context" + "errors" + "fmt" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/internal/global" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/unit" "go.opentelemetry.io/otel/sdk/instrumentation" @@ -170,33 +173,17 @@ type instrumentID struct { } type instrumentImpl[N int64 | float64] struct { - instrument.Asynchronous instrument.Synchronous aggregators []internal.Aggregator[N] } -var _ instrument.Float64ObservableCounter = &instrumentImpl[float64]{} -var _ instrument.Float64ObservableUpDownCounter = &instrumentImpl[float64]{} -var _ instrument.Float64ObservableGauge = &instrumentImpl[float64]{} -var _ instrument.Int64ObservableCounter = &instrumentImpl[int64]{} -var _ instrument.Int64ObservableUpDownCounter = &instrumentImpl[int64]{} -var _ instrument.Int64ObservableGauge = &instrumentImpl[int64]{} -var _ instrument.Float64Counter = &instrumentImpl[float64]{} -var _ instrument.Float64UpDownCounter = &instrumentImpl[float64]{} -var _ instrument.Float64Histogram = &instrumentImpl[float64]{} -var _ instrument.Int64Counter = &instrumentImpl[int64]{} -var _ instrument.Int64UpDownCounter = &instrumentImpl[int64]{} -var _ instrument.Int64Histogram = &instrumentImpl[int64]{} - -func (i *instrumentImpl[N]) Observe(ctx context.Context, val N, attrs ...attribute.KeyValue) { - // Only record a value if this is being called from the MetricProvider. - _, ok := ctx.Value(produceKey).(struct{}) - if !ok { - return - } - i.aggregate(ctx, val, attrs) -} +var _ instrument.Float64Counter = (*instrumentImpl[float64])(nil) +var _ instrument.Float64UpDownCounter = (*instrumentImpl[float64])(nil) +var _ instrument.Float64Histogram = (*instrumentImpl[float64])(nil) +var _ instrument.Int64Counter = (*instrumentImpl[int64])(nil) +var _ instrument.Int64UpDownCounter = (*instrumentImpl[int64])(nil) +var _ instrument.Int64Histogram = (*instrumentImpl[int64])(nil) func (i *instrumentImpl[N]) Add(ctx context.Context, val N, attrs ...attribute.KeyValue) { i.aggregate(ctx, val, attrs) @@ -214,3 +201,79 @@ func (i *instrumentImpl[N]) aggregate(ctx context.Context, val N, attrs []attrib agg.Aggregate(val, attribute.NewSet(attrs...)) } } + +// observablID is a comparable unique identifier of an observable. +type observablID[N int64 | float64] struct { + name string + description string + kind InstrumentKind + unit unit.Unit + scope instrumentation.Scope +} + +type observable[N int64 | float64] struct { + instrument.Asynchronous + observablID[N] + + aggregators []internal.Aggregator[N] +} + +func newObservable[N int64 | float64](scope instrumentation.Scope, kind InstrumentKind, name, desc string, u unit.Unit, agg []internal.Aggregator[N]) *observable[N] { + return &observable[N]{ + observablID: observablID[N]{ + name: name, + description: desc, + kind: kind, + unit: u, + scope: scope, + }, + aggregators: agg, + } +} + +var _ instrument.Float64ObservableCounter = (*observable[float64])(nil) +var _ instrument.Float64ObservableUpDownCounter = (*observable[float64])(nil) +var _ instrument.Float64ObservableGauge = (*observable[float64])(nil) +var _ instrument.Int64ObservableCounter = (*observable[int64])(nil) +var _ instrument.Int64ObservableUpDownCounter = (*observable[int64])(nil) +var _ instrument.Int64ObservableGauge = (*observable[int64])(nil) + +// Observe logs an error. +func (o *observable[N]) Observe(ctx context.Context, val N, attrs ...attribute.KeyValue) { + var zero N + err := errors.New("invalid observation") + global.Error(err, "dropping observation made outside a callback", + "name", o.name, + "description", o.description, + "unit", o.unit, + "number", fmt.Sprintf("%T", zero), + ) +} + +// observe records the val for the set of attrs. +func (o *observable[N]) observe(val N, attrs []attribute.KeyValue) { + for _, agg := range o.aggregators { + agg.Aggregate(val, attribute.NewSet(attrs...)) + } +} + +var errEmptyAgg = errors.New("no aggregators for observable instrument") + +// registerable returns an error if the observable o should not be registered, +// and nil if it should. An errEmptyAgg error is returned if o is effecively a +// no-op because it does not have any aggregators. Also, an error is returned +// if scope defines a Meter other than the one o was created by. +func (o *observable[N]) registerable(scope instrumentation.Scope) error { + if len(o.aggregators) == 0 { + return errEmptyAgg + } + if scope != o.scope { + return fmt.Errorf( + "invalid registration: observable %q from Meter %q, registered with Meter %q", + o.name, + o.scope.Name, + scope.Name, + ) + } + return nil +} diff --git a/sdk/metric/meter.go b/sdk/metric/meter.go index ab0128c98b7d..7a4d6f9d3779 100644 --- a/sdk/metric/meter.go +++ b/sdk/metric/meter.go @@ -16,11 +16,16 @@ package metric // import "go.opentelemetry.io/otel/sdk/metric" import ( "context" + "errors" + "fmt" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/internal/global" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/metric/unit" "go.opentelemetry.io/otel/sdk/instrumentation" + "go.opentelemetry.io/otel/sdk/metric/internal" ) // meter handles the creation and coordination of all metric instruments. A @@ -28,6 +33,7 @@ import ( // produced by an instrumentation scope will use metric instruments from a // single meter. type meter struct { + scope instrumentation.Scope pipes pipelines int64IP *instProvider[int64] @@ -45,6 +51,7 @@ func newMeter(s instrumentation.Scope, p pipelines) *meter { fc := newInstrumentCache[float64](nil, &viewCache) return &meter{ + scope: s, pipes: p, int64IP: newInstProvider(s, p, ic), float64IP: newInstProvider(s, p, fc), @@ -201,28 +208,150 @@ func (m *meter) Float64ObservableGauge(name string, options ...instrument.Float6 // RegisterCallback registers the function f to be called when any of the // insts Collect method is called. func (m *meter) RegisterCallback(f metric.Callback, insts ...instrument.Asynchronous) (metric.Registration, error) { + if len(insts) == 0 { + // Don't allocate a observer if not needed. + return noopRegister{}, nil + } + + reg := newObserver() + var errs multierror for _, inst := range insts { - // Only register if at least one instrument has a non-drop aggregation. - // Otherwise, calling f during collection will be wasted computation. - switch t := inst.(type) { - case *instrumentImpl[int64]: - if len(t.aggregators) > 0 { - return m.registerMultiCallback(f) + // Unwrap any global. + if u, ok := inst.(interface { + Unwrap() instrument.Asynchronous + }); ok { + inst = u.Unwrap() + } + + switch o := inst.(type) { + case *observable[int64]: + if err := o.registerable(m.scope); err != nil { + if !errors.Is(err, errEmptyAgg) { + errs.append(err) + } + continue } - case *instrumentImpl[float64]: - if len(t.aggregators) > 0 { - return m.registerMultiCallback(f) + reg.registerInt64(o.observablID) + case *observable[float64]: + if err := o.registerable(m.scope); err != nil { + if !errors.Is(err, errEmptyAgg) { + errs.append(err) + } + continue } + reg.registerFloat64(o.observablID) default: - // Instrument external to the SDK. For example, an instrument from - // the "go.opentelemetry.io/otel/metric/internal/global" package. - // - // Fail gracefully here, assume a valid instrument. - return m.registerMultiCallback(f) + // Instrument external to the SDK. + return nil, fmt.Errorf("invalid observable: from different implementation") + } + } + + if err := errs.errorOrNil(); err != nil { + return nil, err + } + + if reg.len() == 0 { + // All insts use drop aggregation. + return noopRegister{}, nil + } + + cback := func(ctx context.Context) error { + return f(ctx, reg) + } + return m.pipes.registerMultiCallback(cback), nil +} + +type observer struct { + float64 map[observablID[float64]]struct{} + int64 map[observablID[int64]]struct{} +} + +func newObserver() observer { + return observer{ + float64: make(map[observablID[float64]]struct{}), + int64: make(map[observablID[int64]]struct{}), + } +} + +func (r observer) len() int { + return len(r.float64) + len(r.int64) +} + +func (r observer) registerFloat64(id observablID[float64]) { + r.float64[id] = struct{}{} +} + +func (r observer) registerInt64(id observablID[int64]) { + r.int64[id] = struct{}{} +} + +var ( + errUnknownObserver = errors.New("unknown observable instrument") + errUnregObserver = errors.New("observable instrument not registered for callback") +) + +func (r observer) ObserveFloat64(o instrument.Float64Observer, v float64, a ...attribute.KeyValue) { + var oImpl *observable[float64] + switch conv := o.(type) { + case *observable[float64]: + oImpl = conv + case interface { + Unwrap() instrument.Asynchronous + }: + // Unwrap any global. + async := conv.Unwrap() + var ok bool + if oImpl, ok = async.(*observable[float64]); !ok { + global.Error(errUnknownObserver, "failed to record asynchronous") + return } + default: + global.Error(errUnknownObserver, "failed to record") + return + } + + if _, registered := r.float64[oImpl.observablID]; !registered { + global.Error(errUnregObserver, "failed to record", + "name", oImpl.name, + "description", oImpl.description, + "unit", oImpl.unit, + "number", fmt.Sprintf("%T", float64(0)), + ) + return + } + oImpl.observe(v, a) +} + +func (r observer) ObserveInt64(o instrument.Int64Observer, v int64, a ...attribute.KeyValue) { + var oImpl *observable[int64] + switch conv := o.(type) { + case *observable[int64]: + oImpl = conv + case interface { + Unwrap() instrument.Asynchronous + }: + // Unwrap any global. + async := conv.Unwrap() + var ok bool + if oImpl, ok = async.(*observable[int64]); !ok { + global.Error(errUnknownObserver, "failed to record asynchronous") + return + } + default: + global.Error(errUnknownObserver, "failed to record") + return } - // All insts use drop aggregation. - return noopRegister{}, nil + + if _, registered := r.int64[oImpl.observablID]; !registered { + global.Error(errUnregObserver, "failed to record", + "name", oImpl.name, + "description", oImpl.description, + "unit", oImpl.unit, + "number", fmt.Sprintf("%T", int64(0)), + ) + return + } + oImpl.observe(v, a) } type noopRegister struct{} @@ -231,10 +360,6 @@ func (noopRegister) Unregister() error { return nil } -func (m *meter) registerMultiCallback(c metric.Callback) (metric.Registration, error) { - return m.pipes.registerMultiCallback(c), nil -} - // instProvider provides all OpenTelemetry instruments. type instProvider[N int64 | float64] struct { scope instrumentation.Scope @@ -246,8 +371,7 @@ func newInstProvider[N int64 | float64](s instrumentation.Scope, p pipelines, c return &instProvider[N]{scope: s, pipes: p, resolve: newResolver(p, c)} } -// lookup returns the resolved instrumentImpl. -func (p *instProvider[N]) lookup(kind InstrumentKind, name, desc string, u unit.Unit) (*instrumentImpl[N], error) { +func (p *instProvider[N]) aggs(kind InstrumentKind, name, desc string, u unit.Unit) ([]internal.Aggregator[N], error) { inst := Instrument{ Name: name, Description: desc, @@ -255,13 +379,23 @@ func (p *instProvider[N]) lookup(kind InstrumentKind, name, desc string, u unit. Kind: kind, Scope: p.scope, } - aggs, err := p.resolve.Aggregators(inst) + return p.resolve.Aggregators(inst) +} + +// lookup returns the resolved instrumentImpl. +func (p *instProvider[N]) lookup(kind InstrumentKind, name, desc string, u unit.Unit) (*instrumentImpl[N], error) { + aggs, err := p.aggs(kind, name, desc, u) return &instrumentImpl[N]{aggregators: aggs}, err } type int64ObservProvider struct{ *instProvider[int64] } -func (p int64ObservProvider) registerCallbacks(inst *instrumentImpl[int64], cBacks []instrument.Int64Callback) { +func (p int64ObservProvider) lookup(kind InstrumentKind, name, desc string, u unit.Unit) (*observable[int64], error) { + aggs, err := p.aggs(kind, name, desc, u) + return newObservable(p.scope, kind, name, desc, u, aggs), err +} + +func (p int64ObservProvider) registerCallbacks(inst *observable[int64], cBacks []instrument.Int64Callback) { if inst == nil { // Drop aggregator. return @@ -272,13 +406,19 @@ func (p int64ObservProvider) registerCallbacks(inst *instrumentImpl[int64], cBac } } -func (p int64ObservProvider) callback(i *instrumentImpl[int64], f instrument.Int64Callback) func(context.Context) error { - return func(ctx context.Context) error { return f(ctx, i) } +func (p int64ObservProvider) callback(i *observable[int64], f instrument.Int64Callback) func(context.Context) error { + inst := callbackObserver[int64]{i} + return func(ctx context.Context) error { return f(ctx, inst) } } type float64ObservProvider struct{ *instProvider[float64] } -func (p float64ObservProvider) registerCallbacks(inst *instrumentImpl[float64], cBacks []instrument.Float64Callback) { +func (p float64ObservProvider) lookup(kind InstrumentKind, name, desc string, u unit.Unit) (*observable[float64], error) { + aggs, err := p.aggs(kind, name, desc, u) + return newObservable(p.scope, kind, name, desc, u, aggs), err +} + +func (p float64ObservProvider) registerCallbacks(inst *observable[float64], cBacks []instrument.Float64Callback) { if inst == nil { // Drop aggregator. return @@ -289,6 +429,17 @@ func (p float64ObservProvider) registerCallbacks(inst *instrumentImpl[float64], } } -func (p float64ObservProvider) callback(i *instrumentImpl[float64], f instrument.Float64Callback) func(context.Context) error { - return func(ctx context.Context) error { return f(ctx, i) } +func (p float64ObservProvider) callback(i *observable[float64], f instrument.Float64Callback) func(context.Context) error { + inst := callbackObserver[float64]{i} + return func(ctx context.Context) error { return f(ctx, inst) } +} + +// callbackObserver is an observer that records values for a wrapped +// observable. +type callbackObserver[N int64 | float64] struct { + *observable[N] +} + +func (o callbackObserver[N]) Observe(_ context.Context, val N, attrs ...attribute.KeyValue) { + o.observe(val, attrs) } diff --git a/sdk/metric/meter_test.go b/sdk/metric/meter_test.go index 76a0bf8ec10c..190528e3c516 100644 --- a/sdk/metric/meter_test.go +++ b/sdk/metric/meter_test.go @@ -16,14 +16,20 @@ package metric import ( "context" + "fmt" + "strings" "sync" "testing" + "github.com/go-logr/logr" + "github.com/go-logr/logr/testr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/global" "go.opentelemetry.io/otel/metric/instrument" "go.opentelemetry.io/otel/sdk/instrumentation" "go.opentelemetry.io/otel/sdk/metric/aggregation" @@ -91,7 +97,7 @@ func TestMeterInstrumentConcurrency(t *testing.T) { wg.Wait() } -var emptyCallback metric.Callback = func(ctx context.Context) error { return nil } +var emptyCallback metric.Callback = func(context.Context, metric.Observer) error { return nil } // A Meter Should be able register Callbacks Concurrently. func TestMeterCallbackCreationConcurrency(t *testing.T) { @@ -179,8 +185,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } ctr, err := m.Int64ObservableCounter("aint", instrument.WithInt64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 3) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr, 3) return nil }, ctr) assert.NoError(t, err) @@ -209,8 +215,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } ctr, err := m.Int64ObservableUpDownCounter("aint", instrument.WithInt64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 11) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr, 11) return nil }, ctr) assert.NoError(t, err) @@ -239,8 +245,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } gauge, err := m.Int64ObservableGauge("agauge", instrument.WithInt64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - gauge.Observe(ctx, 11) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(gauge, 11) return nil }, gauge) assert.NoError(t, err) @@ -267,8 +273,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } ctr, err := m.Float64ObservableCounter("afloat", instrument.WithFloat64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 3) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr, 3) return nil }, ctr) assert.NoError(t, err) @@ -297,8 +303,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } ctr, err := m.Float64ObservableUpDownCounter("afloat", instrument.WithFloat64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 11) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr, 11) return nil }, ctr) assert.NoError(t, err) @@ -327,8 +333,8 @@ func TestMeterCreatesInstruments(t *testing.T) { } gauge, err := m.Float64ObservableGauge("agauge", instrument.WithFloat64Callback(cback)) assert.NoError(t, err) - _, err = m.RegisterCallback(func(ctx context.Context) error { - gauge.Observe(ctx, 11) + _, err = m.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(gauge, 11) return nil }, gauge) assert.NoError(t, err) @@ -496,6 +502,229 @@ func TestMeterCreatesInstruments(t *testing.T) { } } +func TestRegisterNonSDKObserverErrors(t *testing.T) { + rdr := NewManualReader() + mp := NewMeterProvider(WithReader(rdr)) + meter := mp.Meter("scope") + + type obsrv struct{ instrument.Asynchronous } + o := obsrv{} + + _, err := meter.RegisterCallback( + func(context.Context, metric.Observer) error { return nil }, + o, + ) + assert.ErrorContains( + t, + err, + "invalid observable: from different implementation", + "External instrument registred", + ) +} + +func TestMeterMixingOnRegisterErrors(t *testing.T) { + rdr := NewManualReader() + mp := NewMeterProvider(WithReader(rdr)) + + m1 := mp.Meter("scope1") + m2 := mp.Meter("scope2") + iCtr, err := m2.Int64ObservableCounter("int64 ctr") + require.NoError(t, err) + fCtr, err := m2.Float64ObservableCounter("float64 ctr") + require.NoError(t, err) + _, err = m1.RegisterCallback( + func(context.Context, metric.Observer) error { return nil }, + iCtr, fCtr, + ) + assert.ErrorContains( + t, + err, + `invalid registration: observable "int64 ctr" from Meter "scope2", registered with Meter "scope1"`, + "Instrument registred with non-creation Meter", + ) + assert.ErrorContains( + t, + err, + `invalid registration: observable "float64 ctr" from Meter "scope2", registered with Meter "scope1"`, + "Instrument registred with non-creation Meter", + ) +} + +func TestCallbackObserverNonRegistered(t *testing.T) { + rdr := NewManualReader() + mp := NewMeterProvider(WithReader(rdr)) + + m1 := mp.Meter("scope1") + valid, err := m1.Int64ObservableCounter("ctr") + require.NoError(t, err) + + m2 := mp.Meter("scope2") + iCtr, err := m2.Int64ObservableCounter("int64 ctr") + require.NoError(t, err) + fCtr, err := m2.Float64ObservableCounter("float64 ctr") + require.NoError(t, err) + + // Panics if Observe is called. + type int64Obsrv struct{ instrument.Int64Observer } + int64Foreign := int64Obsrv{} + type float64Obsrv struct{ instrument.Float64Observer } + float64Foreign := float64Obsrv{} + + _, err = m1.RegisterCallback( + func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(valid, 1) + o.ObserveInt64(iCtr, 1) + o.ObserveFloat64(fCtr, 1) + o.ObserveInt64(int64Foreign, 1) + o.ObserveFloat64(float64Foreign, 1) + return nil + }, + valid, + ) + require.NoError(t, err) + + var got metricdata.ResourceMetrics + assert.NotPanics(t, func() { + got, err = rdr.Collect(context.Background()) + }) + + assert.NoError(t, err) + want := metricdata.ResourceMetrics{ + Resource: resource.Default(), + ScopeMetrics: []metricdata.ScopeMetrics{ + { + Scope: instrumentation.Scope{ + Name: "scope1", + }, + Metrics: []metricdata.Metrics{ + { + Name: "ctr", + Data: metricdata.Sum[int64]{ + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + DataPoints: []metricdata.DataPoint[int64]{ + { + Value: 1, + }, + }, + }, + }, + }, + }, + }, + } + metricdatatest.AssertEqual(t, want, got, metricdatatest.IgnoreTimestamp()) +} + +type logSink struct { + logr.LogSink + + messages []string +} + +func newLogSink(t *testing.T) *logSink { + return &logSink{LogSink: testr.New(t).GetSink()} +} + +func (l *logSink) Info(level int, msg string, keysAndValues ...interface{}) { + l.messages = append(l.messages, msg) + l.LogSink.Info(level, msg, keysAndValues...) +} + +func (l *logSink) Error(err error, msg string, keysAndValues ...interface{}) { + l.messages = append(l.messages, fmt.Sprintf("%s: %s", err, msg)) + l.LogSink.Error(err, msg, keysAndValues...) +} + +func (l *logSink) String() string { + out := make([]string, len(l.messages)) + for i := range l.messages { + out[i] = "\t-" + l.messages[i] + } + return strings.Join(out, "\n") +} + +func TestGlobalInstRegisterCallback(t *testing.T) { + l := newLogSink(t) + otel.SetLogger(logr.New(l)) + + const mtrName = "TestGlobalInstRegisterCallback" + preMtr := global.Meter(mtrName) + preInt64Ctr, err := preMtr.Int64ObservableCounter("pre.int64.counter") + require.NoError(t, err) + preFloat64Ctr, err := preMtr.Float64ObservableCounter("pre.float64.counter") + require.NoError(t, err) + + rdr := NewManualReader() + mp := NewMeterProvider(WithReader(rdr), WithResource(resource.Empty())) + global.SetMeterProvider(mp) + + postMtr := global.Meter(mtrName) + postInt64Ctr, err := postMtr.Int64ObservableCounter("post.int64.counter") + require.NoError(t, err) + postFloat64Ctr, err := postMtr.Float64ObservableCounter("post.float64.counter") + require.NoError(t, err) + + cb := func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(preInt64Ctr, 1) + o.ObserveFloat64(preFloat64Ctr, 2) + o.ObserveInt64(postInt64Ctr, 3) + o.ObserveFloat64(postFloat64Ctr, 4) + return nil + } + + _, err = preMtr.RegisterCallback(cb, preInt64Ctr, preFloat64Ctr, postInt64Ctr, postFloat64Ctr) + assert.NoError(t, err) + + _, err = preMtr.RegisterCallback(cb, preInt64Ctr, preFloat64Ctr, postInt64Ctr, postFloat64Ctr) + assert.NoError(t, err) + + got, err := rdr.Collect(context.Background()) + assert.NoError(t, err) + assert.Lenf(t, l.messages, 0, "Warnings and errors logged:\n%s", l) + metricdatatest.AssertEqual(t, metricdata.ResourceMetrics{ + ScopeMetrics: []metricdata.ScopeMetrics{ + { + Scope: instrumentation.Scope{Name: "TestGlobalInstRegisterCallback"}, + Metrics: []metricdata.Metrics{ + { + Name: "pre.int64.counter", + Data: metricdata.Sum[int64]{ + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + DataPoints: []metricdata.DataPoint[int64]{{Value: 1}}, + }, + }, + { + Name: "pre.float64.counter", + Data: metricdata.Sum[float64]{ + DataPoints: []metricdata.DataPoint[float64]{{Value: 2}}, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + { + Name: "post.int64.counter", + Data: metricdata.Sum[int64]{ + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + DataPoints: []metricdata.DataPoint[int64]{{Value: 3}}, + }, + }, + { + Name: "post.float64.counter", + Data: metricdata.Sum[float64]{ + DataPoints: []metricdata.DataPoint[float64]{{Value: 4}}, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + }, + }, + }, + }, got, metricdatatest.IgnoreTimestamp()) +} + func TestMetersProvideScope(t *testing.T) { rdr := NewManualReader() mp := NewMeterProvider(WithReader(rdr)) @@ -503,8 +732,8 @@ func TestMetersProvideScope(t *testing.T) { m1 := mp.Meter("scope1") ctr1, err := m1.Float64ObservableCounter("ctr1") assert.NoError(t, err) - _, err = m1.RegisterCallback(func(ctx context.Context) error { - ctr1.Observe(ctx, 5) + _, err = m1.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr1, 5) return nil }, ctr1) assert.NoError(t, err) @@ -512,8 +741,8 @@ func TestMetersProvideScope(t *testing.T) { m2 := mp.Meter("scope2") ctr2, err := m2.Int64ObservableCounter("ctr2") assert.NoError(t, err) - _, err = m1.RegisterCallback(func(ctx context.Context) error { - ctr2.Observe(ctx, 7) + _, err = m2.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr2, 7) return nil }, ctr2) assert.NoError(t, err) @@ -592,7 +821,7 @@ func TestUnregisterUnregisters(t *testing.T) { var called bool reg, err := m.RegisterCallback( - func(context.Context) error { + func(context.Context, metric.Observer) error { called = true return nil }, @@ -646,7 +875,7 @@ func TestRegisterCallbackDropAggregations(t *testing.T) { var called bool _, err = m.RegisterCallback( - func(context.Context) error { + func(context.Context, metric.Observer) error { called = true return nil }, @@ -686,10 +915,10 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 2.0, attribute.String("foo", "bar")) - ctr.Observe(ctx, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveFloat64(ctr, 2.0, attribute.String("foo", "bar")) + o.ObserveFloat64(ctr, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err @@ -715,10 +944,10 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 2.0, attribute.String("foo", "bar")) - ctr.Observe(ctx, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveFloat64(ctr, 2.0, attribute.String("foo", "bar")) + o.ObserveFloat64(ctr, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err @@ -744,9 +973,9 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 2.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveFloat64(ctr, 1.0, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveFloat64(ctr, 2.0, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err @@ -770,10 +999,10 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 20, attribute.String("foo", "bar")) - ctr.Observe(ctx, 10, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveInt64(ctr, 20, attribute.String("foo", "bar")) + o.ObserveInt64(ctr, 10, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err @@ -799,10 +1028,10 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 20, attribute.String("foo", "bar")) - ctr.Observe(ctx, 10, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveInt64(ctr, 20, attribute.String("foo", "bar")) + o.ObserveInt64(ctr, 10, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err @@ -828,9 +1057,9 @@ func testAttributeFilter(temporality metricdata.Temporality) func(*testing.T) { if err != nil { return err } - _, err = mtr.RegisterCallback(func(ctx context.Context) error { - ctr.Observe(ctx, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) - ctr.Observe(ctx, 20, attribute.String("foo", "bar"), attribute.Int("version", 2)) + _, err = mtr.RegisterCallback(func(_ context.Context, o metric.Observer) error { + o.ObserveInt64(ctr, 10, attribute.String("foo", "bar"), attribute.Int("version", 1)) + o.ObserveInt64(ctr, 20, attribute.String("foo", "bar"), attribute.Int("version", 2)) return nil }, ctr) return err diff --git a/sdk/metric/pipeline.go b/sdk/metric/pipeline.go index 32b9125340be..666e095c169e 100644 --- a/sdk/metric/pipeline.go +++ b/sdk/metric/pipeline.go @@ -104,9 +104,11 @@ func (p *pipeline) addCallback(cback func(context.Context) error) { p.callbacks = append(p.callbacks, cback) } +type multiCallback func(context.Context) error + // addMultiCallback registers a multi-instrument callback to be run when // `produce()` is called. -func (p *pipeline) addMultiCallback(c metric.Callback) (unregister func()) { +func (p *pipeline) addMultiCallback(c multiCallback) (unregister func()) { p.Lock() defer p.Unlock() e := p.multiCallbacks.PushBack(c) @@ -146,7 +148,7 @@ func (p *pipeline) produce(ctx context.Context) (metricdata.ResourceMetrics, err } for e := p.multiCallbacks.Front(); e != nil; e = e.Next() { // TODO make the callbacks parallel. ( #3034 ) - f := e.Value.(metric.Callback) + f := e.Value.(multiCallback) if err := f(ctx); err != nil { errs.append(err) } @@ -475,7 +477,7 @@ func (p pipelines) registerCallback(cback func(context.Context) error) { } } -func (p pipelines) registerMultiCallback(c metric.Callback) metric.Registration { +func (p pipelines) registerMultiCallback(c multiCallback) metric.Registration { unregs := make([]func(), len(p)) for i, pipe := range p { unregs[i] = pipe.addMultiCallback(c) diff --git a/sdk/metric/pipeline_test.go b/sdk/metric/pipeline_test.go index 7b9f89585dcc..3c5e19b8b7db 100644 --- a/sdk/metric/pipeline_test.go +++ b/sdk/metric/pipeline_test.go @@ -54,7 +54,7 @@ func TestEmptyPipeline(t *testing.T) { }) require.NotPanics(t, func() { - pipe.addMultiCallback(emptyCallback) + pipe.addMultiCallback(func(context.Context) error { return nil }) }) output, err = pipe.produce(context.Background()) @@ -78,7 +78,7 @@ func TestNewPipeline(t *testing.T) { }) require.NotPanics(t, func() { - pipe.addMultiCallback(emptyCallback) + pipe.addMultiCallback(func(context.Context) error { return nil }) }) output, err = pipe.produce(context.Background()) @@ -121,7 +121,7 @@ func TestPipelineConcurrency(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - pipe.addMultiCallback(emptyCallback) + pipe.addMultiCallback(func(context.Context) error { return nil }) }() } wg.Wait()