diff --git a/firehose/firehose.go b/firehose/firehose.go index a1eb5dc..0946d56 100644 --- a/firehose/firehose.go +++ b/firehose/firehose.go @@ -68,7 +68,7 @@ type OutputPlugin struct { // NewOutputPlugin creates an OutputPlugin object func NewOutputPlugin(region, deliveryStream, dataKeys, roleARN, firehoseEndpoint, stsEndpoint, timeKey, timeFmt, logKey string, pluginID int) (*OutputPlugin, error) { - client, err := newPutRecordBatcher(roleARN, region, firehoseEndpoint, stsEndpoint) + client, err := newPutRecordBatcher(roleARN, region, firehoseEndpoint, stsEndpoint, pluginID) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func NewOutputPlugin(region, deliveryStream, dataKeys, roleARN, firehoseEndpoint }, nil } -func newPutRecordBatcher(roleARN, region, firehoseEndpoint, stsEndpoint string) (*firehose.Firehose, error) { +func newPutRecordBatcher(roleARN, region, firehoseEndpoint, stsEndpoint string, pluginID int) (*firehose.Firehose, error) { customResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { if service == endpoints.FirehoseServiceID && firehoseEndpoint != "" { return endpoints.ResolvedEndpoint{ @@ -125,22 +125,49 @@ func newPutRecordBatcher(roleARN, region, firehoseEndpoint, stsEndpoint string) return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) } - sess, err := session.NewSession(&aws.Config{ + // Fetching base credentials + baseConfig := &aws.Config{ Region: aws.String(region), EndpointResolver: endpoints.ResolverFunc(customResolverFn), CredentialsChainVerboseErrors: aws.Bool(true), - }) + } + + sess, err := session.NewSession(baseConfig) if err != nil { return nil, err } - svcConfig := &aws.Config{} + var svcSess = sess + var svcConfig = baseConfig + eksRole := os.Getenv("EKS_POD_EXECUTION_ROLE") + if eksRole != "" { + logrus.Debugf("[firehose %d] Fetching EKS pod credentials.\n", pluginID) + eksConfig := &aws.Config{} + creds := stscreds.NewCredentials(svcSess, eksRole) + eksConfig.Credentials = creds + eksConfig.Region = aws.String(region) + svcConfig = eksConfig + + svcSess, err = session.NewSession(svcConfig) + if err != nil { + return nil, err + } + } if roleARN != "" { - creds := stscreds.NewCredentials(sess, roleARN) - svcConfig.Credentials = creds + logrus.Debugf("[firehose %d] Fetching credentials for %s\n", pluginID, roleARN) + stsConfig := &aws.Config{} + creds := stscreds.NewCredentials(svcSess, roleARN) + stsConfig.Credentials = creds + stsConfig.Region = aws.String(region) + svcConfig = stsConfig + + svcSess, err = session.NewSession(svcConfig) + if err != nil { + return nil, err + } } - client := firehose.New(sess, svcConfig) + client := firehose.New(svcSess, svcConfig) client.Handlers.Build.PushBackNamed(plugins.CustomUserAgentHandler()) return client, nil }