diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index af684a32..a612c2c1 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,6 +46,13 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } +// WithHTTPHeader sets the HTTP headers of the response to the specified value. +func WithHTTPHeader(header http.Header) responseOption { + return respOpt(func(r *response) { + r.headers = header + }) +} + // WithHTTPStatusCode sets the HTTP statusCode of response to the specified value. func WithHTTPStatusCode(statusCode int) responseOption { return respOpt(func(r *response) { diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 4b14ca48..d68eb1d4 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -16,6 +16,9 @@ import ( "io" "net/http" "net/url" + "os" + "path/filepath" + "runtime" "strings" "time" @@ -36,22 +39,42 @@ const ( CloudShell Source = "CloudShell" AppService Source = "AppService" - // General request querry parameter names - metaHTTPHeaderName = "Metadata" - apiVersionQuerryParameterName = "api-version" - resourceQuerryParameterName = "resource" + // General request query parameter names + metaHTTPHeaderName = "Metadata" + apiVersionQueryParameterName = "api-version" + resourceQueryParameterName = "resource" + wwwAuthenticateHeaderName = "www-authenticate" - // UAMI querry parameter name + // UAMI query parameter name miQueryParameterClientId = "client_id" miQueryParameterObjectId = "object_id" miQueryParameterResourceId = "msi_res_id" // IMDS - imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" - imdsAPIVersion = "2018-02-01" - + imdsDefaultEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + imdsAPIVersion = "2018-02-01" systemAssignedManagedIdentity = "system_assigned_managed_identity" - defaultRetryCount = 3 + + // Azure Arc + azureArcEndpoint = "http://127.0.0.1:40342/metadata/identity/oauth2/token" + azureArcAPIVersion = "2020-06-01" + azureArcFileExtension = ".key" + azureArcMaxFileSizeBytes int64 = 4096 + linuxTokenPath = "/var/opt/azcmagent/tokens" + linuxHimdsPath = "/opt/azcmagent/bin/himds" + azureConnectedMachine = "AzureConnectedMachineAgent" + himdsExecutableName = "himds.exe" + tokenName = "Tokens" + + // Environment Variables + identityEndpointEnvVar = "IDENTITY_ENDPOINT" + identityHeaderEnvVar = "IDENTITY_HEADER" + azurePodIdentityAuthorityHostEnvVar = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + imdsEndVar = "IMDS_ENDPOINT" + msiEndpointEnvVar = "MSI_ENDPOINT" + identityServerThumbprintEnvVar = "IDENTITY_SERVER_THUMBPRINT" + + defaultRetryCount = 3 ) // retry codes for IMDS @@ -75,6 +98,29 @@ var retryStatusCodes = []int{ http.StatusInternalServerError, // 500 http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 + +} + +var getAzureArcPlatformPath = func(platform string) string { + switch platform { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), azureConnectedMachine, tokenName) + case "linux": + return linuxTokenPath + default: + return "" + } +} + +var getAzureArcHimdsFilePath = func(platform string) string { + switch platform { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), azureConnectedMachine, himdsExecutableName) + case "linux": + return linuxHimdsPath + default: + return "" + } } type Source string @@ -103,6 +149,7 @@ type Client struct { httpClient ops.HTTPClient miType ID source Source + authParams authority.AuthParams retryPolicyEnabled bool } @@ -145,6 +192,18 @@ func WithRetryPolicyDisabled() ClientOption { // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { + source, err := GetSource() + if err != nil { + return Client{}, err + } + + // If source is Azure Arc return an error, as Azure Arc allow accepts System Assigned managed identities. + if source == AzureArc { + switch id.(type) { + case UserAssignedClientID, UserAssignedResourceID, UserAssignedObjectID: + return Client{}, errors.New("azure Arc doesn't support user assigned managed identities") + } + } opts := ClientOptions{ httpClient: shared.DefaultClient, retryPolicyEnabled: true, @@ -173,44 +232,132 @@ func New(id ID, options ...ClientOption) (Client, error) { miType: id, httpClient: opts.httpClient, retryPolicyEnabled: opts.retryPolicyEnabled, - source: DefaultToIMDS, + source: source, + } + fakeAuthInfo, err := authority.NewInfoFromAuthorityURI("https://login.microsoftonline.com/managed_identity", false, true) + if err != nil { + return Client{}, err } + client.authParams = authority.NewAuthParams(client.miType.value(), fakeAuthInfo) return client, nil } -func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims string) (*http.Request, error) { - var msiEndpoint *url.URL - msiEndpoint, err := url.Parse(imdsEndpoint) - if err != nil { - return nil, fmt.Errorf("couldn't parse %q: %s", imdsEndpoint, err) +// GetSource detects and returns the managed identity source available on the environment. +func GetSource() (Source, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + identityHeader := os.Getenv(identityHeaderEnvVar) + identityServerThumbprint := os.Getenv(identityServerThumbprintEnvVar) + msiEndpoint := os.Getenv(msiEndpointEnvVar) + imdsEndpoint := os.Getenv(imdsEndVar) + + if identityEndpoint != "" && identityHeader != "" { + if identityServerThumbprint != "" { + return ServiceFabric, nil + } + return AppService, nil + } else if msiEndpoint != "" { + return CloudShell, nil + } else if isAzureArcEnvironment(identityEndpoint, imdsEndpoint) { + return AzureArc, nil } - msiParameters := msiEndpoint.Query() - msiParameters.Set(apiVersionQuerryParameterName, imdsAPIVersion) - msiParameters.Set(resourceQuerryParameterName, resource) - if len(claims) > 0 { - msiParameters.Set("claims", claims) + return DefaultToIMDS, nil +} + +// Acquires tokens from the configured managed identity on an azure resource. +// +// Resource: scopes application is requesting access to +// Options: [WithClaims] +func (c Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + resource = strings.TrimSuffix(resource, "/.default") + o := AcquireTokenOptions{} + for _, option := range options { + option(&o) } + c.authParams.Scopes = []string{resource} - switch t := id.(type) { - case UserAssignedClientID: - msiParameters.Set(miQueryParameterClientId, string(t)) - case UserAssignedResourceID: - msiParameters.Set(miQueryParameterResourceId, string(t)) - case UserAssignedObjectID: - msiParameters.Set(miQueryParameterObjectId, string(t)) - case systemAssignedValue: // not adding anything + // ignore cached access tokens when given claims + if o.claims == "" { + storageTokenResponse, err := cacheManager.Read(ctx, c.authParams) + if err != nil { + return base.AuthResult{}, err + } + ar, err := base.AuthResultFromStorage(storageTokenResponse) + if err == nil { + ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err + } + } + + switch c.source { + case AzureArc: + return acquireTokenForAzureArc(ctx, c, resource) + case DefaultToIMDS: + return acquireTokenForIMDS(ctx, c, resource) default: - return nil, fmt.Errorf("unsupported type %T", id) + return base.AuthResult{}, fmt.Errorf("unsupported source %q", c.source) } +} - msiEndpoint.RawQuery = msiParameters.Encode() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) +func acquireTokenForIMDS(ctx context.Context, client Client, resource string) (base.AuthResult, error) { + req, err := createIMDSAuthRequest(ctx, client.miType, resource) if err != nil { - return nil, fmt.Errorf("error creating http request %s", err) + return base.AuthResult{}, err } - req.Header.Set(metaHTTPHeaderName, "true") - return req, nil + tokenResponse, err := client.getTokenForRequest(req) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(client.authParams, tokenResponse) +} + +func acquireTokenForAzureArc(ctx context.Context, client Client, resource string) (base.AuthResult, error) { + req, err := createAzureArcAuthRequest(ctx, resource, "") + if err != nil { + return base.AuthResult{}, err + } + + response, err := client.httpClient.Do(req) + if err != nil { + return base.AuthResult{}, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusUnauthorized { + return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode) + } + + secret, err := client.getAzureArcSecretKey(response, runtime.GOOS) + if err != nil { + return base.AuthResult{}, err + } + + secondRequest, err := createAzureArcAuthRequest(ctx, resource, string(secret)) + if err != nil { + return base.AuthResult{}, err + } + + tokenResponse, err := client.getTokenForRequest(secondRequest) + if err != nil { + return base.AuthResult{}, err + } + return authResultFromToken(client.authParams, tokenResponse) +} + +func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { + if cacheManager == nil { + return base.AuthResult{}, errors.New("cache instance is nil") + } + account, err := cacheManager.Write(authParams, token) + if err != nil { + return base.AuthResult{}, err + } + ar, err := base.NewAuthResult(token, account) + if err != nil { + return base.AuthResult{}, err + } + ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err } // Contains checks if the element is present in the list. @@ -257,25 +404,26 @@ func retry(maxRetries int, c ops.HTTPClient, req *http.Request, s Source) (*http } func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { - retryCount := 3 + r := accesstokens.TokenResponse{} + retryCount := defaultRetryCount if !client.retryPolicyEnabled { retryCount = 1 } resp, err := retry(retryCount, client.httpClient, req, client.source) if err != nil { - return accesstokens.TokenResponse{}, err + return r, err } responseBytes, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { - return accesstokens.TokenResponse{}, err + return r, err } switch resp.StatusCode { case http.StatusOK, http.StatusAccepted: default: sd := strings.TrimSpace(string(responseBytes)) if sd != "" { - return accesstokens.TokenResponse{}, errors.CallErr{ + return r, errors.CallErr{ Req: req, Resp: resp, Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", @@ -285,73 +433,136 @@ func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRe sd), } } - return accesstokens.TokenResponse{}, errors.CallErr{ + return r, errors.CallErr{ Req: req, Resp: resp, Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, resp.StatusCode), } } - var r accesstokens.TokenResponse + err = json.Unmarshal(responseBytes, &r) - r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, req.URL.Query().Get(resourceQuerryParameterName)) + r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, req.URL.Query().Get(resourceQueryParameterName)) return r, err } -// Acquires tokens from the configured managed identity on an azure resource. -// -// Resource: scopes application is requesting access to -// Options: [WithClaims] -func (client Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { - resource = strings.TrimSuffix(resource, "/.default") - o := AcquireTokenOptions{} - for _, option := range options { - option(&o) - } - req, err := createIMDSAuthRequest(ctx, client.miType, resource, o.claims) +func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) { + msiEndpoint, err := url.Parse(imdsDefaultEndpoint) if err != nil { - return base.AuthResult{}, err + return nil, fmt.Errorf("couldn't parse %q: %s", imdsDefaultEndpoint, err) } + msiParameters := msiEndpoint.Query() + msiParameters.Set(apiVersionQueryParameterName, imdsAPIVersion) + msiParameters.Set(resourceQueryParameterName, resource) - authInfo, err := authority.NewInfoFromAuthorityURI("https://login.microsoftonline.com/managed_identity", false, true) + switch t := id.(type) { + case UserAssignedClientID: + msiParameters.Set(miQueryParameterClientId, string(t)) + case UserAssignedResourceID: + msiParameters.Set(miQueryParameterResourceId, string(t)) + case UserAssignedObjectID: + msiParameters.Set(miQueryParameterObjectId, string(t)) + case systemAssignedValue: // not adding anything + default: + return nil, fmt.Errorf("unsupported type %T", id) + } + + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) if err != nil { - return base.AuthResult{}, err + return nil, fmt.Errorf("error creating http request %s", err) } - authParams := authority.NewAuthParams(client.miType.value(), authInfo) - authParams.Scopes = []string{resource} - // ignore cached access tokens when given claims - if o.claims == "" { - if cacheManager == nil { - return base.AuthResult{}, errors.New("cache instance is nil") - } - storageTokenResponse, err := cacheManager.Read(ctx, authParams) - if err != nil { - return base.AuthResult{}, err - } - ar, err := base.AuthResultFromStorage(storageTokenResponse) - if err == nil { - ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - return ar, err - } + req.Header.Set(metaHTTPHeaderName, "true") + return req, nil +} + +func createAzureArcAuthRequest(ctx context.Context, resource string, key string) (*http.Request, error) { + identityEndpoint := os.Getenv(identityEndpointEnvVar) + if identityEndpoint == "" { + identityEndpoint = azureArcEndpoint } - tokenResponse, err := client.getTokenForRequest(req) + msiEndpoint, parseErr := url.Parse(identityEndpoint) + + if parseErr != nil { + return nil, fmt.Errorf("couldn't parse %q: %s", identityEndpoint, parseErr) + } + + msiParameters := msiEndpoint.Query() + msiParameters.Set(apiVersionQueryParameterName, azureArcAPIVersion) + msiParameters.Set(resourceQueryParameterName, resource) + + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) if err != nil { - return base.AuthResult{}, err + return nil, fmt.Errorf("error creating http request %s", err) } - return authResultFromToken(authParams, tokenResponse) + req.Header.Set(metaHTTPHeaderName, "true") + + if key != "" { + req.Header.Set("Authorization", fmt.Sprintf("Basic %s", key)) + } + + return req, nil } -func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) { - if cacheManager == nil { - return base.AuthResult{}, fmt.Errorf("cache instance is nil") +func isAzureArcEnvironment(identityEndpoint, imdsEndpoint string) bool { + if identityEndpoint != "" && imdsEndpoint != "" { + return true } - account, err := cacheManager.Write(authParams, token) + himdsFilePath := getAzureArcHimdsFilePath(runtime.GOOS) + if himdsFilePath != "" { + if _, err := os.Stat(himdsFilePath); err == nil { + return true + } + } + return false +} + +func (c *Client) getAzureArcSecretKey(response *http.Response, platform string) (string, error) { + wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName) + + if len(wwwAuthenticateHeader) == 0 { + return "", errors.New("response has no www-authenticate header") + } + + // check if the platform is supported + expectedSecretFilePath := getAzureArcPlatformPath(platform) + if expectedSecretFilePath == "" { + return "", errors.New("platform not supported, expected linux or windows") + } + + parts := strings.Split(wwwAuthenticateHeader, "Basic realm=") + if len(parts) < 2 { + return "", fmt.Errorf("basic realm= not found in the string, instead found: %s", wwwAuthenticateHeader) + } + + secretFilePath := parts + + // check that the file in the file path is a .key file + fileName := filepath.Base(secretFilePath[1]) + if !strings.HasSuffix(fileName, azureArcFileExtension) { + return "", fmt.Errorf("invalid file extension, expected %s, got %s", azureArcFileExtension, filepath.Ext(fileName)) + } + + // check that file path from header matches the expected file path for the platform + if expectedSecretFilePath != filepath.Dir(secretFilePath[1]) { + return "", fmt.Errorf("invalid file path, expected %s, got %s", expectedSecretFilePath, filepath.Dir(secretFilePath[1])) + } + + fileInfo, err := os.Stat(secretFilePath[1]) if err != nil { - return base.AuthResult{}, err + return "", fmt.Errorf("failed to get metadata for %s due to error: %s", secretFilePath[1], err) } - ar, err := base.NewAuthResult(token, account) + + // Throw an error if the secret file's size is greater than 4096 bytes + if s := fileInfo.Size(); s > azureArcMaxFileSizeBytes { + return "", fmt.Errorf("invalid secret file size, expected %d, file size was %d", azureArcMaxFileSizeBytes, s) + } + + // Attempt to read the contents of the secret file + secret, err := os.ReadFile(secretFilePath[1]) if err != nil { - return base.AuthResult{}, err + return "", fmt.Errorf("failed to read %q due to error: %s", secretFilePath[1], err) } - ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - return ar, err + + return string(secret), nil } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 1e8f1609..efc24cf6 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -10,6 +10,8 @@ import ( "io" "net/http" "net/url" + "os" + "path/filepath" "strings" "testing" "time" @@ -21,10 +23,16 @@ import ( ) const ( - // test Resources - resource = "https://demo.azure.com" - resourceDefaultSuffix = "https://demo.azure.com/.default" + // Test Resources + resource = "https://management.azure.com" + resourceDefaultSuffix = "https://management.azure.com/.default" token = "fake-access-token" + fakeAzureArcFilePath = "fake/fake" + secretKey = "secret.key" + basicRealm = "Basic realm=" + + errorExpectedButGot = "expected %v, got %v" + errorFormingJsonResponse = "error while forming json response : %s" ) type SuccessfulResponse struct { @@ -34,7 +42,7 @@ type SuccessfulResponse struct { TokenType string `json:"token_type"` } -type ErrorRespone struct { +type ErrorResponse struct { Err string `json:"error"` Desc string `json:"error_description"` } @@ -53,7 +61,7 @@ func getSuccessfulResponse(resource string) ([]byte, error) { } func makeResponseWithErrorData(err string, desc string) ([]byte, error) { - responseBody := ErrorRespone{ + responseBody := ErrorResponse{ Err: err, Desc: desc, } @@ -61,72 +69,78 @@ func makeResponseWithErrorData(err string, desc string) ([]byte, error) { return jsonResponse, e } -type resourceTestData struct { - source Source - endpoint string - resource string - miType ID +func createMockFile(t *testing.T, path string, size int64) { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("failed to create directory: %v", err) + } + + f, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file: %v", err) + } + defer f.Close() + + if size > 0 { + if err := f.Truncate(size); err != nil { + t.Fatalf("failed to truncate file: %v", err) + } + } + + // Write the content to the file + if _, err := f.WriteString("secret file data"); err != nil { + t.Fatalf("failed to write to file: %v", err) + } + t.Cleanup(func() { os.Remove(path) }) } -type errorTestData struct { - code int - err string - desc string - correlationid string +func setEnvVars(t *testing.T, source Source) { + switch source { + case AzureArc: + t.Setenv(identityEndpointEnvVar, "http://127.0.0.1:40342/metadata/identity/oauth2/token") + t.Setenv(imdsEndVar, "http://169.254.169.254/metadata/identity/oauth2/token") + case AppService: + t.Setenv(identityEndpointEnvVar, "http://127.0.0.1:41564/msi/token") + t.Setenv(identityHeaderEnvVar, "secret") + case CloudShell: + t.Setenv(msiEndpointEnvVar, "http://localhost:40342/metadata/identity/oauth2/token") + case ServiceFabric: + t.Setenv(identityEndpointEnvVar, "http://localhost:40342/metadata/identity/oauth2/token") + t.Setenv(identityHeaderEnvVar, "secret") + t.Setenv(identityServerThumbprintEnvVar, "thumbprint") + } } -func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { - testCases := []errorTestData{ - {code: http.StatusNotFound, - err: "", - desc: "", - correlationid: "121212"}, - {code: http.StatusNotImplemented, - err: "", - desc: "", - correlationid: "121212"}, - {code: http.StatusServiceUnavailable, - err: "", - desc: "", - correlationid: "121212"}, - {code: http.StatusBadRequest, - err: "invalid_request", - desc: "Identity not found", - correlationid: "121212", - }, +func setCustomAzureArcPlatformPath(t *testing.T, path string) { + originalFunc := getAzureArcPlatformPath + getAzureArcPlatformPath = func(string) string { + return path } - for _, testCase := range testCases { - t.Run(http.StatusText(testCase.code), func(t *testing.T) { + t.Cleanup(func() { getAzureArcPlatformPath = originalFunc }) +} - fakeErrorClient := mock.Client{} - responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) - if err != nil { - t.Fatalf("error while forming json response : %s", err.Error()) - } - fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), - mock.WithBody(responseBody)) - client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) +func setCustomAzureArcFilePath(t *testing.T, path string) { + originalFunc := getAzureArcHimdsFilePath + getAzureArcHimdsFilePath = func(string) string { + return path + } + + t.Cleanup(func() { getAzureArcHimdsFilePath = originalFunc }) +} + +func TestSource(t *testing.T) { + for _, testCase := range []Source{AzureArc, DefaultToIMDS} { + t.Run(string(testCase), func(t *testing.T) { + setEnvVars(t, testCase) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + + actualSource, err := GetSource() if err != nil { - t.Fatal(err) - } - resp, err := client.AcquireToken(context.Background(), resource) - if err == nil { - t.Fatalf("should have encountered the error") - } - var callErr errors.CallErr - if errors.As(err, &callErr) { - if !strings.Contains(err.Error(), testCase.err) { - t.Fatalf("expected message '%s' in error, got %q", testCase.err, callErr.Error()) - } - if callErr.Resp.StatusCode != testCase.code { - t.Fatalf("expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) - } - } else { - t.Fatalf("expected error of type %T, got %T", callErr, err) + t.Fatalf("error while getting source: %s", err.Error()) } - if resp.AccessToken != "" { - t.Fatalf("accesstoken should be empty") + if actualSource != testCase { + t.Fatalf(errorExpectedButGot, testCase, actualSource) } }) } @@ -343,22 +357,45 @@ func TestCacheScopes(t *testing.T) { } } -func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { - testCases := []resourceTestData{ - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: SystemAssigned()}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned()}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: UserAssignedClientID("clientId")}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, +func TestAzureArcReturnsWhenHimdsFound(t *testing.T) { + mockFilePath := filepath.Join(t.TempDir(), "himds") + setCustomAzureArcFilePath(t, mockFilePath) + + // Create the mock himds file + createMockFile(t, mockFilePath, 1024) + + actualSource, err := GetSource() + if err != nil { + t.Fatalf("error while getting source: %s", err.Error()) + } + + if actualSource != AzureArc { + t.Fatalf(errorExpectedButGot, AzureArc, actualSource) + } +} + +func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) { + testCases := []struct { + resource string + miType ID + }{ + {resource: resource, miType: SystemAssigned()}, + {resource: resourceDefaultSuffix, miType: SystemAssigned()}, + {resource: resource, miType: UserAssignedClientID("clientId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, + {resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, } for _, testCase := range testCases { - t.Run(string(testCase.source)+"-"+testCase.miType.value(), func(t *testing.T) { + t.Run(string(DefaultToIMDS)+"-"+testCase.miType.value(), func(t *testing.T) { + endpoint := imdsDefaultEndpoint + var localUrl *url.URL mockClient := mock.Client{} responseBody, err := getSuccessfulResponse(resource) if err != nil { - t.Fatalf("error while forming json response : %s", err.Error()) + t.Fatalf(errorFormingJsonResponse, err.Error()) } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { localUrl = r.URL })) @@ -366,6 +403,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { before := cacheManager defer func() { cacheManager = before }() cacheManager = storage.New(nil) + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) @@ -374,20 +412,15 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { if err != nil { t.Fatal(err) } - if localUrl == nil || !strings.HasPrefix(localUrl.String(), testCase.endpoint) { - t.Fatalf("url request is not on %s got %s", testCase.endpoint, localUrl) - } - if testCase.miType.value() != systemAssignedManagedIdentity { - if !strings.Contains(localUrl.String(), testCase.miType.value()) { - t.Fatalf("url request does not contain the %s got %s", testCase.endpoint, testCase.miType.value()) - } + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) } query := localUrl.Query() - if query.Get(apiVersionQuerryParameterName) != imdsAPIVersion { - t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQuerryParameterName)) + if query.Get(apiVersionQueryParameterName) != imdsAPIVersion { + t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQueryParameterName)) } - if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { + if query.Get(resourceQueryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { t.Fatal("suffix /.default was not removed.") } switch i := testCase.miType.(type) { @@ -432,6 +465,281 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { } } +func TestAzureArc(t *testing.T) { + testCaseFilePath := filepath.Join(t.TempDir(), azureConnectedMachine) + + endpoint := azureArcEndpoint + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + + var localUrl *url.URL + mockClient := mock.Client{} + + mockFilePath := filepath.Join(testCaseFilePath, secretKey) + setCustomAzureArcPlatformPath(t, testCaseFilePath) + + createMockFile(t, mockFilePath, 0) + + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, basicRealm+filepath.Join(testCaseFilePath, secretKey)) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + + responseBody, err := getSuccessfulResponse(resource) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithHTTPHeader(headers), + mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + + // resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), resourceDefaultSuffix) + if err != nil { + t.Fatal(err) + } + + if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) { + t.Fatalf("url request is not on %s got %s", endpoint, localUrl) + } + + query := localUrl.Query() + + if query.Get(apiVersionQueryParameterName) != azureArcAPIVersion { + t.Fatalf("api-version not on %s got %s", azureArcAPIVersion, query.Get(apiVersionQueryParameterName)) + } + if query.Get(resourceQueryParameterName) != strings.TrimSuffix(resourceDefaultSuffix, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + if result.Metadata.TokenSource != base.IdentityProvider { + t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + result, err = client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource) + } + secondFakeClient, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err = secondFakeClient.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + if result.Metadata.TokenSource != base.Cache { + t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource) + } + +} + +func TestAzureArcOnlySystemAssignedSupported(t *testing.T) { + setEnvVars(t, AzureArc) + mockClient := mock.Client{} + + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + for _, testCase := range []ID{ + UserAssignedClientID("client"), + UserAssignedObjectID("ObjectId"), + UserAssignedResourceID("resourceid")} { + _, err := New(testCase, WithHTTPClient(&mockClient)) + if err == nil { + t.Fatal(`expected error: azure arc not supported error"`) + + } + if err.Error() != "azure Arc doesn't support user assigned managed identities" { + t.Fatalf(`expected error: azure arc not supported error, got error: "%v"`, err) + } + + } +} +func TestAzureArcPlatformSupported(t *testing.T) { + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + mockClient := mock.Client{} + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, "Basic realm=/path/to/secret.key") + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + ) + setCustomAzureArcPlatformPath(t, "") + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), resource) + if err == nil || !strings.Contains(err.Error(), "platform not supported") { + println(result.AccessToken) + t.Fatalf(`expected error: "%v" got error: "%v"`, "platform not supported", err) + + } + if result.AccessToken != "" { + t.Fatalf("access token should be empty") + } +} + +func TestAzureArcErrors(t *testing.T) { + setEnvVars(t, AzureArc) + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + testCaseFilePath := filepath.Join(t.TempDir(), "AzureConnectedMachineAgent") + + testCases := []struct { + name string + headerValue string + expectedError string + fileSize int64 + }{ + { + name: "No www-authenticate header", + expectedError: "response has no www-authenticate header", + }, + { + name: "Basic realm= not found", + headerValue: "Basic ", + expectedError: "basic realm= not found in the string, instead found: Basic ", + }, + { + name: "Invalid file extension", + headerValue: "Basic realm=/path/to/secret.txt", + expectedError: "invalid file extension, expected .key, got .txt", + }, + { + name: "Invalid file path", + headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey), + expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"), + }, + { + name: "Unable to get file info", + headerValue: basicRealm + filepath.Join(testCaseFilePath, "2secret.key"), + expectedError: "failed to get metadata", + }, + { + name: "Invalid secret file size", + headerValue: basicRealm + filepath.Join(testCaseFilePath, secretKey), + expectedError: "invalid secret file size, expected 4096, file size was 5000", + fileSize: 5000, + }, + } + + for _, testCase := range testCases { + t.Run(string(testCase.name), func(t *testing.T) { + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + mockClient := mock.Client{} + mockFilePath := filepath.Join(testCaseFilePath, secretKey) + setCustomAzureArcPlatformPath(t, testCaseFilePath) + createMockFile(t, mockFilePath, testCase.fileSize) + headers := http.Header{} + headers.Set(wwwAuthenticateHeaderName, testCase.headerValue) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), + mock.WithHTTPHeader(headers), + ) + + responseBody, err := getSuccessfulResponse(resource) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithHTTPHeader(headers), + mock.WithBody(responseBody)) + + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + return + } + result, err := client.AcquireToken(context.Background(), resource) + if err == nil || !strings.Contains(err.Error(), testCase.expectedError) { + t.Fatalf(`expected error: "%v" got error: "%v"`, testCase.expectedError, err) + + } + if result.AccessToken != "" { + t.Fatal("access token should be empty") + } + }) + } +} + +func TestSystemAssignedReturnsAcquireTokenFailure(t *testing.T) { + testCases := []struct { + code int + err string + desc string + }{ + {code: http.StatusNotFound}, + {code: http.StatusNotImplemented}, + {code: http.StatusServiceUnavailable}, + {code: http.StatusBadRequest, + err: "invalid_request", + desc: "Identity not found", + }, + } + + for _, testCase := range testCases { + t.Run(http.StatusText(testCase.code), func(t *testing.T) { + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) + if err != nil { + t.Fatalf(errorFormingJsonResponse, err.Error()) + } + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), + mock.WithBody(responseBody)) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient), WithRetryPolicyDisabled()) + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatalf("should have encountered the error") + } + var callErr errors.CallErr + if errors.As(err, &callErr) { + if !strings.Contains(err.Error(), testCase.err) { + t.Fatalf("expected message '%s' in error, got %q", testCase.err, callErr.Error()) + } + if callErr.Resp.StatusCode != testCase.code { + t.Fatalf("expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) + } + } else { + t.Fatalf("expected error of type %T, got %T", callErr, err) + } + if resp.AccessToken != "" { + t.Fatalf("access token should be empty") + } + }) + } +} + func TestCreatingIMDSClient(t *testing.T) { tests := []struct { name string @@ -473,6 +781,7 @@ func TestCreatingIMDSClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + setCustomAzureArcFilePath(t, fakeAzureArcFilePath) client, err := New(tt.id) if tt.wantErr { if err == nil {