Skip to content

Commit

Permalink
Pass context from parents
Browse files Browse the repository at this point in the history
  • Loading branch information
donald-cheung committed Sep 5, 2024
1 parent ddda641 commit 1c502fd
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
16 changes: 8 additions & 8 deletions exporter/kafkaexporter/kafka_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (e *kafkaTracesProducer) Close(context.Context) error {
return e.producer.Close()
}

func (e *kafkaTracesProducer) start(_ context.Context, _ component.Host) error {
producer, err := newSaramaProducer(e.cfg)
func (e *kafkaTracesProducer) start(ctx context.Context, _ component.Host) error {
producer, err := newSaramaProducer(ctx, e.cfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -107,8 +107,8 @@ func (e *kafkaMetricsProducer) Close(context.Context) error {
return e.producer.Close()
}

func (e *kafkaMetricsProducer) start(_ context.Context, _ component.Host) error {
producer, err := newSaramaProducer(e.cfg)
func (e *kafkaMetricsProducer) start(ctx context.Context, _ component.Host) error {
producer, err := newSaramaProducer(ctx, e.cfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -149,16 +149,16 @@ func (e *kafkaLogsProducer) Close(context.Context) error {
return e.producer.Close()
}

func (e *kafkaLogsProducer) start(_ context.Context, _ component.Host) error {
producer, err := newSaramaProducer(e.cfg)
func (e *kafkaLogsProducer) start(ctx context.Context, _ component.Host) error {
producer, err := newSaramaProducer(ctx, e.cfg)
if err != nil {
return err
}
e.producer = producer
return nil
}

func newSaramaProducer(config Config) (sarama.SyncProducer, error) {
func newSaramaProducer(ctx context.Context, config Config) (sarama.SyncProducer, error) {
c := sarama.NewConfig()

c.ClientID = config.ClientID
Expand Down Expand Up @@ -187,7 +187,7 @@ func newSaramaProducer(config Config) (sarama.SyncProducer, error) {
c.Version = version
}

if err := kafka.ConfigureAuthentication(config.Authentication, c); err != nil {
if err := kafka.ConfigureAuthentication(ctx, config.Authentication, c); err != nil {
return nil, err
}

Expand Down
11 changes: 7 additions & 4 deletions internal/kafka/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ type AWSMSKConfig struct {
Region string `mapstructure:"region"`
// BrokerAddr is the client is connecting to in order to perform the auth required
BrokerAddr string `mapstructure:"broker_addr"`
// Context
ctx context.Context
}

// Token return the AWS session token for the AWS_MSK_IAM_OAUTHBEARER mechanism
func (c *AWSMSKConfig) Token() (*sarama.AccessToken, error) {
token, _, err := signer.GenerateAuthToken(context.TODO(), c.Region)
token, _, err := signer.GenerateAuthToken(c.ctx, c.Region)

return &sarama.AccessToken{Token: token}, err
}
Expand All @@ -74,7 +76,7 @@ type KerberosConfig struct {
}

// ConfigureAuthentication configures authentication in sarama.Config.
func ConfigureAuthentication(config Authentication, saramaConfig *sarama.Config) error {
func ConfigureAuthentication(ctx context.Context, config Authentication, saramaConfig *sarama.Config) error {
if config.PlainText != nil {
configurePlaintext(*config.PlainText, saramaConfig)
}
Expand All @@ -84,7 +86,7 @@ func ConfigureAuthentication(config Authentication, saramaConfig *sarama.Config)
}
}
if config.SASL != nil {
if err := configureSASL(*config.SASL, saramaConfig); err != nil {
if err := configureSASL(ctx, *config.SASL, saramaConfig); err != nil {
return err
}
}
Expand All @@ -101,7 +103,7 @@ func configurePlaintext(config PlainTextConfig, saramaConfig *sarama.Config) {
saramaConfig.Net.SASL.Password = config.Password
}

func configureSASL(config SASLConfig, saramaConfig *sarama.Config) error {
func configureSASL(ctx context.Context, config SASLConfig, saramaConfig *sarama.Config) error {

if config.Username == "" && config.Mechanism != "AWS_MSK_IAM_OAUTHBEARER" {
return fmt.Errorf("username have to be provided")
Expand Down Expand Up @@ -130,6 +132,7 @@ func configureSASL(config SASLConfig, saramaConfig *sarama.Config) error {
}
saramaConfig.Net.SASL.Mechanism = awsmsk.Mechanism
case "AWS_MSK_IAM_OAUTHBEARER":
config.AWSMSK.ctx = ctx
saramaConfig.Net.SASL.Mechanism = sarama.SASLTypeOAuth
saramaConfig.Net.SASL.TokenProvider = &config.AWSMSK
tlsConfig := tls.Config{}
Expand Down
2 changes: 1 addition & 1 deletion internal/kafka/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func TestAuthentication(t *testing.T) {
for _, test := range tests {
t.Run("", func(t *testing.T) {
config := &sarama.Config{}
err := ConfigureAuthentication(test.auth, config)
err := ConfigureAuthentication(context.Background(), test.auth, config)
if test.err != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), test.err)
Expand Down
2 changes: 1 addition & 1 deletion receiver/kafkametricsreceiver/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var newMetricsReceiver = func(
}
sc.Version = version
}
if err := kafka.ConfigureAuthentication(config.Authentication, sc); err != nil {
if err := kafka.ConfigureAuthentication(ctx, config.Authentication, sc); err != nil {
return nil, err
}
scraperControllerOptions := make([]scraperhelper.ScraperControllerOption, 0, len(config.Scrapers))
Expand Down
14 changes: 7 additions & 7 deletions receiver/kafkareceiver/kafka_receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func newTracesReceiver(config Config, set receiver.Settings, unmarshaler TracesU
}, nil
}

func createKafkaClient(config Config) (sarama.ConsumerGroup, error) {
func createKafkaClient(ctx context.Context, config Config) (sarama.ConsumerGroup, error) {
saramaConfig := sarama.NewConfig()
saramaConfig.ClientID = config.ClientID
saramaConfig.Metadata.Full = config.Metadata.Full
Expand All @@ -138,7 +138,7 @@ func createKafkaClient(config Config) (sarama.ConsumerGroup, error) {
return nil, err
}
}
if err := kafka.ConfigureAuthentication(config.Authentication, saramaConfig); err != nil {
if err := kafka.ConfigureAuthentication(ctx, config.Authentication, saramaConfig); err != nil {
return nil, err
}
return sarama.NewConsumerGroup(config.Brokers, config.GroupID, saramaConfig)
Expand All @@ -157,7 +157,7 @@ func (c *kafkaTracesConsumer) Start(_ context.Context, host component.Host) erro
}
// consumerGroup may be set in tests to inject fake implementation.
if c.consumerGroup == nil {
if c.consumerGroup, err = createKafkaClient(c.config); err != nil {
if c.consumerGroup, err = createKafkaClient(ctx, c.config); err != nil {
return err
}
}
Expand Down Expand Up @@ -238,7 +238,7 @@ func newMetricsReceiver(config Config, set receiver.Settings, unmarshaler Metric
}, nil
}

func (c *kafkaMetricsConsumer) Start(_ context.Context, host component.Host) error {
func (c *kafkaMetricsConsumer) Start(ctx context.Context, host component.Host) error {
ctx, cancel := context.WithCancel(context.Background())
c.cancelConsumeLoop = cancel
obsrecv, err := receiverhelper.NewObsReport(receiverhelper.ObsReportSettings{
Expand All @@ -251,7 +251,7 @@ func (c *kafkaMetricsConsumer) Start(_ context.Context, host component.Host) err
}
// consumerGroup may be set in tests to inject fake implementation.
if c.consumerGroup == nil {
if c.consumerGroup, err = createKafkaClient(c.config); err != nil {
if c.consumerGroup, err = createKafkaClient(ctx, c.config); err != nil {
return err
}
}
Expand Down Expand Up @@ -332,7 +332,7 @@ func newLogsReceiver(config Config, set receiver.Settings, unmarshaler LogsUnmar
}, nil
}

func (c *kafkaLogsConsumer) Start(_ context.Context, host component.Host) error {
func (c *kafkaLogsConsumer) Start(ctx context.Context, host component.Host) error {
ctx, cancel := context.WithCancel(context.Background())
c.cancelConsumeLoop = cancel
obsrecv, err := receiverhelper.NewObsReport(receiverhelper.ObsReportSettings{
Expand All @@ -345,7 +345,7 @@ func (c *kafkaLogsConsumer) Start(_ context.Context, host component.Host) error
}
// consumerGroup may be set in tests to inject fake implementation.
if c.consumerGroup == nil {
if c.consumerGroup, err = createKafkaClient(c.config); err != nil {
if c.consumerGroup, err = createKafkaClient(ctx, c.config); err != nil {
return err
}
}
Expand Down

0 comments on commit 1c502fd

Please sign in to comment.