diff --git a/consumer/consumertest/sink.go b/consumer/consumertest/sink.go index c8a6d2a4864..cd39256fd6a 100644 --- a/consumer/consumertest/sink.go +++ b/consumer/consumertest/sink.go @@ -22,23 +22,10 @@ import ( "go.opentelemetry.io/collector/consumer/pdata" ) -type baseErrorConsumer struct { - mu sync.Mutex - consumeError error // to be returned by ConsumeTraces, if set -} - -// SetConsumeError sets an error that will be returned by the Consume function. -// TODO: Remove this when all calls are switched to the new ErrConsumer. -func (bec *baseErrorConsumer) SetConsumeError(err error) { - bec.mu.Lock() - defer bec.mu.Unlock() - bec.consumeError = err -} - // TracesSink is a consumer.TracesConsumer that acts like a sink that // stores all traces and allows querying them for testing. type TracesSink struct { - baseErrorConsumer + mu sync.Mutex traces []pdata.Traces spansCount int } @@ -50,10 +37,6 @@ func (ste *TracesSink) ConsumeTraces(_ context.Context, td pdata.Traces) error { ste.mu.Lock() defer ste.mu.Unlock() - if ste.consumeError != nil { - return ste.consumeError - } - ste.traces = append(ste.traces, td) ste.spansCount += td.SpanCount() @@ -89,7 +72,7 @@ func (ste *TracesSink) Reset() { // MetricsSink is a consumer.MetricsConsumer that acts like a sink that // stores all metrics and allows querying them for testing. type MetricsSink struct { - baseErrorConsumer + mu sync.Mutex metrics []pdata.Metrics metricsCount int } @@ -100,9 +83,6 @@ var _ consumer.MetricsConsumer = (*MetricsSink)(nil) func (sme *MetricsSink) ConsumeMetrics(_ context.Context, md pdata.Metrics) error { sme.mu.Lock() defer sme.mu.Unlock() - if sme.consumeError != nil { - return sme.consumeError - } sme.metrics = append(sme.metrics, md) sme.metricsCount += md.MetricCount() @@ -139,7 +119,7 @@ func (sme *MetricsSink) Reset() { // LogsSink is a consumer.LogsConsumer that acts like a sink that // stores all logs and allows querying them for testing. type LogsSink struct { - baseErrorConsumer + mu sync.Mutex logs []pdata.Logs logRecordsCount int } @@ -150,9 +130,6 @@ var _ consumer.LogsConsumer = (*LogsSink)(nil) func (sle *LogsSink) ConsumeLogs(_ context.Context, ld pdata.Logs) error { sle.mu.Lock() defer sle.mu.Unlock() - if sle.consumeError != nil { - return sle.consumeError - } sle.logs = append(sle.logs, ld) sle.logRecordsCount += ld.LogRecordCount() diff --git a/consumer/consumertest/sink_test.go b/consumer/consumertest/sink_test.go index 17990885f62..b8ef1a6095d 100644 --- a/consumer/consumertest/sink_test.go +++ b/consumer/consumertest/sink_test.go @@ -16,7 +16,6 @@ package consumertest import ( "context" - "errors" "testing" "github.com/stretchr/testify/assert" @@ -41,15 +40,6 @@ func TestTracesSink(t *testing.T) { assert.Equal(t, 0, sink.SpansCount()) } -func TestTracesSink_Error(t *testing.T) { - sink := new(TracesSink) - sink.SetConsumeError(errors.New("my error")) - td := testdata.GenerateTraceDataOneSpan() - require.Error(t, sink.ConsumeTraces(context.Background(), td)) - assert.Len(t, sink.AllTraces(), 0) - assert.Equal(t, 0, sink.SpansCount()) -} - func TestMetricsSink(t *testing.T) { sink := new(MetricsSink) md := testdata.GenerateMetricsOneMetric() @@ -65,15 +55,6 @@ func TestMetricsSink(t *testing.T) { assert.Equal(t, 0, sink.MetricsCount()) } -func TestMetricsSink_Error(t *testing.T) { - sink := new(MetricsSink) - sink.SetConsumeError(errors.New("my error")) - md := testdata.GenerateMetricsOneMetric() - require.Error(t, sink.ConsumeMetrics(context.Background(), md)) - assert.Len(t, sink.AllMetrics(), 0) - assert.Equal(t, 0, sink.MetricsCount()) -} - func TestLogsSink(t *testing.T) { sink := new(LogsSink) md := testdata.GenerateLogDataOneLogNoResource() @@ -88,12 +69,3 @@ func TestLogsSink(t *testing.T) { assert.Equal(t, 0, len(sink.AllLogs())) assert.Equal(t, 0, sink.LogRecordsCount()) } - -func TestLogsSink_Error(t *testing.T) { - sink := new(LogsSink) - sink.SetConsumeError(errors.New("my error")) - ld := testdata.GenerateLogDataOneLogNoResource() - require.Error(t, sink.ConsumeLogs(context.Background(), ld)) - assert.Len(t, sink.AllLogs(), 0) - assert.Equal(t, 0, sink.LogRecordsCount()) -} diff --git a/internal/internalconsumertest/err_or_sink_consumer.go b/internal/internalconsumertest/err_or_sink_consumer.go new file mode 100644 index 00000000000..847e12f3b83 --- /dev/null +++ b/internal/internalconsumertest/err_or_sink_consumer.go @@ -0,0 +1,75 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internalconsumertest + +import ( + "context" + "sync" + + "go.opentelemetry.io/collector/consumer/consumertest" + "go.opentelemetry.io/collector/consumer/pdata" +) + +type ErrOrSinkConsumer struct { + *consumertest.TracesSink + *consumertest.MetricsSink + mu sync.Mutex + consumeError error // to be returned by ConsumeTraces, if set +} + +// SetConsumeError sets an error that will be returned by the Consume function. +func (esc *ErrOrSinkConsumer) SetConsumeError(err error) { + esc.mu.Lock() + defer esc.mu.Unlock() + esc.consumeError = err +} + +// ConsumeTraces stores traces to this sink. +func (esc *ErrOrSinkConsumer) ConsumeTraces(ctx context.Context, td pdata.Traces) error { + esc.mu.Lock() + defer esc.mu.Unlock() + + if esc.consumeError != nil { + return esc.consumeError + } + + return esc.TracesSink.ConsumeTraces(ctx, td) +} + +// ConsumeTraces stores traces to this sink. +func (esc *ErrOrSinkConsumer) ConsumeMetrics(ctx context.Context, md pdata.Metrics) error { + esc.mu.Lock() + defer esc.mu.Unlock() + + if esc.consumeError != nil { + return esc.consumeError + } + + return esc.MetricsSink.ConsumeMetrics(ctx, md) +} + +// Reset deletes any stored in the sinks, resets error to nil. +func (esc *ErrOrSinkConsumer) Reset() { + esc.mu.Lock() + defer esc.mu.Unlock() + + esc.consumeError = nil + if esc.TracesSink != nil { + esc.TracesSink.Reset() + } + if esc.MetricsSink != nil { + esc.MetricsSink.Reset() + } +} diff --git a/receiver/opencensusreceiver/opencensus_test.go b/receiver/opencensusreceiver/opencensus_test.go index 5d9803e802d..07e16b12290 100644 --- a/receiver/opencensusreceiver/opencensus_test.go +++ b/receiver/opencensusreceiver/opencensus_test.go @@ -45,6 +45,7 @@ import ( "go.opentelemetry.io/collector/component/componenttest" "go.opentelemetry.io/collector/consumer/consumertest" + "go.opentelemetry.io/collector/internal/internalconsumertest" "go.opentelemetry.io/collector/obsreport/obsreporttest" "go.opentelemetry.io/collector/testutil" "go.opentelemetry.io/collector/translator/internaldata" @@ -426,7 +427,7 @@ func TestOCReceiverTrace_HandleNextConsumerResponse(t *testing.T) { require.NoError(t, err) defer doneFn() - sink := new(consumertest.TracesSink) + sink := &internalconsumertest.ErrOrSinkConsumer{TracesSink: new(consumertest.TracesSink)} var opts []ocOption ocr, err := newOpenCensusReceiver(exporter.receiverTag, "tcp", addr, nil, nil, opts...) @@ -575,7 +576,7 @@ func TestOCReceiverMetrics_HandleNextConsumerResponse(t *testing.T) { require.NoError(t, err) defer doneFn() - sink := new(consumertest.MetricsSink) + sink := &internalconsumertest.ErrOrSinkConsumer{MetricsSink: new(consumertest.MetricsSink)} var opts []ocOption ocr, err := newOpenCensusReceiver(exporter.receiverTag, "tcp", addr, nil, nil, opts...) diff --git a/receiver/otlpreceiver/otlp_test.go b/receiver/otlpreceiver/otlp_test.go index 9465a7beeb5..ebd68625e95 100644 --- a/receiver/otlpreceiver/otlp_test.go +++ b/receiver/otlpreceiver/otlp_test.go @@ -52,6 +52,7 @@ import ( otlpcommon "go.opentelemetry.io/collector/internal/data/protogen/common/v1" otlpresource "go.opentelemetry.io/collector/internal/data/protogen/resource/v1" otlptrace "go.opentelemetry.io/collector/internal/data/protogen/trace/v1" + "go.opentelemetry.io/collector/internal/internalconsumertest" "go.opentelemetry.io/collector/internal/testdata" "go.opentelemetry.io/collector/obsreport/obsreporttest" "go.opentelemetry.io/collector/testutil" @@ -156,7 +157,7 @@ func TestJsonHttp(t *testing.T) { addr := testutil.GetAvailableLocalAddress(t) // Set the buffer count to 1 to make it flush the test span immediately. - sink := new(consumertest.TracesSink) + sink := &internalconsumertest.ErrOrSinkConsumer{TracesSink: new(consumertest.TracesSink)} ocr := newHTTPReceiver(t, addr, sink, nil) require.NoError(t, ocr.Start(context.Background(), componenttest.NewNopHost()), "Failed to start trace receiver") @@ -183,7 +184,7 @@ func TestJsonHttp(t *testing.T) { } } -func testHTTPJSONRequest(t *testing.T, url string, sink *consumertest.TracesSink, encoding string, expectedErr error) { +func testHTTPJSONRequest(t *testing.T, url string, sink *internalconsumertest.ErrOrSinkConsumer, encoding string, expectedErr error) { var buf *bytes.Buffer var err error switch encoding { @@ -334,9 +335,8 @@ func TestProtoHttp(t *testing.T) { addr := testutil.GetAvailableLocalAddress(t) // Set the buffer count to 1 to make it flush the test span immediately. - tSink := new(consumertest.TracesSink) - mSink := new(consumertest.MetricsSink) - ocr := newHTTPReceiver(t, addr, tSink, mSink) + tSink := &internalconsumertest.ErrOrSinkConsumer{TracesSink: new(consumertest.TracesSink)} + ocr := newHTTPReceiver(t, addr, tSink, consumertest.NewMetricsNop()) require.NoError(t, ocr.Start(context.Background(), componenttest.NewNopHost()), "Failed to start trace receiver") defer ocr.Shutdown(context.Background()) @@ -396,7 +396,7 @@ func createHTTPProtobufRequest( func testHTTPProtobufRequest( t *testing.T, url string, - tSink *consumertest.TracesSink, + tSink *internalconsumertest.ErrOrSinkConsumer, encoding string, traceBytes []byte, expectedErr error, @@ -645,7 +645,7 @@ func TestOTLPReceiverTrace_HandleNextConsumerResponse(t *testing.T) { require.NoError(t, err) defer doneFn() - sink := new(consumertest.TracesSink) + sink := &internalconsumertest.ErrOrSinkConsumer{TracesSink: new(consumertest.TracesSink)} ocr := newGRPCReceiver(t, exporter.receiverTag, addr, sink, nil) require.NotNil(t, ocr)