Skip to content

Commit

Permalink
Improve error from communicating with Azure API.
Browse files Browse the repository at this point in the history
This changes the behaviour of our use of the hamilton msgraph client to
return errors when connecting which will pass the message through.

This is a reimplementation of the fix that was upstreamed here
manicminer/hamilton#280
  • Loading branch information
bigkevmcd committed May 1, 2024
1 parent 514545f commit 08483bc
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 9 deletions.
36 changes: 27 additions & 9 deletions pkg/auth/providers/azure/clients/ms_graph_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
Expand Down Expand Up @@ -44,8 +45,9 @@ func (c azureMSGraphClient) MarshalTokenJSON() (string, error) {
func (c azureMSGraphClient) GetUser(id string) (v3.Principal, error) {
user, _, err := c.userClient.Get(context.Background(), id, odata.Query{})
if err != nil {
return v3.Principal{}, err
return v3.Principal{}, fmt.Errorf("failed to get user from Azure: %w", err)
}

return c.userToPrincipal(*user)
}

Expand Down Expand Up @@ -202,15 +204,15 @@ type AccessTokenCache struct {
}

// Replace fetches the access token from a secret in Kubernetes.
func (c AccessTokenCache) Replace(cache cache.Unmarshaler, key string) {
func (c AccessTokenCache) Replace(unmarshaler cache.Unmarshaler, key string) {
secretName := fmt.Sprintf("%s:%s", common.SecretsNamespace, AccessTokenSecretName)
secret, err := common.ReadFromSecret(c.Secrets, secretName, "access-token")
if err != nil {
logrus.Errorf("[%s] failed to read the access token from Kubernetes: %v", cacheLogPrefix, err)
return
}

err = cache.Unmarshal([]byte(secret))
err = unmarshaler.Unmarshal([]byte(secret))
if err != nil {
logrus.Errorf("[%s] failed to unmarshal the access token: %v", cacheLogPrefix, err)
}
Expand All @@ -235,18 +237,19 @@ func (c AccessTokenCache) Export(cache cache.Marshaler, key string) {
// If that fails, it tries to acquire it directly from the auth provider with the credential (application secret in Azure).
// It also checks that the access token has the necessary permissions.
func NewMSGraphClient(config *v32.AzureADConfig, secrets corev1.SecretInterface) (AzureClient, error) {
c := &azureMSGraphClient{}
cred, err := confidential.NewCredFromSecret(config.ApplicationSecret)
if err != nil {
return nil, fmt.Errorf("could not create a cred from a secret: %w", err)
}

tokenCache := AccessTokenCache{Secrets: secrets}
confidentialClientApp, err := confidential.New(config.ApplicationID, cred,
confidential.WithAccessor(tokenCache),
confidential.WithAuthority(fmt.Sprintf("%s%s", config.Endpoint, config.TenantID)))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create Azure client: %w", err)
}

scope := fmt.Sprintf("%s/%s", config.GraphEndpoint, ".default")

var ar confidential.AuthResult
Expand All @@ -260,6 +263,10 @@ func NewMSGraphClient(config *v32.AzureADConfig, secrets corev1.SecretInterface)
}
}

return newMSGraphClient(config, ar), nil
}

func newMSGraphClient(config *v32.AzureADConfig, ar confidential.AuthResult) *azureMSGraphClient {
authResult := getCustomAuthResult(&ar)
authorizer := authorizer{authResult: authResult}

Expand All @@ -268,17 +275,28 @@ func NewMSGraphClient(config *v32.AzureADConfig, secrets corev1.SecretInterface)
userClient.BaseClient.Authorizer = &authorizer
userClient.BaseClient.ApiVersion = msgraph.Version10
userClient.BaseClient.DisableRetries = true
userClient.BaseClient.RetryableClient.ErrorHandler = retryableErrorHandler

groupClient := msgraph.NewGroupsClient(config.TenantID)
groupClient.BaseClient.Endpoint = environments.ApiEndpoint(config.GraphEndpoint)
groupClient.BaseClient.Authorizer = &authorizer
groupClient.BaseClient.ApiVersion = msgraph.Version10
groupClient.BaseClient.DisableRetries = true
groupClient.BaseClient.RetryableClient.ErrorHandler = retryableErrorHandler

return &azureMSGraphClient{
authResult: authResult,
userClient: userClient,
groupClient: groupClient,
}
}

func retryableErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
if resp == nil {
return nil, err
}

c.authResult = authResult
c.userClient = userClient
c.groupClient = groupClient
return c, err
return resp, nil
}

func getCustomAuthResult(result *confidential.AuthResult) *customAuthResult {
Expand Down
174 changes: 174 additions & 0 deletions pkg/auth/providers/azure/clients/ms_graph_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package clients

import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
mgmtv3 "github.com/rancher/rancher/pkg/apis/management.cattle.io/v3"
)

func TestAzureClient_connection_failures(t *testing.T) {
// This creates a listener on a random available port.
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
port := l.Addr().(*net.TCPAddr).Port
l.Close()

config := &mgmtv3.AzureADConfig{
GraphEndpoint: fmt.Sprintf("https://localhost:%d/", port),
TenantID: "test-tenant",
}

connectionTests := []struct {
name string
call func(client AzureClient) error
}{
{
name: "GetUser()",
call: func(client AzureClient) error {
_, err := client.GetUser("test-user-id")
return err
},
},
{
name: "ListUsers()",
call: func(client AzureClient) error {
_, err := client.ListUsers("LastName eq 'Smith'")
return err
},
},
{
name: "GetGroup()",
call: func(client AzureClient) error {
_, err := client.GetGroup("testing-group")
return err
},
},
{
name: "ListGroups()",
call: func(client AzureClient) error {
_, err := client.ListGroups("mailEnabled eq true")
return err
},
},
{
name: "ListGroupMemberships()",
call: func(client AzureClient) error {
_, err := client.ListGroupMemberships("test-user-id")
return err
},
},
}

for _, tt := range connectionTests {
t.Run(tt.name, func(t *testing.T) {
client := newMSGraphClient(config, confidential.AuthResult{
AccessToken: "test-token",
})
client.userClient.BaseClient.RetryableClient.RetryMax = 1
client.groupClient.BaseClient.RetryableClient.RetryMax = 1

err := tt.call(client)

if err == nil {
t.Error("expected to get an error, got nil")
}

if msg := err.Error(); !strings.Contains(msg, "connect: connection refused") {
t.Errorf("got %s, want message with 'connection refused'", msg)
}
})
}
}

func TestAzureClient_invalid_responses(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
paths := []string{
"/v1.0/test-tenant/users/test-user-id",
"/v1.0/test-tenant/users",
"/v1.0/test-tenant/groups/testing-group",
"/v1.0/test-tenant/groups",
"/v1.0/test-tenant/users/test-user-id/transitiveMemberOf",
}
if slices.Contains(paths, r.URL.Path) {
fmt.Fprintln(w, `{ "`+strings.Repeat("a", 513)+`" 1 }`)
return
}
http.Error(w, fmt.Sprintf("didn't match: %s", r.URL.Path), http.StatusNotFound)
}))
defer ts.Close()

config := &mgmtv3.AzureADConfig{
GraphEndpoint: ts.URL,
TenantID: "test-tenant",
}

connectionTests := []struct {
name string
call func(client AzureClient) error
}{
{
name: "GetUser()",
call: func(client AzureClient) error {
_, err := client.GetUser("test-user-id")
return err
},
},
{
name: "ListUsers()",
call: func(client AzureClient) error {
_, err := client.ListUsers("LastName eq 'Smith'")
return err
},
},
{
name: "GetGroup()",
call: func(client AzureClient) error {
_, err := client.GetGroup("testing-group")
return err
},
},
{
name: "ListGroups()",
call: func(client AzureClient) error {
_, err := client.ListGroups("mailEnabled eq true")
return err
},
},
{
name: "ListGroupMemberships()",
call: func(client AzureClient) error {
_, err := client.ListGroupMemberships("test-user-id")
return err
},
},
}

for _, tt := range connectionTests {
t.Run(tt.name, func(t *testing.T) {
client := newMSGraphClient(config, confidential.AuthResult{
AccessToken: "test-token",
})
client.userClient.BaseClient.RetryableClient.RetryMax = 1
client.groupClient.BaseClient.RetryableClient.RetryMax = 1

err := tt.call(client)

if err == nil {
t.Error("expected to get an error, got nil")
}

if msg := err.Error(); !strings.Contains(msg, "invalid character") {
t.Errorf("got %s, want message with 'invalid character'", msg)
}
})
}
}

0 comments on commit 08483bc

Please sign in to comment.