diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 4f730598c..b2eb2555a 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,12 +6,13 @@ import ( ) const ( - Namespace = "aws_iam_authenticator" - Malformed = "malformed_request" - Invalid = "invalid_token" - STSError = "sts_error" - Unknown = "uknown_user" - Success = "success" + Namespace = "aws_iam_authenticator" + Malformed = "malformed_request" + Invalid = "invalid_token" + STSError = "sts_error" + STSThrottling = "sts_throttling" + Unknown = "uknown_user" + Success = "success" ) var authenticatorMetrics Metrics diff --git a/pkg/server/server.go b/pkg/server/server.go index 092894f5c..ddf21f142 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -330,7 +330,13 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) // if the token is invalid, reject with a 403 identity, err := h.verifier.Verify(tokenReview.Spec.Token) if err != nil { - if _, ok := err.(token.STSError); ok { + if _, ok := err.(token.STSThrottling); ok { + metrics.Get().Latency.WithLabelValues(metrics.STSThrottling).Observe(duration(start)) + log.WithError(err).Warn("access denied") + w.WriteHeader(http.StatusTooManyRequests) + w.Write(tokenReviewDenyJSON) + return + } else if _, ok := err.(token.STSError); ok { metrics.Get().Latency.WithLabelValues(metrics.STSError).Observe(duration(start)) } else { metrics.Get().Latency.WithLabelValues(metrics.Invalid).Observe(duration(start)) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index eb2ce541e..8b4e256bb 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -117,7 +117,7 @@ func createIndexer() cache.Indexer { // Count of expected metrics type validateOpts struct { // The expected number of latency entries for each label. - malformed, invalidToken, unknownUser, success, stsError uint64 + malformed, invalidToken, unknownUser, success, stsError, stsThrottling uint64 } func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) { @@ -135,7 +135,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { } for _, m := range metricFamilies { if strings.HasPrefix(m.GetName(), "aws_iam_authenticator_authenticate_latency_seconds") { - var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64 + var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError, actualSTSThrottling uint64 for _, metric := range m.GetMetric() { if len(metric.Label) != 1 { t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label) @@ -155,6 +155,8 @@ func validateMetrics(t *testing.T, opts validateOpts) { actualUnknown = metric.GetHistogram().GetSampleCount() case metrics.STSError: actualSTSError = metric.GetHistogram().GetSampleCount() + case metrics.STSThrottling: + actualSTSThrottling = metric.GetHistogram().GetSampleCount() default: t.Errorf("Unknown result for latency label: %s", *label.Value) @@ -165,6 +167,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { checkHistogramSampleCount(t, metrics.Invalid, actualInvalid, opts.invalidToken) checkHistogramSampleCount(t, metrics.Unknown, actualUnknown, opts.unknownUser) checkHistogramSampleCount(t, metrics.STSError, actualSTSError, opts.stsError) + checkHistogramSampleCount(t, metrics.STSThrottling, actualSTSThrottling, opts.stsThrottling) } } } @@ -364,6 +367,27 @@ func TestAuthenticateVerifierErrorCRD(t *testing.T) { validateMetrics(t, validateOpts{invalidToken: 1}) } +func TestAuthenticateVerifierSTSThrottling(t *testing.T) { + resp := httptest.NewRecorder() + + data, err := json.Marshal(authenticationv1beta1.TokenReview{ + Spec: authenticationv1beta1.TokenReviewSpec{ + Token: "token", + }, + }) + if err != nil { + t.Fatalf("Could not marshal in put data: %v", err) + } + req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data)) + h := setup(&testVerifier{err: token.STSThrottling{}}) + h.authenticateEndpoint(resp, req) + if resp.Code != http.StatusTooManyRequests { + t.Errorf("Expected status code %d, was %d", http.StatusTooManyRequests, resp.Code) + } + verifyBodyContains(t, resp, string(tokenReviewDenyJSON)) + validateMetrics(t, validateOpts{stsThrottling: 1}) +} + func TestAuthenticateVerifierSTSError(t *testing.T) { resp := httptest.NewRecorder() diff --git a/pkg/token/token.go b/pkg/token/token.go index 1582b3c3e..35e75e217 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -139,6 +139,14 @@ func NewSTSError(m string) STSError { return STSError{message: m} } +// STSThrottling is returned when there was STS Throttling. +type STSThrottling struct { +} + +func (e STSThrottling) Error() string { + return "STSThrottling" +} + var parameterWhitelist = map[string]bool{ "action": true, "version": true, @@ -570,7 +578,11 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc() if response.StatusCode != 200 { - return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, string(responseBody[:]))) + responseStr := string(responseBody[:]) + if strings.Contains(responseStr, "Throttling") { + return nil, STSThrottling{} + } + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr)) } var callerIdentity getCallerIdentityWrapper diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 1e2ad6c32..62d3ac06f 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -60,6 +60,13 @@ func assertSTSError(t *testing.T, err error) { } } +func assertSTSThrottling(t *testing.T, err error) { + t.Helper() + if _, ok := err.(STSThrottling); !ok { + t.Errorf("Expected err %v to be an STSThrottling but was not", err) + } +} + var ( now = time.Now() timeStr = now.UTC().Format("20060102T150405Z") @@ -194,6 +201,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) { validationErrorTest(t, "aws", toToken(fmt.Sprintf("https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ASIAAAAAAAAAAAAAAAAA%%2F20220601%%2Fus-west-2%%2Fsts%%2Faws4_request&X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Security-Token=XXXXXXXXXXXXX&X-Amz-SignedHeaders=host%%3Bx-k8s-aws-id&x-amz-credential=eve&X-Amz-Signature=999999999999999999", timeStr)), "input token was not properly formatted: duplicate query parameter found:") } +func TestVerifyHTTPThrottling(t *testing.T) { + testVerifier := newVerifier("aws", 400, "{\\\"Error\\\":{\\\"Code\\\":\\\"Throttling\\\",\\\"Message\\\":\\\"Rate exceeded\\\",\\\"Type\\\":\\\"Sender\\\"},\\\"RequestId\\\":\\\"8c2d3520-24e1-4d5c-ac55-7e226335f447\\\"}", nil) + _, err := testVerifier.Verify(validToken) + errorContains(t, err, "Throttling") + assertSTSThrottling(t, err) +} + func TestVerifyHTTPError(t *testing.T) { _, err := newVerifier("aws", 0, "", errors.New("an error")).Verify(validToken) errorContains(t, err, "error during GET: an error")