-
Notifications
You must be signed in to change notification settings - Fork 423
/
token.go
687 lines (594 loc) · 23.4 KB
/
token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
/*
Copyright 2017-2020 by the contributors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package token
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/apis/clientauthentication"
clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
"sigs.k8s.io/aws-iam-authenticator/pkg"
"sigs.k8s.io/aws-iam-authenticator/pkg/arn"
"sigs.k8s.io/aws-iam-authenticator/pkg/filecache"
"sigs.k8s.io/aws-iam-authenticator/pkg/metrics"
)
// Identity is returned on successful Verify() results. It contains a parsed
// version of the AWS identity used to create the token.
type Identity struct {
// ARN is the raw Amazon Resource Name returned by sts:GetCallerIdentity
ARN string
// CanonicalARN is the Amazon Resource Name converted to a more canonical
// representation. In particular, STS assumed role ARNs like
// "arn:aws:sts::ACCOUNTID:assumed-role/ROLENAME/SESSIONNAME" are converted
// to their IAM ARN equivalent "arn:aws:iam::ACCOUNTID:role/NAME"
CanonicalARN string
// AccountID is the 12 digit AWS account number.
AccountID string
// UserID is the unique user/role ID (e.g., "AROAAAAAAAAAAAAAAAAAA").
UserID string
// SessionName is the STS session name (or "" if this is not a
// session-based identity). For EC2 instance roles, this will be the EC2
// instance ID (e.g., "i-0123456789abcdef0"). You should only rely on it
// if you trust that _only_ EC2 is allowed to assume the IAM Role. If IAM
// users or other roles are allowed to assume the role, they can provide
// (nearly) arbitrary strings here.
SessionName string
// The AWS Access Key ID used to authenticate the request. This can be used
// in conjunction with CloudTrail to determine the identity of the individual
// if the individual assumed an IAM role before making the request.
AccessKeyID string
// ASW STS endpoint used to authenticate (expected values is sts endpoint eg: sts.us-west-2.amazonaws.com)
STSEndpoint string
}
const (
// The sts GetCallerIdentity request is valid for 15 minutes regardless of this parameters value after it has been
// signed, but we set this unused parameter to 60 for legacy reasons (we check for a value between 0 and 60 on the
// server side in 0.3.0 or earlier). IT IS IGNORED. If we can get STS to support x-amz-expires, then we should
// set this parameter to the actual expiration, and make it configurable.
requestPresignParam = 60
// The actual token expiration (presigned STS urls are valid for 15 minutes after timestamp in x-amz-date).
presignedURLExpiration = 15 * time.Minute
v1Prefix = "k8s-aws-v1."
maxTokenLenBytes = 1024 * 4
clusterIDHeader = "x-k8s-aws-id"
// Format of the X-Amz-Date header used for expiration
// https://golang.org/pkg/time/#pkg-constants
dateHeaderFormat = "20060102T150405Z"
kindExecCredential = "ExecCredential"
execInfoEnvKey = "KUBERNETES_EXEC_INFO"
stsServiceID = "sts"
)
// Token is generated and used by Kubernetes client-go to authenticate with a Kubernetes cluster.
type Token struct {
Token string
Expiration time.Time
}
// GetTokenOptions is passed to GetWithOptions to provide an extensible get token interface
type GetTokenOptions struct {
Region string
ClusterID string
AssumeRoleARN string
AssumeRoleExternalID string
SessionName string
}
// FormatError is returned when there is a problem with token that is
// an encoded sts request. This can include the url, data, action or anything
// else that prevents the sts call from being made.
type FormatError struct {
message string
}
func (e FormatError) Error() string {
return "input token was not properly formatted: " + e.message
}
// STSError is returned when there was either an error calling STS or a problem
// processing the data returned from STS.
type STSError struct {
message string
}
func (e STSError) Error() string {
return "sts getCallerIdentity failed: " + e.message
}
// NewSTSError creates a error of type STS.
func NewSTSError(m string) STSError {
return STSError{message: m}
}
// STSThrottling is returned when there was STS Throttling.
type STSThrottling struct {
message string
}
func (e STSThrottling) Error() string {
return "sts getCallerIdentity was throttled: " + e.message
}
// NewSTSError creates a error of type STS.
func NewSTSThrottling(m string) STSThrottling {
return STSThrottling{message: m}
}
var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
"x-amz-algorithm": true,
"x-amz-credential": true,
"x-amz-date": true,
"x-amz-expires": true,
"x-amz-security-token": true,
"x-amz-signature": true,
"x-amz-signedheaders": true,
}
// this is the result type from the GetCallerIdentity endpoint
type getCallerIdentityWrapper struct {
GetCallerIdentityResponse struct {
GetCallerIdentityResult struct {
Account string `json:"Account"`
Arn string `json:"Arn"`
UserID string `json:"UserId"`
} `json:"GetCallerIdentityResult"`
ResponseMetadata struct {
RequestID string `json:"RequestId"`
} `json:"ResponseMetadata"`
} `json:"GetCallerIdentityResponse"`
}
// Generator provides new tokens for the AWS IAM Authenticator.
type Generator interface {
// Get a token using the provided options
GetWithOptions(options *GetTokenOptions) (Token, error)
// GetWithSTS returns a token valid for clusterID using the given STS client.
GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error)
// FormatJSON returns the client auth formatted json for the ExecCredential auth
FormatJSON(Token) string
}
type generator struct {
forwardSessionName bool
cache bool
nowFunc func() time.Time
}
// NewGenerator creates a Generator and returns it.
func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) {
return generator{
forwardSessionName: forwardSessionName,
cache: cache,
nowFunc: time.Now,
}, nil
}
// StdinStderrTokenProvider gets MFA token from standard input.
func StdinStderrTokenProvider() (string, error) {
var v string
fmt.Fprint(os.Stderr, "Assume Role MFA token code: ")
_, err := fmt.Scanln(&v)
return v, err
}
// GetWithOptions takes a GetTokenOptions struct, builds the STS client, and wraps GetWithSTS.
// If no session has been passed in options, it will build a new session. If an
// AssumeRoleARN was passed in then assume the role for the session.
func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
if options.ClusterID == "" {
return Token{}, fmt.Errorf("ClusterID is required")
}
// create a session with the "base" credentials available
// (from environment variable, profile files, EC2 metadata, etc)
sess, err := session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: StdinStderrTokenProvider,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return Token{}, fmt.Errorf("could not create session: %v", err)
}
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if options.Region != "" {
sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint))
}
if g.cache {
// figure out what profile we're using
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := filecache.NewFileCacheProvider(
options.ClusterID,
profile,
options.AssumeRoleARN,
filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil {
sess.Config.Credentials = credentials.NewCredentials(cacheProvider)
} else {
fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}
}
// use an STS client based on the direct credentials
stsAPI := sts.New(sess)
// if a roleARN was specified, replace the STS client with one that uses
// temporary credentials from that role.
if options.AssumeRoleARN != "" {
var sessionSetters []func(*stscreds.AssumeRoleProvider)
if options.AssumeRoleExternalID != "" {
sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) {
provider.ExternalID = &options.AssumeRoleExternalID
})
}
if g.forwardSessionName {
// If the current session is already a federated identity, carry through
// this session name onto the new session to provide better debugging
// capabilities
resp, err := stsAPI.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return Token{}, err
}
userIDParts := strings.Split(*resp.UserId, ":")
if len(userIDParts) == 2 {
sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) {
provider.RoleSessionName = userIDParts[1]
})
}
} else if options.SessionName != "" {
sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) {
provider.RoleSessionName = options.SessionName
})
}
// create STS-based credentials that will assume the given role
creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...)
// create an STS API interface that uses the assumed role's temporary credentials
stsAPI = sts.New(sess, &aws.Config{Credentials: creds})
}
return g.GetWithSTS(options.ClusterID, stsAPI)
}
func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler {
return request.NamedHandler{
Name: "v4.SignRequestHandler", Fn: func(req *request.Request) {
v4.SignSDKRequestWithCurrentTime(req, nowFunc)
},
}
}
// GetWithSTS returns a token valid for clusterID using the given STS client.
func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) {
// generate an sts:GetCallerIdentity request and add our custom cluster ID header
request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
request.HTTPRequest.Header.Add(clusterIDHeader, clusterID)
// override the Sign handler so we can control the now time for testing.
request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc))
// Sign the request. The expires parameter (sets the x-amz-expires header) is
// currently ignored by STS, and the token expires 15 minutes after the x-amz-date
// timestamp regardless. We set it to 60 seconds for backwards compatibility (the
// parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between
// 0 and 60 on the server side).
// https://github.com/aws/aws-sdk-go/issues/2167
presignedURLString, err := request.Presign(requestPresignParam * time.Second)
if err != nil {
return Token{}, err
}
// Set token expiration to 1 minute before the presigned URL expires for some cushion
tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute)
// TODO: this may need to be a constant-time base64 encoding
return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil
}
// FormatJSON formats the json to support ExecCredential authentication
func (g generator) FormatJSON(token Token) string {
apiVersion := clientauthv1beta1.SchemeGroupVersion.String()
env := os.Getenv(execInfoEnvKey)
if env != "" {
cred := &clientauthentication.ExecCredential{}
if err := json.Unmarshal([]byte(env), cred); err == nil {
apiVersion = cred.APIVersion
}
}
expirationTimestamp := metav1.NewTime(token.Expiration)
execInput := &clientauthv1beta1.ExecCredential{
TypeMeta: metav1.TypeMeta{
APIVersion: apiVersion,
Kind: kindExecCredential,
},
Status: &clientauthv1beta1.ExecCredentialStatus{
ExpirationTimestamp: &expirationTimestamp,
Token: token.Token,
},
}
enc, _ := json.Marshal(execInput)
return string(enc)
}
// Verifier validates tokens by calling STS and returning the associated identity.
type Verifier interface {
Verify(token string) (*Identity, error)
}
type tokenVerifier struct {
client *http.Client
clusterID string
validSTShostnames map[string]bool
}
func getDefaultHostNameForRegion(partition *endpoints.Partition, region, service string) (string, error) {
rep, err := partition.EndpointFor(service, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
if err != nil {
return "", fmt.Errorf("Error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
parsedURL, err := url.Parse(rep.URL)
if err != nil {
return "", fmt.Errorf("Error parsing STS URL %s. err: %v", rep.URL, err)
}
return parsedURL.Hostname(), nil
}
func stsHostsForPartition(partitionID, region string) map[string]bool {
validSTShostnames := map[string]bool{}
var partition *endpoints.Partition
for _, p := range endpoints.DefaultPartitions() {
if partitionID == p.ID() {
partition = &p
break
}
}
if partition == nil {
logrus.Errorf("Partition %s not valid", partitionID)
return validSTShostnames
}
stsSvc, ok := partition.Services()[stsServiceID]
if !ok {
logrus.Errorf("STS service not found in partition %s", partitionID)
// Add the host of the current instances region if the service doesn't already exists in the partition
// so we don't fail if the service is not present in the go sdk but matches the instances region.
stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID)
if err != nil {
logrus.WithError(err).Error("Error getting default hostname")
} else {
validSTShostnames[stsHostName] = true
}
return validSTShostnames
}
stsSvcEndPoints := stsSvc.Endpoints()
for epName, ep := range stsSvcEndPoints {
rep, err := ep.ResolveEndpoint(endpoints.STSRegionalEndpointOption)
if err != nil {
logrus.WithError(err).Errorf("Error resolving endpoint for %s in partition %s", epName, partitionID)
continue
}
parsedURL, err := url.Parse(rep.URL)
if err != nil {
logrus.WithError(err).Errorf("Error parsing STS URL %s", rep.URL)
continue
}
validSTShostnames[parsedURL.Hostname()] = true
}
// Add the host of the current instances region if not already exists so we don't fail if the region is not
// present in the go sdk but matches the instances region.
if _, ok := stsSvcEndPoints[region]; !ok {
stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID)
if err != nil {
logrus.WithError(err).Error("Error getting default hostname")
return validSTShostnames
}
validSTShostnames[stsHostName] = true
}
return validSTShostnames
}
// NewVerifier creates a Verifier that is bound to the clusterID and uses the default http client.
func NewVerifier(clusterID, partitionID, region string) Verifier {
// Initialize metrics if they haven't already been initialized to avoid a
// nil pointer panic when setting metric values.
if !metrics.Initialized() {
metrics.InitMetrics(prometheus.NewRegistry())
}
return tokenVerifier{
client: &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 10 * time.Second,
},
clusterID: clusterID,
validSTShostnames: stsHostsForPartition(partitionID, region),
}
}
// verify a sts host, doc: http://docs.amazonaws.cn/en_us/general/latest/gr/rande.html#sts_region
func (v tokenVerifier) verifyHost(host string) error {
if _, ok := v.validSTShostnames[host]; !ok {
return FormatError{fmt.Sprintf("unexpected hostname %q in pre-signed URL", host)}
}
return nil
}
// Verify a token is valid for the specified clusterID. On success, returns an
// Identity that contains information about the AWS principal that created the
// token. On failure, returns nil and a non-nil error.
func (v tokenVerifier) Verify(token string) (*Identity, error) {
if len(token) > maxTokenLenBytes {
return nil, FormatError{"token is too large"}
}
if !strings.HasPrefix(token, v1Prefix) {
return nil, FormatError{fmt.Sprintf("token is missing expected %q prefix", v1Prefix)}
}
// TODO: this may need to be a constant-time base64 decoding
tokenBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(token, v1Prefix))
if err != nil {
return nil, FormatError{err.Error()}
}
parsedURL, err := url.Parse(string(tokenBytes))
if err != nil {
return nil, FormatError{err.Error()}
}
if parsedURL.Scheme != "https" {
return nil, FormatError{fmt.Sprintf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)}
}
if err = v.verifyHost(parsedURL.Host); err != nil {
return nil, err
}
stsRegion, err := getStsRegion(parsedURL.Host)
if err != nil {
return nil, err
}
if parsedURL.Path != "/" {
return nil, FormatError{"unexpected path in pre-signed URL"}
}
queryParamsLower := make(url.Values)
queryParams, err := url.ParseQuery(parsedURL.RawQuery)
if err != nil {
return nil, FormatError{"malformed query parameter"}
}
if err = validateDuplicateParameters(queryParams); err != nil {
return nil, err
}
for key, values := range queryParams {
if !parameterWhitelist[strings.ToLower(key)] {
return nil, FormatError{fmt.Sprintf("non-whitelisted query parameter %q", key)}
}
if len(values) != 1 {
return nil, FormatError{"query parameter with multiple values not supported"}
}
queryParamsLower.Set(strings.ToLower(key), values[0])
}
if queryParamsLower.Get("action") != "GetCallerIdentity" {
return nil, FormatError{"unexpected action parameter in pre-signed URL"}
}
if !hasSignedClusterIDHeader(&queryParamsLower) {
return nil, FormatError{fmt.Sprintf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)}
}
// We validate x-amz-expires is between 0 and 15 minutes (900 seconds) although currently pre-signed STS URLs, and
// therefore tokens, expire exactly 15 minutes after the x-amz-date header, regardless of x-amz-expires.
expires, err := strconv.Atoi(queryParamsLower.Get("x-amz-expires"))
if err != nil || expires < 0 || expires > 900 {
return nil, FormatError{fmt.Sprintf("invalid X-Amz-Expires parameter in pre-signed URL: %d", expires)}
}
date := queryParamsLower.Get("x-amz-date")
if date == "" {
return nil, FormatError{"X-Amz-Date parameter must be present in pre-signed URL"}
}
// Obtain AWS Access Key ID from supplied credentials
accessKeyID := strings.Split(queryParamsLower.Get("x-amz-credential"), "/")[0]
dateParam, err := time.Parse(dateHeaderFormat, date)
if err != nil {
return nil, FormatError{fmt.Sprintf("error parsing X-Amz-Date parameter %s into format %s: %s", date, dateHeaderFormat, err.Error())}
}
now := time.Now()
expiration := dateParam.Add(presignedURLExpiration)
if now.After(expiration) {
return nil, FormatError{fmt.Sprintf("X-Amz-Date parameter is expired (%.f minute expiration) %s", presignedURLExpiration.Minutes(), dateParam)}
}
req, err := http.NewRequest("GET", parsedURL.String(), nil)
req.Header.Set(clusterIDHeader, v.clusterID)
req.Header.Set("accept", "application/json")
response, err := v.client.Do(req)
if err != nil {
metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc()
// special case to avoid printing the full URL if possible
if urlErr, ok := err.(*url.Error); ok {
return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsRegion))
}
return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsRegion))
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err))
}
metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsRegion).Inc()
if response.StatusCode != 200 {
responseStr := string(responseBody[:])
// refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log
// response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"}
if strings.Contains(responseStr, "Throttling") {
metrics.Get().StsThrottling.WithLabelValues(stsRegion).Inc()
return nil, NewSTSThrottling(responseStr)
}
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsRegion, responseStr))
}
var callerIdentity getCallerIdentityWrapper
err = json.Unmarshal(responseBody, &callerIdentity)
if err != nil {
return nil, NewSTSError(err.Error())
}
id := &Identity{
AccessKeyID: accessKeyID,
STSEndpoint: parsedURL.Host,
}
return getIdentityFromSTSResponse(id, callerIdentity)
}
func getIdentityFromSTSResponse(id *Identity, wrapper getCallerIdentityWrapper) (*Identity, error) {
var err error
result := wrapper.GetCallerIdentityResponse.GetCallerIdentityResult
id.ARN = result.Arn
id.AccountID = result.Account
var principalType arn.PrincipalType
principalType, id.CanonicalARN, err = arn.Canonicalize(id.ARN)
if err != nil {
return nil, NewSTSError(err.Error())
}
// The user ID is one of:
// 1. UserID:SessionName (for assumed roles)
// 2. UserID (for IAM User principals).
// 3. AWSAccount:CallerSpecifiedName (for federated users)
// We want the entire UserID for federated users because otherwise,
// its just the account ID and is indistinguishable from the UserID
// of the root user.
if principalType == arn.FEDERATED_USER || principalType == arn.USER || principalType == arn.ROOT {
id.UserID = result.UserID
} else {
userIDParts := strings.Split(result.UserID, ":")
if len(userIDParts) == 2 {
id.UserID = userIDParts[0]
id.SessionName = userIDParts[1]
} else {
return nil, NewSTSError(fmt.Sprintf("malformed UserID %q", result.UserID))
}
}
return id, nil
}
func validateDuplicateParameters(queryParams url.Values) error {
duplicateCheck := make(map[string]bool)
for key, _ := range queryParams {
if _, found := duplicateCheck[strings.ToLower(key)]; found {
return FormatError{fmt.Sprintf("duplicate query parameter found: %q", key)}
}
duplicateCheck[strings.ToLower(key)] = true
}
return nil
}
func hasSignedClusterIDHeader(paramsLower *url.Values) bool {
signedHeaders := strings.Split(paramsLower.Get("x-amz-signedheaders"), ";")
for _, hdr := range signedHeaders {
if strings.ToLower(hdr) == strings.ToLower(clusterIDHeader) {
return true
}
}
return false
}
func getStsRegion(host string) (string, error) {
if host == "" {
return "", fmt.Errorf("host is empty")
}
parts := strings.Split(host, ".")
if len(parts) < 3 {
return "", fmt.Errorf("invalid host format: %v", host)
}
if host == "sts.amazonaws.com" {
return "global", nil
}
return parts[1], nil
}