Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use retryable http client in Azure authz provider #372

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 4 additions & 36 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,19 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"

"go.kubeguard.dev/guard/auth"
"go.kubeguard.dev/guard/auth/providers/azure/graph"
"go.kubeguard.dev/guard/util/httpclient"
azureutils "go.kubeguard.dev/guard/util/azure"

"github.com/Azure/go-autorest/autorest/azure"
"github.com/coreos/go-oidc"
"github.com/golang-jwt/jwt/v4"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"golang.org/x/oauth2"
authv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -120,7 +117,7 @@ func getOIDCIssuerProvider(issuerURL string, issuerGetRetryCount int) (*oidc.Pro

// NOTE: we start a root context here to allow background remote key set refresh
ctx := context.Background()
ctx = withRetryableHttpClient(ctx, issuerGetRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, issuerGetRetryCount)
provider, err := oidc.NewProvider(ctx, issuerURL)
if err != nil {
// failed in this attempt, let other attempts retry
Expand Down Expand Up @@ -180,35 +177,6 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {
return c, nil
}

// makeRetryableHttpClient creates an HTTP client which attempts the request
// (1 + retryCount) times and has a 3 second timeout per attempt.
func makeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := makeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
Expand All @@ -217,7 +185,7 @@ type metadataJSON struct {
// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient(retryCount)
retryClient := azureutils.MakeRetryableHttpClient(retryCount)

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
Expand Down Expand Up @@ -261,7 +229,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf
}
}

ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
if klog.V(7).Enabled() {
Expand Down
10 changes: 8 additions & 2 deletions authz/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"go.kubeguard.dev/guard/authz"
authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options"
"go.kubeguard.dev/guard/authz/providers/azure/rbac"
azureutils "go.kubeguard.dev/guard/util/azure"
errutils "go.kubeguard.dev/guard/util/error"

"github.com/Azure/go-autorest/autorest/azure"
Expand All @@ -49,7 +50,8 @@ func init() {
}

type Authorizer struct {
rbacClient *rbac.AccessInfo
rbacClient *rbac.AccessInfo
httpClientRetryCount int
}

func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
Expand All @@ -64,7 +66,9 @@ func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error)
}

func newAuthzClient(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
c := &Authorizer{}
c := &Authorizer{
httpClientRetryCount: authopts.HttpClientRetryCount,
}

authzInfoVal, err := getAuthzInfo(authopts.Environment)
if err != nil {
Expand Down Expand Up @@ -120,6 +124,8 @@ func (s Authorizer) Check(ctx context.Context, request *authzv1.SubjectAccessRev
return &authzv1.SubjectAccessReviewStatus{Allowed: true, Reason: rbac.AccessAllowedVerdict}, nil
}

ctx = azureutils.WithRetryableHttpClient(ctx, s.httpClientRetryCount)

if s.rbacClient.IsTokenExpired() {
if err := s.rbacClient.RefreshToken(ctx); err != nil {
return nil, errutils.WithCode(err, http.StatusInternalServerError)
Expand Down
14 changes: 9 additions & 5 deletions authz/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package azure

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
Expand All @@ -37,7 +38,8 @@ import (
)

const (
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
httpClientRetryCount = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a unit test for covering the retry behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

)

func clientSetup(serverUrl, mode string) (*Authorizer, error) {
Expand All @@ -52,9 +54,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) {
}

authOpts := auth.Options{
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
HttpClientRetryCount: httpClientRetryCount,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert that it try 2 times ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}

authzInfo := rbac.AuthzInfo{
Expand Down Expand Up @@ -171,6 +174,7 @@ func TestCheck(t *testing.T) {
assert.Nilf(t, resp, "response should be nil")
assert.NotNilf(t, err, "should get error")
assert.Contains(t, err.Error(), "Error occured during authorization check")
assert.Contains(t, err.Error(), fmt.Sprintf("giving up after %d attempt", httpClientRetryCount+1))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bcho @julienstroheker this line can assert that it gives up after 3 attempts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please have another round of review? Thanks

if v, ok := err.(errutils.HttpStatusCode); ok {
assert.Equal(t, v.Code(), http.StatusInternalServerError)
}
Expand All @@ -194,7 +198,7 @@ func TestCheck(t *testing.T) {
resp, err := client.Check(ctx, request, store)
assert.Nilf(t, resp, "response should be nil")
assert.NotNilf(t, err, "should get error")
assert.Contains(t, err.Error(), "Checkaccess requests have timed out")
assert.Contains(t, err.Error(), "context deadline exceeded")
if v, ok := err.(errutils.HttpStatusCode); ok {
assert.Equal(t, v.Code(), http.StatusInternalServerError)
}
Expand Down
10 changes: 7 additions & 3 deletions authz/providers/azure/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type AccessInfo struct {
skipAuthzForNonAADUsers bool
allowNonResDiscoveryPathAccess bool
useNamespaceResourceScopeFormat bool
httpClientRetryCount int
lock sync.RWMutex
}

Expand Down Expand Up @@ -155,7 +156,7 @@ func getClusterType(clsType string) string {
}
}

func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options) (*AccessInfo, error) {
func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options, authopts auth.Options) (*AccessInfo, error) {
u := &AccessInfo{
client: httpclient.DefaultHTTPClient,
headers: http.Header{
Expand All @@ -169,6 +170,7 @@ func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts aut
skipAuthzForNonAADUsers: opts.SkipAuthzForNonAADUsers,
allowNonResDiscoveryPathAccess: opts.AllowNonResDiscoveryPathAccess,
useNamespaceResourceScopeFormat: opts.UseNamespaceResourceScopeFormat,
httpClientRetryCount: authopts.HttpClientRetryCount,
}

u.skipCheck = make(map[string]void, len(opts.SkipAuthzCheck))
Expand Down Expand Up @@ -207,7 +209,7 @@ func New(opts authzOpts.Options, authopts auth.Options, authzInfo *AuthzInfo) (*
tokenProvider = graph.NewAKSTokenProvider(opts.AKSAuthzTokenURL, authopts.TenantID)
}

return newAccessInfo(tokenProvider, rbacURL, opts)
return newAccessInfo(tokenProvider, rbacURL, opts, authopts)
}

func (a *AccessInfo) RefreshToken(ctx context.Context) error {
Expand Down Expand Up @@ -328,6 +330,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut
// create a request id for every checkaccess request
requestUUID := uuid.New()
reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()})
reqContext = azureutils.WithRetryableHttpClient(reqContext, a.httpClientRetryCount)
err := a.sendCheckAccessRequest(reqContext, checkAccessUsername, checkAccessURL, body, ch)
if err != nil {
code := http.StatusInternalServerError
Expand Down Expand Up @@ -397,7 +400,8 @@ func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessUser
// start time to calculate checkaccess duration
start := time.Now()
klog.V(5).Infof("Sending checkAccess request with correlationID: %s", correlationID[0])
resp, err := a.client.Do(req)
client := azureutils.LoadClientWithContext(ctx, a.client)
resp, err := client.Do(req)
duration := time.Since(start).Seconds()
if err != nil {
checkAccessTotal.WithLabelValues(internalServerCode).Inc()
Expand Down
40 changes: 40 additions & 0 deletions util/azure/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"io"
"log"
"net/http"
"path"
"strconv"
Expand All @@ -32,9 +33,11 @@ import (
"go.kubeguard.dev/guard/util/httpclient"

"github.com/Azure/go-autorest/autorest/azure"
"github.com/hashicorp/go-retryablehttp"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/oauth2"
v "gomodules.xyz/x/version"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -502,6 +505,43 @@ func fetchDataActionsList(ctx context.Context) ([]Operation, error) {
return finalOperations, nil
}

// MakeRetryableHttpClient creates an HTTP client which attempts the request
// (1 + retryCount) times and has a 3 second timeout per attempt.
func MakeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

// WithRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func WithRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := MakeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

func LoadClientWithContext(ctx context.Context, defaultClient *http.Client) *http.Client {
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
return c
}

return defaultClient
}

func init() {
prometheus.MustRegister(DiscoverResourcesTotalDuration, discoverResourcesAzureCallDuration, discoverResourcesApiServerCallDuration, counterDiscoverResources, counterGetOperationsResources)
}
Loading