Skip to content

Commit

Permalink
return 429 for STS throttling
Browse files Browse the repository at this point in the history
  • Loading branch information
nnmin-aws committed Sep 19, 2023
1 parent 8b1fe5c commit 4de7547
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 10 deletions.
13 changes: 7 additions & 6 deletions pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
)

const (
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
Unknown = "uknown_user"
Success = "success"
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
STSThrottling = "sts_throttling"
Unknown = "uknown_user"
Success = "success"
)

var authenticatorMetrics Metrics
Expand Down
8 changes: 7 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,13 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
// if the token is invalid, reject with a 403
identity, err := h.verifier.Verify(tokenReview.Spec.Token)
if err != nil {
if _, ok := err.(token.STSError); ok {
if _, ok := err.(token.STSThrottling); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSThrottling).Observe(duration(start))
log.WithError(err).Warn("access denied")
w.WriteHeader(http.StatusTooManyRequests)
w.Write(tokenReviewDenyJSON)
return
} else if _, ok := err.(token.STSError); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSError).Observe(duration(start))
} else {
metrics.Get().Latency.WithLabelValues(metrics.Invalid).Observe(duration(start))
Expand Down
28 changes: 26 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func createIndexer() cache.Indexer {
// Count of expected metrics
type validateOpts struct {
// The expected number of latency entries for each label.
malformed, invalidToken, unknownUser, success, stsError uint64
malformed, invalidToken, unknownUser, success, stsError, stsThrottling uint64
}

func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) {
Expand All @@ -135,7 +135,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
}
for _, m := range metricFamilies {
if strings.HasPrefix(m.GetName(), "aws_iam_authenticator_authenticate_latency_seconds") {
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError, actualSTSThrottling uint64
for _, metric := range m.GetMetric() {
if len(metric.Label) != 1 {
t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label)
Expand All @@ -155,6 +155,8 @@ func validateMetrics(t *testing.T, opts validateOpts) {
actualUnknown = metric.GetHistogram().GetSampleCount()
case metrics.STSError:
actualSTSError = metric.GetHistogram().GetSampleCount()
case metrics.STSThrottling:
actualSTSThrottling = metric.GetHistogram().GetSampleCount()
default:
t.Errorf("Unknown result for latency label: %s", *label.Value)

Expand All @@ -165,6 +167,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
checkHistogramSampleCount(t, metrics.Invalid, actualInvalid, opts.invalidToken)
checkHistogramSampleCount(t, metrics.Unknown, actualUnknown, opts.unknownUser)
checkHistogramSampleCount(t, metrics.STSError, actualSTSError, opts.stsError)
checkHistogramSampleCount(t, metrics.STSThrottling, actualSTSThrottling, opts.stsThrottling)
}
}
}
Expand Down Expand Up @@ -364,6 +367,27 @@ func TestAuthenticateVerifierErrorCRD(t *testing.T) {
validateMetrics(t, validateOpts{invalidToken: 1})
}

func TestAuthenticateVerifierSTSThrottling(t *testing.T) {
resp := httptest.NewRecorder()

data, err := json.Marshal(authenticationv1beta1.TokenReview{
Spec: authenticationv1beta1.TokenReviewSpec{
Token: "token",
},
})
if err != nil {
t.Fatalf("Could not marshal in put data: %v", err)
}
req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data))
h := setup(&testVerifier{err: token.STSThrottling{}})
h.authenticateEndpoint(resp, req)
if resp.Code != http.StatusTooManyRequests {
t.Errorf("Expected status code %d, was %d", http.StatusTooManyRequests, resp.Code)
}
verifyBodyContains(t, resp, string(tokenReviewDenyJSON))
validateMetrics(t, validateOpts{stsThrottling: 1})
}

func TestAuthenticateVerifierSTSError(t *testing.T) {
resp := httptest.NewRecorder()

Expand Down
14 changes: 13 additions & 1 deletion pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ func NewSTSError(m string) STSError {
return STSError{message: m}
}

// STSError is returned when there was STS Throttling.
type STSThrottling struct {
}

func (e STSThrottling) Error() string {
return "STSThrottling"
}

var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
Expand Down Expand Up @@ -570,7 +578,11 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {

metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc()
if response.StatusCode != 200 {
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, string(responseBody[:])))
responseStr := string(responseBody[:])
if strings.Contains(responseStr, "Throttling") {
return nil, STSThrottling{}
}
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr))
}

var callerIdentity getCallerIdentityWrapper
Expand Down

0 comments on commit 4de7547

Please sign in to comment.