Skip to content

Commit

Permalink
Updated the Arc default logic
Browse files Browse the repository at this point in the history
Updated tests and merged andyohart/managed-identity
  • Loading branch information
4gust committed Nov 11, 2024
1 parent 7ba9974 commit 3f42cd8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 58 deletions.
11 changes: 7 additions & 4 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption {
//
// Options: [WithHTTPClient]
func New(id ID, options ...ClientOption) (Client, error) {
source, err := GetSource()
source, err := getSource()
if err != nil {
return Client{}, err
}
Expand Down Expand Up @@ -211,8 +211,8 @@ func New(id ID, options ...ClientOption) (Client, error) {
return client, nil
}

// GetSource detects and returns the managed identity source available on the environment.
func GetSource() (Source, error) {
// getSource detects and returns the managed identity source available on the environment.
func getSource() (Source, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
identityHeader := os.Getenv(identityHeaderEnvVar)
identityServerThumbprint := os.Getenv(identityServerThumbprintEnvVar)
Expand Down Expand Up @@ -418,7 +418,10 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.R
}

func createAzureArcAuthRequest(ctx context.Context, resource string, key string) (*http.Request, error) {
identityEndpoint := azureArcEndpoint
identityEndpoint := os.Getenv(identityEndpointEnvVar)
if identityEndpoint == "" {
identityEndpoint = azureArcEndpoint
}
msiEndpoint, parseErr := url.Parse(identityEndpoint)

if parseErr != nil {
Expand Down
74 changes: 20 additions & 54 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,14 +499,6 @@ func setEnvVars(t *testing.T, source Source) {
}
}

func unsetEnvVars(t *testing.T) {
t.Setenv(identityEndpointEnvVar, "")
t.Setenv(identityHeaderEnvVar, "")
t.Setenv(identityServerThumbprintEnvVar, "")
t.Setenv(imdsEndVar, "")
t.Setenv(msiEndpointEnvVar, "")
}

func setCustomAzureArcPlatformPath(t *testing.T, path string) {
originalFunc := getAzureArcPlatformPath
getAzureArcPlatformPath = func(string) string {
Expand All @@ -528,11 +520,10 @@ func setCustomAzureArcFilePath(t *testing.T, path string) {
func TestSource(t *testing.T) {
for _, testCase := range []Source{AzureArc, DefaultToIMDS} {
t.Run(string(testCase), func(t *testing.T) {
unsetEnvVars(t)
setEnvVars(t, testCase)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)

actualSource, err := GetSource()
actualSource, err := getSource()
if err != nil {
t.Fatalf("error while getting source: %s", err.Error())
}
Expand All @@ -544,7 +535,6 @@ func TestSource(t *testing.T) {
}

func TestAzureArcReturnsWhenHimdsFound(t *testing.T) {
unsetEnvVars(t)
mockFilePath := filepath.Join(t.TempDir(), "himds")
setCustomAzureArcFilePath(t, mockFilePath)

Expand Down Expand Up @@ -681,7 +671,6 @@ func TestAzureArc(t *testing.T) {
testCaseFilePath := getMockFilePath(t)

endpoint := azureArcEndpoint
unsetEnvVars(t)
setEnvVars(t, AzureArc)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)

Expand Down Expand Up @@ -719,9 +708,9 @@ func TestAzureArc(t *testing.T) {

client, err := New(SystemAssigned(), WithHTTPClient(&mockClient))
if err != nil {
t.Fatalf(`error not expected, got error: "%s"`, err)
t.Fatal(err)
}
result, err := client.AcquireToken(context.Background(), resource)
result, err := client.AcquireToken(context.Background(), resourceDefaultSuffix)
if err != nil {
t.Fatalf(`error not expected, got error: "%s"`, err)
}
Expand All @@ -735,7 +724,7 @@ func TestAzureArc(t *testing.T) {
if query.Get(apiVersionQueryParameterName) != azureArcAPIVersion {
t.Fatalf("api-version not on %s got %s", azureArcAPIVersion, query.Get(apiVersionQueryParameterName))
}
if query.Get(resourceQueryParameterName) != strings.TrimSuffix(resource, "/.default") {
if query.Get(resourceQueryParameterName) != strings.TrimSuffix(resourceDefaultSuffix, "/.default") {
t.Fatal("suffix /.default was not removed.")
}
if result.Metadata.TokenSource != base.IdentityProvider {
Expand All @@ -744,7 +733,7 @@ func TestAzureArc(t *testing.T) {
if result.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, result.AccessToken)
}
result, err = client.AcquireToken(context.Background(), resource)
result, err = client.AcquireToken(context.Background(), resourceDefaultSuffix)
if err != nil {
t.Fatal(err)
}
Expand All @@ -755,7 +744,7 @@ func TestAzureArc(t *testing.T) {
if err != nil {
t.Fatal(err)
}
result, err = secondFakeClient.AcquireToken(context.Background(), resource)
result, err = secondFakeClient.AcquireToken(context.Background(), resourceDefaultSuffix)
if err != nil {
t.Fatal(err)
}
Expand All @@ -766,7 +755,7 @@ func TestAzureArc(t *testing.T) {
}

func TestAzureArcOnlySystemAssignedSupported(t *testing.T) {
unsetEnvVars(t)

setEnvVars(t, AzureArc)
mockClient := mock.Client{}

Expand All @@ -776,19 +765,15 @@ func TestAzureArcOnlySystemAssignedSupported(t *testing.T) {
UserAssignedObjectID("ObjectId"),
UserAssignedResourceID("resourceid")} {
_, err := New(testCase, WithHTTPClient(&mockClient))
if err != nil {
println(err.Error())
if err.Error() != "azure Arc doesn't support user assigned managed identities" {
t.Fatalf(`expected error: azure arc not supported error, got error: "%v"`, err)
}
return
} else {
if err == nil {
t.Fatal(`expected error: azure arc not supported error"`)
}
if err.Error() != "azure Arc doesn't support user assigned managed identities" {
t.Fatalf(`expected error: azure arc not supported error, got error: "%v"`, err)
}
}
}
func TestAzureArcPlatformSupported(t *testing.T) {
unsetEnvVars(t)
setEnvVars(t, AzureArc)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)
mockClient := mock.Client{}
Expand All @@ -815,7 +800,6 @@ func TestAzureArcPlatformSupported(t *testing.T) {
}

func TestAzureArcErrors(t *testing.T) {
unsetEnvVars(t)
setEnvVars(t, AzureArc)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)
testCaseFilePath := getMockFilePath(t)
Expand Down Expand Up @@ -878,14 +862,9 @@ func TestAzureArcErrors(t *testing.T) {
mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithHTTPHeader(headers),
mock.WithBody(responseBody))

// resetting cache
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(SystemAssigned(), WithHTTPClient(&mockClient))
if err != nil {
t.Fatalf(`expected error: "%v" "`, testCase.expectedError)
t.Fatalf(`unexpected error: "%v" "`, testCase.expectedError)
return
}
result, err := client.AcquireToken(context.Background(), resource)
Expand All @@ -902,33 +881,21 @@ func TestAzureArcErrors(t *testing.T) {

func TestSystemAssignedReturnsAcquireTokenFailure(t *testing.T) {
testCases := []struct {
code int
err string
desc string
correlationID string
code int
err string
desc string
}{
{code: http.StatusNotFound,
err: "",
desc: "",
correlationID: "121212"},
{code: http.StatusNotImplemented,
err: "",
desc: "",
correlationID: "121212"},
{code: http.StatusServiceUnavailable,
err: "",
desc: "",
correlationID: "121212"},
{code: http.StatusNotFound},
{code: http.StatusNotImplemented},
{code: http.StatusServiceUnavailable},
{code: http.StatusBadRequest,
err: "invalid_request",
desc: "Identity not found",
correlationID: "121212",
err: "invalid_request",
desc: "Identity not found",
},
}

for _, testCase := range testCases {
t.Run(http.StatusText(testCase.code), func(t *testing.T) {
unsetEnvVars(t)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)

fakeErrorClient := mock.Client{}
Expand Down Expand Up @@ -1005,7 +972,6 @@ func TestCreatingIMDSClient(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
unsetEnvVars(t)
setCustomAzureArcFilePath(t, fakeAzureArcFilePath)

client, err := New(tt.id)
Expand Down

0 comments on commit 3f42cd8

Please sign in to comment.