From 1d2f3e808f39e8526bda7b1197b53886b4fb5d8f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 27 Aug 2024 23:53:00 +0100 Subject: [PATCH 01/32] Initial system assgined for acquire token --- apps/managedidentity/managedidentity.go | 115 ++++++++++++++---- .../devapps/client_certificate_sample.go | 106 ++++++++-------- apps/tests/devapps/main.go | 6 +- apps/tests/devapps/managedidentity_sample.go | 41 ++++--- apps/tests/devapps/sample_utils.go | 24 ++-- 5 files changed, 179 insertions(+), 113 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index ddbe71b0..7636290d 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -11,12 +11,20 @@ package managedidentity import ( "context" + "encoding/json" "fmt" + "io" + "net/http" + "net/url" "sync" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" + // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" + // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" + // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) const ( @@ -29,10 +37,10 @@ const ( // Client is a client that provides access to Managed Identity token calls. type Client struct { - AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). also may remove from here + // AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). also may remove from here cacheAccessorMu *sync.RWMutex - // base ops.HTTPClient - // managedIdentityType Type + httpClient ops.HTTPClient + MiType ID // Token *oauth.Client // pmanager manager // todo : expose the manager from base. // cacheAccessor cache.ExportReplace @@ -47,13 +55,18 @@ type clientOptions struct { // clientId string } -type withClaimsOption struct{ Claims string } -type withHTTPClientOption struct{ HttpClient ops.HTTPClient } +// type withClaimsOption struct{ Claims string } +type withHTTPClientOption struct { + HttpClient ops.HTTPClient +} // Option is an optional argument to New(). type Option interface{ apply(*clientOptions) } type ClientOption interface{ ClientOption() } -type AcquireTokenOption interface{ AcquireTokenOption() } +type AcquireTokenOptions struct { + Claims string +} +type AcquireTokenOption interface{ apply(*AcquireTokenOptions) } // Source represents the managed identity sources supported. type Source int @@ -77,14 +90,14 @@ func (c ClientID) value() string { return string(c) } func (o ObjectID) value() string { return string(o) } func (r ResourceID) value() string { return string(r) } -func (w withClaimsOption) AcquireTokenOption() {} -func (w withHTTPClientOption) AcquireTokenOption() {} -func (w withHTTPClientOption) apply(opts *clientOptions) { opts.httpClient = w.HttpClient } +func (w AcquireTokenOptions) AcquireTokenOption() {} +func (w withHTTPClientOption) AcquireTokenOption() {} +func (w Client) apply(opts *clientOptions) { opts.httpClient = w.HttpClient } // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. -func WithClaims(claims string) AcquireTokenOption { - return withClaimsOption{Claims: claims} +func WithClaims(claims string) AcquireTokenOptions { + return AcquireTokenOptions{Claims: claims} } // WithHTTPClient allows for a custom HTTP client to be set. @@ -99,28 +112,29 @@ func WithHTTPClient(httpClient ops.HTTPClient) Option { func New(id ID, options ...Option) (Client, error) { fmt.Println("idType: ", id.value()) - opts := clientOptions{ - claims: "claims", + opts := clientOptions{ // work on this side where + httpClient: shared.DefaultClient, } for _, option := range options { option.apply(&opts) } - authInfo, err := authority.NewInfoFromAuthorityURI("authorityURI", true, false) - if err != nil { - return Client{}, err + client := Client{ // TODO :: check for http client + MiType: id, } - authParams := authority.NewAuthParams(id.value(), authInfo) - client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go - AuthParams: authParams, - cacheAccessorMu: &sync.RWMutex{}, - // manager: storage.New(token), - // pmanager: storage.NewPartitionedManager(token), - } + return client, nil +} - return client, err +type responseJson struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn string `json:"expires_in"` + ExpiresOn string `json:"expires_on"` + NotBefore string `json:"not_before"` + Resource string `json:"resource"` + TokenType string `json:"token_type"` } // Acquires tokens from the configured managed identity on an azure resource. @@ -128,6 +142,57 @@ func New(id ID, options ...Option) (Client, error) { // Resource: scopes application is requesting access to // Options: [WithClaims] func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + o := AcquireTokenOptions{} + + for _, option := range options { + option.apply(&o) + } + + if client.MiType == SystemAssigned() { + systemUrl := "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01" + var msiEndpoint *url.URL + msi_endpoint, err := url.Parse(systemUrl) + if err != nil { + fmt.Println("Error creating URL: ", err) + return base.AuthResult{}, nil + } + msiParameters := msi_endpoint.Query() + msiParameters.Add("resource", "https://management.azure.com/") + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) + if err != nil { + fmt.Println("Error creating HTTP request: ", err) + return base.AuthResult{}, nil + } + req.Header.Add("Metadata", "true") + + resp, err := client.httpClient.Do(req) + if err != nil { + fmt.Println("Error calling token endpoint: ", err) + return base.AuthResult{}, nil + } + + // Pull out response body + responseBytes, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + fmt.Println("Error reading response body : ", err) + return base.AuthResult{}, nil + } + + // Unmarshall response body into struct + var r accesstokens.TokenResponse + err = json.Unmarshal(responseBytes, &r) + if err != nil { + fmt.Println("Error unmarshalling the response:", err) + return base.AuthResult{}, nil + } + + println("Access token :: ", r.AccessToken) + return base.NewAuthResult(r, shared.Account{}) + } + + // all the other options. return base.AuthResult{}, nil } diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index c55ed571..d32dcb68 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -1,56 +1,56 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. +// // Copyright (c) Microsoft Corporation. +// // Licensed under the MIT license. package main -import ( - "context" - "fmt" - "log" - "os" - - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" -) - -var _config2 *Config = CreateConfig("confidential_config.json") - -// Keep the ConfidentialClient application object around, because it maintains a token cache -// For simplicity, the sample uses global variables. -// For user flows (web site, web api) or for large multi-tenant apps use a cache per user or per tenant -var _app2 *confidential.Client = createAppWithCert() - -func createAppWithCert() *confidential.Client { - - pemData, err := os.ReadFile(_config2.PemData) - if err != nil { - log.Fatal(err) - } - - // This extracts our public certificates and private key from the PEM file. If it is - // encrypted, the second argument must be password to decode. - // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) - // and to download it when the application starts. - certs, privateKey, err := confidential.CertFromPEM(pemData, "") - if err != nil { - log.Fatal(err) - } - cred, err := confidential.NewCredFromCert(certs, privateKey) - if err != nil { - log.Fatal(err) - } - app, err := confidential.New(_config2.Authority, _config2.ClientID, cred, confidential.WithCache(cacheAccessor)) - if err != nil { - log.Fatal(err) - } - return &app -} - -func acquireTokenClientCertificate() { - - result, err := _app2.AcquireTokenByCredential(context.Background(), _config1.Scopes) - if err != nil { - log.Fatal(err) - } - - fmt.Println("A Bearer token was acquired, it expires on: ", result.ExpiresOn) -} +// import ( +// "context" +// "fmt" +// "log" +// "os" + +// "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" +// ) + +// var _config2 *Config = CreateConfig("confidential_config.json") + +// // Keep the ConfidentialClient application object around, because it maintains a token cache +// // For simplicity, the sample uses global variables. +// // For user flows (web site, web api) or for large multi-tenant apps use a cache per user or per tenant +// var _app2 *confidential.Client = createAppWithCert() + +// func createAppWithCert() *confidential.Client { + +// pemData, err := os.ReadFile(_config2.PemData) +// if err != nil { +// log.Fatal(err) +// } + +// // This extracts our public certificates and private key from the PEM file. If it is +// // encrypted, the second argument must be password to decode. +// // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) +// // and to download it when the application starts. +// certs, privateKey, err := confidential.CertFromPEM(pemData, "") +// if err != nil { +// log.Fatal(err) +// } +// cred, err := confidential.NewCredFromCert(certs, privateKey) +// if err != nil { +// log.Fatal(err) +// } +// app, err := confidential.New(_config2.Authority, _config2.ClientID, cred, confidential.WithCache(cacheAccessor)) +// if err != nil { +// log.Fatal(err) +// } +// return &app +// } + +// func acquireTokenClientCertificate() { + +// result, err := _app2.AcquireTokenByCredential(context.Background(), _config1.Scopes) +// if err != nil { +// log.Fatal(err) +// } + +// fmt.Println("A Bearer token was acquired, it expires on: ", result.ExpiresOn) +// } diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index 027bd6e4..d1da7260 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -13,7 +13,7 @@ func main() { ctx := context.Background() // Choose a sammple to run. - exampleType := "5" + exampleType := "7" if exampleType == "1" { acquireTokenDeviceCode() @@ -36,8 +36,8 @@ func main() { } else if exampleType == "6" { // This sample does not use a serialized cache - it relies on in-memory cache by reusing the app object // This works well for app tokens, because there is only 1 token per resource, per tenant. - acquireTokenClientCertificate() - + // acquireTokenClientCertificate() + println("Nothing to do in this option") // this time the token comes from the cache! // acquireTokenClientCertificate() } else if exampleType == "7" { diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index 337b1128..ab7fe090 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -11,28 +11,29 @@ import ( func RunManagedIdentity() { customHttpClient := &http.Client{} - miSystemAssigned, error := mi.New(mi.SystemAssigned()) + miSystemAssigned, error := mi.New(mi.SystemAssigned(), mi.WithHTTPClient(customHttpClient)) if error != nil { fmt.Println(error) } - miClientIdAssigned, error := mi.New(mi.ClientID("client id 123"), mi.WithHTTPClient(customHttpClient)) - if error != nil { - fmt.Println(error) - } - - miResourceIdAssigned, error := mi.New(mi.ResourceID("resource id 123")) - if error != nil { - fmt.Println(error) - } - - miObjectIdAssigned, error := mi.New(mi.ObjectID("object id 123")) - if error != nil { - fmt.Println(error) - } - - miSystemAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) - miClientIdAssigned.AcquireToken(context.Background(), "resource") - miResourceIdAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) - miObjectIdAssigned.AcquireToken(context.Background(), "resource") + // miClientIdAssigned, error := mi.New(mi.ClientID("951d3571-c442-42e0-9efd-e1d7e1a21030"), + // mi.WithHTTPClient(customHttpClient)) + // if error != nil { + // fmt.Println(error) + // } + + // miResourceIdAssigned, error := mi.New(mi.ResourceID("resource id 123")) + // if error != nil { + // fmt.Println(error) + // } + + // miObjectIdAssigned, error := mi.New(mi.ObjectID("object id 123")) + // if error != nil { + // fmt.Println(error) + // } + + miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") + // miClientIdAssigned.AcquireToken(context.Background(), "resource") + // miResourceIdAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) + // miObjectIdAssigned.AcquireToken(context.Background(), "resource") } diff --git a/apps/tests/devapps/sample_utils.go b/apps/tests/devapps/sample_utils.go index b311927b..05e8b2ff 100644 --- a/apps/tests/devapps/sample_utils.go +++ b/apps/tests/devapps/sample_utils.go @@ -11,18 +11,18 @@ import ( // Config represents the config.json required to run the samples type Config struct { - ClientID string `json:"client_id"` - Authority string `json:"authority"` - Scopes []string `json:"scopes"` - Username string `json:"username"` - Password string `json:"password"` - RedirectURI string `json:"redirect_uri"` - CodeChallenge string `json:"code_challenge"` - CodeChallengeMethod string `json:"code_challenge_method"` - State string `json:"state"` - ClientSecret string `json:"client_secret"` - Thumbprint string `json:"thumbprint"` - PemData string `json:"pem_file"` + ClientID string `json:"client_id"` + Authority string `json:"authority"` + Scopes []string `json:"scopes"` + Username string `json:"username"` + Password string `json:"password"` + RedirectURI string `json:"redirect_uri"` + // CodeChallenge string `json:"code_challenge"` + // CodeChallengeMethod string `json:"code_challenge_method"` + // State string `json:"state"` + ClientSecret string `json:"client_secret"` + // Thumbprint string `json:"thumbprint"` + // PemData string `json:"pem_file"` } // CreateConfig creates the Config struct from a json file. From 63e6bed3c3a998b7f4ae443f8db45210d79777f2 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 28 Aug 2024 00:38:35 +0100 Subject: [PATCH 02/32] Added a simple version of getting token. Added a simple version of getting token and printing it reformatting code. --- apps/managedidentity/managedidentity.go | 93 +++++++++++--------- apps/tests/devapps/managedidentity_sample.go | 5 +- 2 files changed, 52 insertions(+), 46 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 7636290d..da81b3e4 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -16,7 +16,6 @@ import ( "io" "net/http" "net/url" - "sync" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" @@ -35,12 +34,36 @@ const ( AzureArc = 1 ) +// id type managed identity +type Source int + +type systemAssignedValue string + +type ID interface { + value() string +} + +func SystemAssigned() ID { + return systemAssignedValue("") +} + +type ClientID string +type ObjectID string +type ResourceID string + +func (s systemAssignedValue) value() string { return string(s) } +func (c ClientID) value() string { return string(c) } +func (o ObjectID) value() string { return string(o) } +func (r ResourceID) value() string { return string(r) } + +//------------------ + // Client is a client that provides access to Managed Identity token calls. type Client struct { // AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). also may remove from here - cacheAccessorMu *sync.RWMutex - httpClient ops.HTTPClient - MiType ID + // cacheAccessorMu *sync.RWMutex + httpClient ops.HTTPClient + MiType ID // Token *oauth.Client // pmanager manager // todo : expose the manager from base. // cacheAccessor cache.ExportReplace @@ -48,7 +71,7 @@ type Client struct { // clientOptions are optional settings for New(). These options are set using various functions // returning Option calls. -type clientOptions struct { +type ClientOptions struct { claims string // bypasses cache, does nothing else httpClient ops.HTTPClient // disableInstanceDiscovery bool // always false @@ -56,43 +79,26 @@ type clientOptions struct { } // type withClaimsOption struct{ Claims string } -type withHTTPClientOption struct { - HttpClient ops.HTTPClient -} +// type withHTTPClientOption struct { +// HttpClient ops.HTTPClient +// } // Option is an optional argument to New(). -type Option interface{ apply(*clientOptions) } -type ClientOption interface{ ClientOption() } +type Option interface{ apply(*ClientOptions) } + +// type ClientOptions interface{ ClientOption() } type AcquireTokenOptions struct { Claims string } type AcquireTokenOption interface{ apply(*AcquireTokenOptions) } // Source represents the managed identity sources supported. -type Source int -type systemAssignedValue string - -type ID interface { - value() string -} +func (w AcquireTokenOptions) AcquireTokenOption() {} +func (w ClientOptions) ClientOptions() {} -func SystemAssigned() ID { - return systemAssignedValue("") -} - -type ClientID string -type ObjectID string -type ResourceID string - -func (s systemAssignedValue) value() string { return string(s) } -func (c ClientID) value() string { return string(c) } -func (o ObjectID) value() string { return string(o) } -func (r ResourceID) value() string { return string(r) } - -func (w AcquireTokenOptions) AcquireTokenOption() {} -func (w withHTTPClientOption) AcquireTokenOption() {} -func (w Client) apply(opts *clientOptions) { opts.httpClient = w.HttpClient } +// func (w *Client) apply(opts ) +func (w Client) apply(opts *ClientOptions) { w.httpClient = opts.httpClient } // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. @@ -101,8 +107,8 @@ func WithClaims(claims string) AcquireTokenOptions { } // WithHTTPClient allows for a custom HTTP client to be set. -func WithHTTPClient(httpClient ops.HTTPClient) Option { - return withHTTPClientOption{HttpClient: httpClient} +func WithHTTPClient(httpClient ops.HTTPClient) ClientOptions { + return ClientOptions{httpClient: httpClient} } // Client to be used to acquire tokens for managed identity. @@ -112,7 +118,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) Option { func New(id ID, options ...Option) (Client, error) { fmt.Println("idType: ", id.value()) - opts := clientOptions{ // work on this side where + opts := ClientOptions{ // work on this side where httpClient: shared.DefaultClient, } @@ -121,7 +127,8 @@ func New(id ID, options ...Option) (Client, error) { } client := Client{ // TODO :: check for http client - MiType: id, + MiType: id, + httpClient: opts.httpClient, } return client, nil @@ -141,22 +148,22 @@ type responseJson struct { // // Resource: scopes application is requesting access to // Options: [WithClaims] -func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { - o := AcquireTokenOptions{} +func (client Client) AcquireToken(context context.Context, resource string) (base.AuthResult, error) { + // o := AcquireTokenOptions{} - for _, option := range options { - option.apply(&o) - } + // for _, option := range options { + // option.apply(&o) + // } if client.MiType == SystemAssigned() { systemUrl := "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01" var msiEndpoint *url.URL - msi_endpoint, err := url.Parse(systemUrl) + msiEndpoint, err := url.Parse(systemUrl) if err != nil { fmt.Println("Error creating URL: ", err) return base.AuthResult{}, nil } - msiParameters := msi_endpoint.Query() + msiParameters := msiEndpoint.Query() msiParameters.Add("resource", "https://management.azure.com/") msiEndpoint.RawQuery = msiParameters.Encode() req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index ab7fe090..3c8394e2 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -3,15 +3,14 @@ package main import ( "context" "fmt" - "net/http" mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" ) func RunManagedIdentity() { - customHttpClient := &http.Client{} + // customHttpClient := &http.Client{} - miSystemAssigned, error := mi.New(mi.SystemAssigned(), mi.WithHTTPClient(customHttpClient)) + miSystemAssigned, error := mi.New(mi.SystemAssigned()) if error != nil { fmt.Println(error) } From 69a039cfe334261aeeddf7eb87ebf156dfb90c15 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 2 Sep 2024 12:03:17 +0100 Subject: [PATCH 03/32] added IMDB for SAMI Added tests and implementation for SAMI IMDS --- apps/managedidentity/managedidentity.go | 200 +++++++++++-------- apps/managedidentity/managedidentity_test.go | 67 +++++-- 2 files changed, 167 insertions(+), 100 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index da81b3e4..c9d1f371 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -21,9 +21,57 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" - // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" - // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" - // "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" +) + +const ( + MetaHTTPHeadderName = "Metadata" + APIVersionQuerryParameterName = "api-version" + ResourceBodyOrQuerryParameterName = "resource" +) + +const ( + MIQuerryParameterClientId = "client_id" + MIQuerryParameterObjectId = "object_id" + MIQuerryParameterResourceId = "mi_res_id" +) + +// AZURE_POD_IDENTITY_AUTHORITY_HOST: "AZURE_POD_IDENTITY_AUTHORITY_HOST", +// +// IDENTITY_ENDPOINT: "IDENTITY_ENDPOINT", +// IDENTITY_HEADER: "IDENTITY_HEADER", +// IDENTITY_SERVER_THUMBPRINT: "IDENTITY_SERVER_THUMBPRINT", +// IMDS_ENDPOINT: "IMDS_ENDPOINT", +// MSI_ENDPOINT: "MSI_ENDPOINT", +// APP_SERVICE: "AppService", +// +// These are the MI source names +// AZURE_ARC: "AzureArc", +// CLOUD_SHELL: "CloudShell", +// DEFAULT_TO_IMDS: "DefaultToImds", +// IMDS: "Imds", +// SERVICE_FABRIC: "ServiceFabric", + +// CouldShell +// This also comes from enviourment point :: :: +const () + +// Appservice +// end point comes from enviournment variable ?? ?!?!?!?!? +const ( + AppServiceMSIEndPointVersion = "2019-08-01" +) + +// Arc +const ( + ARCAPIEndpoint = "http://127.0.0.1:40342/metadata/identity/oauth2/token" + ARCAPIVersion = "2019-11-01" +) + +// IMDS +const ( + IMDSTokenPath = "/metadata/identity/oauth2/token" + IMDSEndpoint = "http://169.254.169.254" + IMDSTokenPath + IMDSAPIVersion = "2018-02-01" ) const ( @@ -37,16 +85,11 @@ const ( // id type managed identity type Source int -type systemAssignedValue string - type ID interface { value() string } -func SystemAssigned() ID { - return systemAssignedValue("") -} - +type systemAssignedValue string // its private for a reason to make the input consistent. type ClientID string type ObjectID string type ResourceID string @@ -55,67 +98,54 @@ func (s systemAssignedValue) value() string { return string(s) } func (c ClientID) value() string { return string(c) } func (o ObjectID) value() string { return string(o) } func (r ResourceID) value() string { return string(r) } +func SystemAssigned() ID { + return systemAssignedValue("") +} //------------------ +//-- construction of the structues and the API's // Client is a client that provides access to Managed Identity token calls. type Client struct { - // AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). also may remove from here // cacheAccessorMu *sync.RWMutex httpClient ops.HTTPClient MiType ID - // Token *oauth.Client // pmanager manager // todo : expose the manager from base. // cacheAccessor cache.ExportReplace } -// clientOptions are optional settings for New(). These options are set using various functions -// returning Option calls. type ClientOptions struct { - claims string // bypasses cache, does nothing else httpClient ops.HTTPClient - // disableInstanceDiscovery bool // always false - // clientId string } -// type withClaimsOption struct{ Claims string } -// type withHTTPClientOption struct { -// HttpClient ops.HTTPClient -// } - -// Option is an optional argument to New(). -type Option interface{ apply(*ClientOptions) } - -// type ClientOptions interface{ ClientOption() } type AcquireTokenOptions struct { Claims string } -type AcquireTokenOption interface{ apply(*AcquireTokenOptions) } - -// Source represents the managed identity sources supported. -func (w AcquireTokenOptions) AcquireTokenOption() {} -func (w ClientOptions) ClientOptions() {} +type ClientOption func(o *ClientOptions) -// func (w *Client) apply(opts ) -func (w Client) apply(opts *ClientOptions) { w.httpClient = opts.httpClient } +type AcquireTokenOption func(o *AcquireTokenOptions) // WithClaims sets additional claims to request for the token, such as those required by conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. -func WithClaims(claims string) AcquireTokenOptions { - return AcquireTokenOptions{Claims: claims} +func WithClaims(claims string) AcquireTokenOption { + return func(o *AcquireTokenOptions) { + o.Claims = claims + } } // WithHTTPClient allows for a custom HTTP client to be set. -func WithHTTPClient(httpClient ops.HTTPClient) ClientOptions { - return ClientOptions{httpClient: httpClient} +func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { + return func(o *ClientOptions) { + o.httpClient = httpClient + } } // Client to be used to acquire tokens for managed identity. // ID: [SystemAssigned()], [ClientID("clientID")], [ResourceID("resourceID")], [ObjectID("objectID")] // // Options: [WithHTTPClient] -func New(id ID, options ...Option) (Client, error) { +func New(id ID, options ...ClientOption) (Client, error) { fmt.Println("idType: ", id.value()) opts := ClientOptions{ // work on this side where @@ -123,7 +153,7 @@ func New(id ID, options ...Option) (Client, error) { } for _, option := range options { - option.apply(&opts) + option(&opts) } client := Client{ // TODO :: check for http client @@ -134,69 +164,71 @@ func New(id ID, options ...Option) (Client, error) { return client, nil } -type responseJson struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn string `json:"expires_in"` - ExpiresOn string `json:"expires_on"` - NotBefore string `json:"not_before"` - Resource string `json:"resource"` - TokenType string `json:"token_type"` +func getTokenForURL(url *url.URL, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { + req, err := http.NewRequest(http.MethodGet, url.String(), nil) + if err != nil { + return accesstokens.TokenResponse{}, err + } + req.Header.Add("Metadata", "true") + + resp, err := httpClient.Do(req) + if err != nil { + return accesstokens.TokenResponse{}, err + } + + // Pull out response body + responseBytes, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return accesstokens.TokenResponse{}, err + } + + // Unmarshall response body into struct + var r accesstokens.TokenResponse + err = json.Unmarshal(responseBytes, &r) + if err != nil { + return accesstokens.TokenResponse{}, err + } + return r, nil + } // Acquires tokens from the configured managed identity on an azure resource. // // Resource: scopes application is requesting access to // Options: [WithClaims] -func (client Client) AcquireToken(context context.Context, resource string) (base.AuthResult, error) { - // o := AcquireTokenOptions{} +func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + o := AcquireTokenOptions{} - // for _, option := range options { - // option.apply(&o) - // } + for _, option := range options { + option(&o) + } + // try and find some resource which cna be accessed + // service fabric GET + // app service GET + // could shell POST request + // azure arc GET + // default :: IMDS GET + // + // Sources that send GET requests: App Service, Azure Arc, IMDS, Service Fabric + // + // Sources that send POST requests: Cloud Shell if client.MiType == SystemAssigned() { - systemUrl := "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01" var msiEndpoint *url.URL - msiEndpoint, err := url.Parse(systemUrl) + msiEndpoint, err := url.Parse(IMDSEndpoint) if err != nil { fmt.Println("Error creating URL: ", err) return base.AuthResult{}, nil } msiParameters := msiEndpoint.Query() - msiParameters.Add("resource", "https://management.azure.com/") + msiParameters.Add("api-version", "2018-02-01") + msiParameters.Add("resource", resource) msiEndpoint.RawQuery = msiParameters.Encode() - req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) - if err != nil { - fmt.Println("Error creating HTTP request: ", err) - return base.AuthResult{}, nil - } - req.Header.Add("Metadata", "true") - - resp, err := client.httpClient.Do(req) - if err != nil { - fmt.Println("Error calling token endpoint: ", err) - return base.AuthResult{}, nil - } - - // Pull out response body - responseBytes, err := io.ReadAll(resp.Body) - defer resp.Body.Close() - if err != nil { - fmt.Println("Error reading response body : ", err) - return base.AuthResult{}, nil - } - - // Unmarshall response body into struct - var r accesstokens.TokenResponse - err = json.Unmarshal(responseBytes, &r) - if err != nil { - fmt.Println("Error unmarshalling the response:", err) - return base.AuthResult{}, nil - } - println("Access token :: ", r.AccessToken) - return base.NewAuthResult(r, shared.Account{}) + token, err := getTokenForURL(msiEndpoint, client.httpClient) + println("Access token :: ", token.AccessToken) + return base.NewAuthResult(token, shared.Account{}) } // all the other options. diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4bf34540..fa9670d6 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -4,43 +4,78 @@ package managedidentity import ( "context" + "fmt" + "io" + "net/http" + "strings" "testing" + "time" + + internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" ) -func fakeClient(mangedIdentityId ID, options ...Option) (Client, error) { - client, err := New(mangedIdentityId, options...) +func fakeMIClient(mangedIdentityId ID, options ...ClientOption) (Client, error) { + fakeClient, err := New(mangedIdentityId, options...) if err != nil { return Client{}, err } - return client, nil + return fakeClient, nil } -func TestManagedIdentity(t *testing.T) { - client, err := fakeClient(SystemAssigned()) +type errorClient struct{} - if err != nil { - t.Fatal(err) - } +func (*errorClient) Do(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) +} - _, err = client.AcquireToken(context.Background(), "scope", WithClaims("claim")) +type fakeClient struct{} - if err == nil { - t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") +func (*fakeClient) CloseIdleConnections() {} +func (*fakeClient) Do(req *http.Request) (*http.Response, error) { + w := http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "access_token": "fakeToken", + "refresh_token": "", + "expires_in": "3599", + "expires_on": "1506484173", + "not_before": "1506480273", + "resource": "fakeresource", + "token_type": "Bearer" + }`)), + Header: make(http.Header), } + return &w, nil } -func TestManagedIdentityWithClaims(t *testing.T) { - client, err := fakeClient(ClientID("123")) +func TestManagedIdentity(t *testing.T) { + fakeHTTPClient := fakeClient{} + + client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) if err != nil { t.Fatal(err) } - _, err = client.AcquireToken(context.Background(), "scope", WithClaims("claim")) + result, err := client.AcquireToken(context.Background(), "fakeresource") - if err == nil { - t.Errorf("TestManagedIdentityWithClaims: unexpected nil error from TestManagedIdentityWithClaims") + if err != nil { + t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") } + var tokenScope = []string{"the_scope"} + + expected := accesstokens.TokenResponse{ + AccessToken: "fakeToken", + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "TokenType", + } + if result.AccessToken != expected.AccessToken { + t.Fatalf(`unexpected access token "%s"`, result.AccessToken) + } + } From 7c9418256dbf26ba7a19688f0e1ad815b988381f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 2 Sep 2024 12:07:55 +0100 Subject: [PATCH 04/32] Reverted the test app to original state Reverted changes in the test app --- .../devapps/client_certificate_sample.go | 102 +++++++++--------- apps/tests/devapps/main.go | 7 +- apps/tests/devapps/sample_utils.go | 24 ++--- 3 files changed, 66 insertions(+), 67 deletions(-) diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index d32dcb68..467a45bc 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -3,54 +3,54 @@ package main -// import ( -// "context" -// "fmt" -// "log" -// "os" - -// "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" -// ) - -// var _config2 *Config = CreateConfig("confidential_config.json") - -// // Keep the ConfidentialClient application object around, because it maintains a token cache -// // For simplicity, the sample uses global variables. -// // For user flows (web site, web api) or for large multi-tenant apps use a cache per user or per tenant -// var _app2 *confidential.Client = createAppWithCert() - -// func createAppWithCert() *confidential.Client { - -// pemData, err := os.ReadFile(_config2.PemData) -// if err != nil { -// log.Fatal(err) -// } - -// // This extracts our public certificates and private key from the PEM file. If it is -// // encrypted, the second argument must be password to decode. -// // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) -// // and to download it when the application starts. -// certs, privateKey, err := confidential.CertFromPEM(pemData, "") -// if err != nil { -// log.Fatal(err) -// } -// cred, err := confidential.NewCredFromCert(certs, privateKey) -// if err != nil { -// log.Fatal(err) -// } -// app, err := confidential.New(_config2.Authority, _config2.ClientID, cred, confidential.WithCache(cacheAccessor)) -// if err != nil { -// log.Fatal(err) -// } -// return &app -// } - -// func acquireTokenClientCertificate() { - -// result, err := _app2.AcquireTokenByCredential(context.Background(), _config1.Scopes) -// if err != nil { -// log.Fatal(err) -// } - -// fmt.Println("A Bearer token was acquired, it expires on: ", result.ExpiresOn) -// } +import ( + "context" + "fmt" + "log" + "os" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" +) + +var _config2 *Config = CreateConfig("confidential_config.json") + +// Keep the ConfidentialClient application object around, because it maintains a token cache +// For simplicity, the sample uses global variables. +// For user flows (web site, web api) or for large multi-tenant apps use a cache per user or per tenant +var _app2 *confidential.Client = createAppWithCert() + +func createAppWithCert() *confidential.Client { + + pemData, err := os.ReadFile(_config2.PemData) + if err != nil { + log.Fatal(err) + } + + // This extracts our public certificates and private key from the PEM file. If it is + // encrypted, the second argument must be password to decode. + // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) + // and to download it when the application starts. + certs, privateKey, err := confidential.CertFromPEM(pemData, "") + if err != nil { + log.Fatal(err) + } + cred, err := confidential.NewCredFromCert(certs, privateKey) + if err != nil { + log.Fatal(err) + } + app, err := confidential.New(_config2.Authority, _config2.ClientID, cred, confidential.WithCache(cacheAccessor)) + if err != nil { + log.Fatal(err) + } + return &app +} + +func acquireTokenClientCertificate() { + + result, err := _app2.AcquireTokenByCredential(context.Background(), _config1.Scopes) + if err != nil { + log.Fatal(err) + } + + fmt.Println("A Bearer token was acquired, it expires on: ", result.ExpiresOn) +} diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index d1da7260..b15d0a9e 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -5,7 +5,7 @@ import ( ) var ( - //config = CreateConfig("config.json") + config = CreateConfig("config.json") cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) @@ -36,10 +36,9 @@ func main() { } else if exampleType == "6" { // This sample does not use a serialized cache - it relies on in-memory cache by reusing the app object // This works well for app tokens, because there is only 1 token per resource, per tenant. - // acquireTokenClientCertificate() - println("Nothing to do in this option") + acquireTokenClientCertificate() // this time the token comes from the cache! - // acquireTokenClientCertificate() + acquireTokenClientCertificate() } else if exampleType == "7" { RunManagedIdentity() } diff --git a/apps/tests/devapps/sample_utils.go b/apps/tests/devapps/sample_utils.go index 05e8b2ff..b311927b 100644 --- a/apps/tests/devapps/sample_utils.go +++ b/apps/tests/devapps/sample_utils.go @@ -11,18 +11,18 @@ import ( // Config represents the config.json required to run the samples type Config struct { - ClientID string `json:"client_id"` - Authority string `json:"authority"` - Scopes []string `json:"scopes"` - Username string `json:"username"` - Password string `json:"password"` - RedirectURI string `json:"redirect_uri"` - // CodeChallenge string `json:"code_challenge"` - // CodeChallengeMethod string `json:"code_challenge_method"` - // State string `json:"state"` - ClientSecret string `json:"client_secret"` - // Thumbprint string `json:"thumbprint"` - // PemData string `json:"pem_file"` + ClientID string `json:"client_id"` + Authority string `json:"authority"` + Scopes []string `json:"scopes"` + Username string `json:"username"` + Password string `json:"password"` + RedirectURI string `json:"redirect_uri"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + State string `json:"state"` + ClientSecret string `json:"client_secret"` + Thumbprint string `json:"thumbprint"` + PemData string `json:"pem_file"` } // CreateConfig creates the Config struct from a json file. From 264641817343f0659442d43799a7878c6fac1dab Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 2 Sep 2024 12:09:01 +0100 Subject: [PATCH 05/32] Formatting changes Formatting changes --- apps/tests/devapps/client_certificate_sample.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index 467a45bc..c55ed571 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -1,5 +1,5 @@ -// // Copyright (c) Microsoft Corporation. -// // Licensed under the MIT license. +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. package main From 4db1c7e3a33f38b2d8743dc01aa6008e17627443 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 3 Sep 2024 16:25:33 +0100 Subject: [PATCH 06/32] Added methods for UAMI Added method for UAMI --- apps/managedidentity/managedidentity.go | 84 +++++++++----------- apps/managedidentity/managedidentity_test.go | 28 +++++-- 2 files changed, 59 insertions(+), 53 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index c9d1f371..e4f17f3f 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -23,42 +23,24 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) +// General request querry parameter names const ( MetaHTTPHeadderName = "Metadata" APIVersionQuerryParameterName = "api-version" ResourceBodyOrQuerryParameterName = "resource" ) +// UAMI querry parameter name const ( MIQuerryParameterClientId = "client_id" MIQuerryParameterObjectId = "object_id" MIQuerryParameterResourceId = "mi_res_id" ) -// AZURE_POD_IDENTITY_AUTHORITY_HOST: "AZURE_POD_IDENTITY_AUTHORITY_HOST", -// -// IDENTITY_ENDPOINT: "IDENTITY_ENDPOINT", -// IDENTITY_HEADER: "IDENTITY_HEADER", -// IDENTITY_SERVER_THUMBPRINT: "IDENTITY_SERVER_THUMBPRINT", -// IMDS_ENDPOINT: "IMDS_ENDPOINT", -// MSI_ENDPOINT: "MSI_ENDPOINT", -// APP_SERVICE: "AppService", -// -// These are the MI source names -// AZURE_ARC: "AzureArc", -// CLOUD_SHELL: "CloudShell", -// DEFAULT_TO_IMDS: "DefaultToImds", -// IMDS: "Imds", -// SERVICE_FABRIC: "ServiceFabric", - -// CouldShell -// This also comes from enviourment point :: :: -const () - // Appservice -// end point comes from enviournment variable ?? ?!?!?!?!? +// end point comes from enviournment variable ?? const ( - AppServiceMSIEndPointVersion = "2019-08-01" + AppServiceMSIEndPointAPIVersion = "2019-08-01" ) // Arc @@ -69,8 +51,7 @@ const ( // IMDS const ( - IMDSTokenPath = "/metadata/identity/oauth2/token" - IMDSEndpoint = "http://169.254.169.254" + IMDSTokenPath + IMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" IMDSAPIVersion = "2018-02-01" ) @@ -102,9 +83,6 @@ func SystemAssigned() ID { return systemAssignedValue("") } -//------------------ -//-- construction of the structues and the API's - // Client is a client that provides access to Managed Identity token calls. type Client struct { // cacheAccessorMu *sync.RWMutex @@ -204,7 +182,7 @@ func (client Client) AcquireToken(context context.Context, resource string, opti option(&o) } - // try and find some resource which cna be accessed + // try and find some resource which can be accessed // service fabric GET // app service GET // could shell POST request @@ -214,25 +192,41 @@ func (client Client) AcquireToken(context context.Context, resource string, opti // Sources that send GET requests: App Service, Azure Arc, IMDS, Service Fabric // // Sources that send POST requests: Cloud Shell - if client.MiType == SystemAssigned() { - var msiEndpoint *url.URL - msiEndpoint, err := url.Parse(IMDSEndpoint) - if err != nil { - fmt.Println("Error creating URL: ", err) - return base.AuthResult{}, nil - } - msiParameters := msiEndpoint.Query() - msiParameters.Add("api-version", "2018-02-01") - msiParameters.Add("resource", resource) - msiEndpoint.RawQuery = msiParameters.Encode() - - token, err := getTokenForURL(msiEndpoint, client.httpClient) - println("Access token :: ", token.AccessToken) - return base.NewAuthResult(token, shared.Account{}) + + var msiEndpoint *url.URL + msiEndpoint, err := url.Parse(IMDSEndpoint) + if err != nil { + fmt.Println("Error creating URL: ", err) + return base.AuthResult{}, nil } + msiParameters := msiEndpoint.Query() + msiParameters.Add("api-version", "2018-02-01") + msiParameters.Add("resource", resource) - // all the other options. - return base.AuthResult{}, nil + if len(o.Claims) > 0 { + msiParameters.Add("claims", o.Claims) + } + + switch client.MiType.(type) { + case ClientID: + msiParameters.Add(MIQuerryParameterClientId, client.MiType.value()) + case ResourceID: + msiParameters.Add(MIQuerryParameterResourceId, client.MiType.value()) + case ObjectID: + msiParameters.Add(MIQuerryParameterObjectId, client.MiType.value()) + case systemAssignedValue: // not adding anything + default: + return base.AuthResult{}, fmt.Errorf("Type not suported") + + } + + msiEndpoint.RawQuery = msiParameters.Encode() + token, err := getTokenForURL(msiEndpoint, client.httpClient) + if err != nil { + return base.AuthResult{}, fmt.Errorf("URL not formed") + } + println("Access token ** ", token.AccessToken) + return base.NewAuthResult(token, shared.Account{}) } // Detects and returns the managed identity source available on the environment. diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index fa9670d6..bad9b4b0 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -27,6 +27,7 @@ func fakeMIClient(mangedIdentityId ID, options ...ClientOption) (Client, error) type errorClient struct{} +func (*errorClient) CloseIdleConnections() {} func (*errorClient) Do(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) } @@ -43,7 +44,6 @@ func (*fakeClient) Do(req *http.Request) (*http.Response, error) { "expires_in": "3599", "expires_on": "1506484173", "not_before": "1506480273", - "resource": "fakeresource", "token_type": "Bearer" }`)), Header: make(http.Header), @@ -51,7 +51,7 @@ func (*fakeClient) Do(req *http.Request) (*http.Response, error) { return &w, nil } -func TestManagedIdentity(t *testing.T) { +func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { fakeHTTPClient := fakeClient{} client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) @@ -65,17 +65,29 @@ func TestManagedIdentity(t *testing.T) { if err != nil { t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") } - var tokenScope = []string{"the_scope"} expected := accesstokens.TokenResponse{ - AccessToken: "fakeToken", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, - TokenType: "TokenType", + AccessToken: "fakeToken", + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + TokenType: "Bearer", } if result.AccessToken != expected.AccessToken { t.Fatalf(`unexpected access token "%s"`, result.AccessToken) } } +func TestManagedIdentityIMDS_SAMIError(t *testing.T) { + fakeHTTPClient := errorClient{} + + client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) + + if err != nil { + t.Fatal(err) + } + + if _, err := client.AcquireToken(context.Background(), "fakeresource"); err == nil { + t.Errorf("TestManagedIdentity: Should have returned error for incorrect http request.") + } + +} From 3bf038396567de78711ab9fc2fcb7df2bf8fdf69 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 4 Sep 2024 16:03:23 +0100 Subject: [PATCH 07/32] Updated and cleaned up MI for SAMI Updated the some code and cleaned up some comments and print statement --- apps/managedidentity/managedidentity.go | 100 ++++++------------- apps/managedidentity/managedidentity_test.go | 54 +++++++--- apps/tests/devapps/managedidentity_sample.go | 29 +----- 3 files changed, 78 insertions(+), 105 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index e4f17f3f..4cf952a3 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -25,47 +25,32 @@ import ( // General request querry parameter names const ( - MetaHTTPHeadderName = "Metadata" - APIVersionQuerryParameterName = "api-version" - ResourceBodyOrQuerryParameterName = "resource" + metaHTTPHeadderName = "Metadata" + apiVersionQuerryParameterName = "api-version" + resourceQuerryParameterName = "resource" ) // UAMI querry parameter name const ( - MIQuerryParameterClientId = "client_id" - MIQuerryParameterObjectId = "object_id" - MIQuerryParameterResourceId = "mi_res_id" -) - -// Appservice -// end point comes from enviournment variable ?? -const ( - AppServiceMSIEndPointAPIVersion = "2019-08-01" -) - -// Arc -const ( - ARCAPIEndpoint = "http://127.0.0.1:40342/metadata/identity/oauth2/token" - ARCAPIVersion = "2019-11-01" + miQuerryParameterClientId = "client_id" + miQuerryParameterObjectId = "object_id" + miQuerryParameterResourceId = "mi_res_id" ) // IMDS const ( - IMDSEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" - IMDSAPIVersion = "2018-02-01" + imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + imdsAPIVersion = "2018-02-01" ) const ( // DefaultToIMDS indicates that the source is defaulted to IMDS since no environment variables are set. - DefaultToIMDS = 0 + defaultToIMDS = 0 // AzureArc represents the source to acquire token for managed identity is Azure Arc. - AzureArc = 1 + azureArc = 1 ) -// id type managed identity -type Source int - type ID interface { value() string } @@ -83,13 +68,9 @@ func SystemAssigned() ID { return systemAssignedValue("") } -// Client is a client that provides access to Managed Identity token calls. type Client struct { - // cacheAccessorMu *sync.RWMutex httpClient ops.HTTPClient - MiType ID - // pmanager manager // todo : expose the manager from base. - // cacheAccessor cache.ExportReplace + miType ID } type ClientOptions struct { @@ -97,7 +78,7 @@ type ClientOptions struct { } type AcquireTokenOptions struct { - Claims string + claims string } type ClientOption func(o *ClientOptions) @@ -108,7 +89,7 @@ type AcquireTokenOption func(o *AcquireTokenOptions) // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. func WithClaims(claims string) AcquireTokenOption { return func(o *AcquireTokenOptions) { - o.Claims = claims + o.claims = claims } } @@ -124,8 +105,6 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { - fmt.Println("idType: ", id.value()) - opts := ClientOptions{ // work on this side where httpClient: shared.DefaultClient, } @@ -134,8 +113,8 @@ func New(id ID, options ...ClientOption) (Client, error) { option(&opts) } - client := Client{ // TODO :: check for http client - MiType: id, + client := Client{ + miType: id, httpClient: opts.httpClient, } @@ -147,13 +126,15 @@ func getTokenForURL(url *url.URL, httpClient ops.HTTPClient) (accesstokens.Token if err != nil { return accesstokens.TokenResponse{}, err } - req.Header.Add("Metadata", "true") + req.Header.Add(metaHTTPHeadderName, "true") resp, err := httpClient.Do(req) if err != nil { return accesstokens.TokenResponse{}, err } - + if resp.StatusCode != http.StatusOK { + return accesstokens.TokenResponse{}, fmt.Errorf("Error code was non Ok %T ", resp.StatusCode) + } // Pull out response body responseBytes, err := io.ReadAll(resp.Body) defer resp.Body.Close() @@ -182,54 +163,37 @@ func (client Client) AcquireToken(context context.Context, resource string, opti option(&o) } - // try and find some resource which can be accessed - // service fabric GET - // app service GET - // could shell POST request - // azure arc GET - // default :: IMDS GET - // - // Sources that send GET requests: App Service, Azure Arc, IMDS, Service Fabric - // - // Sources that send POST requests: Cloud Shell - var msiEndpoint *url.URL - msiEndpoint, err := url.Parse(IMDSEndpoint) + msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { fmt.Println("Error creating URL: ", err) return base.AuthResult{}, nil } msiParameters := msiEndpoint.Query() - msiParameters.Add("api-version", "2018-02-01") - msiParameters.Add("resource", resource) + msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") + msiParameters.Add(resourceQuerryParameterName, resource) - if len(o.Claims) > 0 { - msiParameters.Add("claims", o.Claims) + if len(o.claims) > 0 { + msiParameters.Add("claims", o.claims) } - switch client.MiType.(type) { + switch client.miType.(type) { case ClientID: - msiParameters.Add(MIQuerryParameterClientId, client.MiType.value()) + msiParameters.Add(miQuerryParameterClientId, client.miType.value()) case ResourceID: - msiParameters.Add(MIQuerryParameterResourceId, client.MiType.value()) + msiParameters.Add(miQuerryParameterResourceId, client.miType.value()) case ObjectID: - msiParameters.Add(MIQuerryParameterObjectId, client.MiType.value()) + msiParameters.Add(miQuerryParameterObjectId, client.miType.value()) case systemAssignedValue: // not adding anything default: - return base.AuthResult{}, fmt.Errorf("Type not suported") + return base.AuthResult{}, fmt.Errorf("unsupported type %T", client.miType) } msiEndpoint.RawQuery = msiParameters.Encode() - token, err := getTokenForURL(msiEndpoint, client.httpClient) + tokenResponse, err := getTokenForURL(msiEndpoint, client.httpClient) if err != nil { - return base.AuthResult{}, fmt.Errorf("URL not formed") + return base.AuthResult{}, err } - println("Access token ** ", token.AccessToken) - return base.NewAuthResult(token, shared.Account{}) -} - -// Detects and returns the managed identity source available on the environment. -func GetSource() Source { - return DefaultToIMDS + return base.NewAuthResult(tokenResponse, shared.Account{}) } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index bad9b4b0..0ff73ae2 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -32,13 +32,22 @@ func (*errorClient) Do(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) } -type fakeClient struct{} +type FakeClient struct { + responseType int +} -func (*fakeClient) CloseIdleConnections() {} -func (*fakeClient) Do(req *http.Request) (*http.Response, error) { - w := http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(`{ +func (c *FakeClient) CloseIdleConnections() {} +func (c *FakeClient) Do(req *http.Request) (*http.Response, error) { + println(c.responseType) + w := makeResponse(c.responseType) + return &w, nil +} + +func makeResponse(responseType int) http.Response { + if responseType == 1 { + return http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ "access_token": "fakeToken", "refresh_token": "", "expires_in": "3599", @@ -46,13 +55,18 @@ func (*fakeClient) Do(req *http.Request) (*http.Response, error) { "not_before": "1506480273", "token_type": "Bearer" }`)), - Header: make(http.Header), + Header: make(http.Header), + } + } else { + return http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + } } - return &w, nil } - func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { - fakeHTTPClient := fakeClient{} + fakeHTTPClient := FakeClient{responseType: 1} client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) @@ -63,7 +77,7 @@ func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { result, err := client.AcquireToken(context.Background(), "fakeresource") if err != nil { - t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") + t.Fatal("TestManagedIdentity: unexpected nil error from TestManagedIdentity") } expected := accesstokens.TokenResponse{ @@ -77,6 +91,22 @@ func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { } } + +func TestManagedIdentityIMDS_SAMIHttpRequestFailure(t *testing.T) { + fakeHTTPClient := FakeClient{responseType: 2} + + client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) + + if err != nil { + t.Fatal(err) + } + + if _, err := client.AcquireToken(context.Background(), "fakeresource"); err == nil { + t.Fatal("TestManagedIdentity: Should have returned error for incorrect http request.") + } + +} + func TestManagedIdentityIMDS_SAMIError(t *testing.T) { fakeHTTPClient := errorClient{} @@ -87,7 +117,7 @@ func TestManagedIdentityIMDS_SAMIError(t *testing.T) { } if _, err := client.AcquireToken(context.Background(), "fakeresource"); err == nil { - t.Errorf("TestManagedIdentity: Should have returned error for incorrect http request.") + t.Fatal("TestManagedIdentity: Should have returned error for incorrect http request.") } } diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index 3c8394e2..9fba6ddb 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -8,31 +8,10 @@ import ( ) func RunManagedIdentity() { - // customHttpClient := &http.Client{} - - miSystemAssigned, error := mi.New(mi.SystemAssigned()) - if error != nil { - fmt.Println(error) + miSystemAssigned, err := mi.New(mi.SystemAssigned()) + if err != nil { + fmt.Println(err) } - - // miClientIdAssigned, error := mi.New(mi.ClientID("951d3571-c442-42e0-9efd-e1d7e1a21030"), - // mi.WithHTTPClient(customHttpClient)) - // if error != nil { - // fmt.Println(error) - // } - - // miResourceIdAssigned, error := mi.New(mi.ResourceID("resource id 123")) - // if error != nil { - // fmt.Println(error) - // } - - // miObjectIdAssigned, error := mi.New(mi.ObjectID("object id 123")) - // if error != nil { - // fmt.Println(error) - // } - miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") - // miClientIdAssigned.AcquireToken(context.Background(), "resource") - // miResourceIdAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) - // miObjectIdAssigned.AcquireToken(context.Background(), "resource") + } From 8c3fed1dbf78c4adbbace1fab3c0efdeb3c2dae4 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:17:10 +0100 Subject: [PATCH 08/32] Update apps/managedidentity/managedidentity.go Updated the key for the resource Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 4cf952a3..bb7f7bea 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -34,7 +34,7 @@ const ( const ( miQuerryParameterClientId = "client_id" miQuerryParameterObjectId = "object_id" - miQuerryParameterResourceId = "mi_res_id" + miQuerryParameterResourceId = "msi_res_id" ) // IMDS From 5eb2919f9a9544f17b0d985a37981b3ca11d9d9a Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 10 Sep 2024 09:06:56 +0100 Subject: [PATCH 09/32] Resolved some comments. Updated the token from url function to a reaquest based function --- apps/managedidentity/managedidentity.go | 97 +++++++++----- apps/managedidentity/managedidentity_test.go | 133 +++++++++++++++++++ apps/tests/devapps/managedidentity_sample.go | 7 +- 3 files changed, 202 insertions(+), 35 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 4cf952a3..49633cab 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -16,6 +16,7 @@ import ( "io" "net/http" "net/url" + "strings" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" @@ -34,7 +35,7 @@ const ( const ( miQuerryParameterClientId = "client_id" miQuerryParameterObjectId = "object_id" - miQuerryParameterResourceId = "mi_res_id" + miQuerryParameterResourceId = "msi_res_id" ) // IMDS @@ -105,7 +106,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { - opts := ClientOptions{ // work on this side where + opts := ClientOptions{ httpClient: shared.DefaultClient, } @@ -121,11 +122,58 @@ func New(id ID, options ...ClientOption) (Client, error) { return client, nil } -func getTokenForURL(url *url.URL, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { - req, err := http.NewRequest(http.MethodGet, url.String(), nil) +func createIMDSAuthRequest(_ context.Context, id ID, resource string, claims string) (*http.Request, error) { + var msiEndpoint *url.URL + msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { - return accesstokens.TokenResponse{}, err + return &http.Request{}, fmt.Errorf("Error creating URL as parsing the URL filed") + } + + msiParameters := msiEndpoint.Query() + msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") + + resource = removeSuffix(resource, "/.default") + print(resource) + msiParameters.Add(resourceQuerryParameterName, resource) + + if len(claims) > 0 { + msiParameters.Add("claims", claims) + } + + switch t := id.(type) { + case ClientID: + if len(string(t)) > 0 { + msiParameters.Add(miQuerryParameterClientId, string(t)) + } else { + return &http.Request{}, fmt.Errorf("ClientId parameter is empty for %T", t) + } + case ResourceID: + if len(string(t)) > 0 { + msiParameters.Add(miQuerryParameterResourceId, string(t)) + } else { + return &http.Request{}, fmt.Errorf("ResourceID parameter is empty for %T", t) + } + case ObjectID: + if len(string(t)) > 0 { + msiParameters.Add(miQuerryParameterObjectId, string(t)) + } else { + return &http.Request{}, fmt.Errorf("ObjectID parameter is empty for %T", t) + } + case systemAssignedValue: // not adding anything + default: + return &http.Request{}, fmt.Errorf("unsupported type %T", id) } + + msiEndpoint.RawQuery = msiParameters.Encode() + fmt.Println(msiEndpoint) + req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) + if err != nil { + return &http.Request{}, fmt.Errorf("Error creating request") + } + return req, nil +} + +func getTokenForRequest(_ context.Context, req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { req.Header.Add(metaHTTPHeadderName, "true") resp, err := httpClient.Do(req) @@ -149,49 +197,32 @@ func getTokenForURL(url *url.URL, httpClient ops.HTTPClient) (accesstokens.Token return accesstokens.TokenResponse{}, err } return r, nil +} +// RemoveSuffix removes the specified 'suffix' from 'str' if it exists. +func removeSuffix(str, suffix string) string { + if strings.HasSuffix(str, suffix) { + return str[:len(str)-len(suffix)] // Remove the suffix if it exists + } + return str // Return the original string if suffix doesn't exist } // Acquires tokens from the configured managed identity on an azure resource. // // Resource: scopes application is requesting access to // Options: [WithClaims] -func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { +func (client Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { o := AcquireTokenOptions{} for _, option := range options { option(&o) } - - var msiEndpoint *url.URL - msiEndpoint, err := url.Parse(imdsEndpoint) + req, err := createIMDSAuthRequest(ctx, client.miType, resource, o.claims) if err != nil { fmt.Println("Error creating URL: ", err) - return base.AuthResult{}, nil - } - msiParameters := msiEndpoint.Query() - msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") - msiParameters.Add(resourceQuerryParameterName, resource) - - if len(o.claims) > 0 { - msiParameters.Add("claims", o.claims) + return base.AuthResult{}, fmt.Errorf("Error while creating request") } - - switch client.miType.(type) { - case ClientID: - msiParameters.Add(miQuerryParameterClientId, client.miType.value()) - case ResourceID: - msiParameters.Add(miQuerryParameterResourceId, client.miType.value()) - case ObjectID: - msiParameters.Add(miQuerryParameterObjectId, client.miType.value()) - case systemAssignedValue: // not adding anything - default: - return base.AuthResult{}, fmt.Errorf("unsupported type %T", client.miType) - - } - - msiEndpoint.RawQuery = msiParameters.Encode() - tokenResponse, err := getTokenForURL(msiEndpoint, client.httpClient) + tokenResponse, err := getTokenForRequest(ctx, req, client.httpClient) if err != nil { return base.AuthResult{}, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 0ff73ae2..1f1ab0d3 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -92,6 +92,139 @@ func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { } +func TestCreateIMDSAuthRequest(t *testing.T) { + tests := []struct { + name string + id ID + resource string + claims string + wantErr bool + }{ + { + name: "System Assigned", + id: SystemAssigned(), + resource: "https://management.azure.com", + claims: "", + wantErr: false, + }, + { + name: "System Assigned", + id: SystemAssigned(), + resource: "https://management.azure.com/.default", + claims: "", + wantErr: false, + }, + { + name: "Client ID", + id: ClientID("test-client-id"), + resource: "https://storage.azure.com", + claims: "", + wantErr: false, + }, + { + name: "Resource ID", + id: ResourceID("test-resource-id"), + resource: "https://vault.azure.net", + claims: "", + wantErr: false, + }, + { + name: "Object ID", + id: ObjectID("test-object-id"), + resource: "https://graph.microsoft.com", + claims: "", + wantErr: false, + }, + { + name: "With Claims", + id: SystemAssigned(), + resource: "https://management.azure.com", + claims: "test-claims", + wantErr: false, + }, + { + name: "Empty Client ID", + id: ClientID(""), + resource: "https://management.azure.com", + claims: "", + wantErr: true, + }, + { + name: "Empty Resource ID", + id: ResourceID(""), + resource: "https://management.azure.com", + claims: "", + wantErr: true, + }, + { + name: "Empty Object ID", + id: ObjectID(""), + resource: "https://management.azure.com", + claims: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) + if tt.wantErr { + if err == nil { + t.Errorf("createIMDSAuthRequest() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + + if err != nil { + t.Errorf("createIMDSAuthRequest() unexpected error = %v", err) + return + } + + if req == nil { + t.Errorf("createIMDSAuthRequest() returned nil request") + return + } + + if req.Method != http.MethodGet { + t.Errorf("createIMDSAuthRequest() method = %v, want %v", req.Method, http.MethodGet) + } + + if !strings.HasPrefix(req.URL.String(), imdsEndpoint) { + t.Errorf("createIMDSAuthRequest() URL = %v, want prefix %v", req.URL.String(), imdsEndpoint) + } + + query := req.URL.Query() + + if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { + t.Errorf("createIMDSAuthRequest() api-version = %v, want %v", query.Get(apiVersionQuerryParameterName), "2018-02-01") + } + + if query.Get(resourceQuerryParameterName) != removeSuffix(tt.resource, "/.default") { + t.Errorf("createIMDSAuthRequest() resource = %v, want %v", query.Get(resourceQuerryParameterName), removeSuffix(tt.resource, "/.default")) + } + + if tt.claims != "" && query.Get("claims") != tt.claims { + t.Errorf("createIMDSAuthRequest() claims = %v, want %v", query.Get("claims"), tt.claims) + } + + switch tt.id.(type) { + case ClientID: + if query.Get(miQuerryParameterClientId) != tt.id.value() { + t.Errorf("createIMDSAuthRequest() client_id = %v, want %v", query.Get(miQuerryParameterClientId), tt.id.value()) + } + case ResourceID: + if query.Get(miQuerryParameterResourceId) != tt.id.value() { + t.Errorf("createIMDSAuthRequest() msi_res_id = %v, want %v", query.Get(miQuerryParameterResourceId), tt.id.value()) + } + case ObjectID: + if query.Get(miQuerryParameterObjectId) != tt.id.value() { + t.Errorf("createIMDSAuthRequest() object_id = %v, want %v", query.Get(miQuerryParameterObjectId), tt.id.value()) + } + } + }) + } +} + func TestManagedIdentityIMDS_SAMIHttpRequestFailure(t *testing.T) { fakeHTTPClient := FakeClient{responseType: 2} diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index 9fba6ddb..95a67d9d 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -12,6 +12,9 @@ func RunManagedIdentity() { if err != nil { fmt.Println(err) } - miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") - + temp, err := miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") + if err != nil { + println(err.Error()) + } + fmt.Println("token : ", temp.AccessToken) } From 64e470503d35c2d84990a6529bb367dace0b04e4 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 10 Sep 2024 15:45:11 +0100 Subject: [PATCH 10/32] Updated test Updated test to fail not only return error --- apps/managedidentity/managedidentity_test.go | 27 +++++++------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 1f1ab0d3..3d614fd5 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -170,55 +170,46 @@ func TestCreateIMDSAuthRequest(t *testing.T) { req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) if tt.wantErr { if err == nil { - t.Errorf("createIMDSAuthRequest() error = %v, wantErr %v", err, tt.wantErr) + t.Fatal(err) } return } - if err != nil { - t.Errorf("createIMDSAuthRequest() unexpected error = %v", err) - return - } - if req == nil { - t.Errorf("createIMDSAuthRequest() returned nil request") + t.Fatal("createIMDSAuthRequest() returned nil request") return } if req.Method != http.MethodGet { - t.Errorf("createIMDSAuthRequest() method = %v, want %v", req.Method, http.MethodGet) + t.Fatal("createIMDSAuthRequest() method is not GET") } if !strings.HasPrefix(req.URL.String(), imdsEndpoint) { - t.Errorf("createIMDSAuthRequest() URL = %v, want prefix %v", req.URL.String(), imdsEndpoint) + t.Fatal("createIMDSAuthRequest() URL is not matched.") } query := req.URL.Query() if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { - t.Errorf("createIMDSAuthRequest() api-version = %v, want %v", query.Get(apiVersionQuerryParameterName), "2018-02-01") + t.Fatal("createIMDSAuthRequest() api-version missmatch") } if query.Get(resourceQuerryParameterName) != removeSuffix(tt.resource, "/.default") { - t.Errorf("createIMDSAuthRequest() resource = %v, want %v", query.Get(resourceQuerryParameterName), removeSuffix(tt.resource, "/.default")) - } - - if tt.claims != "" && query.Get("claims") != tt.claims { - t.Errorf("createIMDSAuthRequest() claims = %v, want %v", query.Get("claims"), tt.claims) + t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") } switch tt.id.(type) { case ClientID: if query.Get(miQuerryParameterClientId) != tt.id.value() { - t.Errorf("createIMDSAuthRequest() client_id = %v, want %v", query.Get(miQuerryParameterClientId), tt.id.value()) + t.Fatal("createIMDSAuthRequest() client_id does not match with the id value") } case ResourceID: if query.Get(miQuerryParameterResourceId) != tt.id.value() { - t.Errorf("createIMDSAuthRequest() msi_res_id = %v, want %v", query.Get(miQuerryParameterResourceId), tt.id.value()) + t.Fatal("createIMDSAuthRequest() resource id does not match with the id value") } case ObjectID: if query.Get(miQuerryParameterObjectId) != tt.id.value() { - t.Errorf("createIMDSAuthRequest() object_id = %v, want %v", query.Get(miQuerryParameterObjectId), tt.id.value()) + t.Fatal("createIMDSAuthRequest() object id does not match with the id value") } } }) From a7e760a70ac4baef480a25e35149ba8de01952e0 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 11 Sep 2024 18:50:03 +0100 Subject: [PATCH 11/32] Updated the Identity method for feedback Added tests for failure and success for SAMI --- apps/managedidentity/managedidentity.go | 61 ++-- apps/managedidentity/managedidentity_test.go | 305 ++++++++++++++----- 2 files changed, 269 insertions(+), 97 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 49633cab..2c49e028 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -12,6 +12,7 @@ package managedidentity import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -24,6 +25,14 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) +const ( + DefaultToIMDS Source = 0 + AzureArc Source = 1 + ServiceFabric Source = 2 + CloudShell Source = 3 + AppService Source = 4 +) + // General request querry parameter names const ( metaHTTPHeadderName = "Metadata" @@ -52,6 +61,25 @@ const ( azureArc = 1 ) +type Source int + +func (s Source) String() string { + switch s { + case DefaultToIMDS: + return "DefaultToIMDS" + case AzureArc: + return "AzureArc" + case ServiceFabric: + return "ServiceFabric" + case CloudShell: + return "CloudShell" + case AppService: + return "AppService" + default: + return fmt.Sprintf("UnknownSource(%d)", s) + } +} + type ID interface { value() string } @@ -126,14 +154,14 @@ func createIMDSAuthRequest(_ context.Context, id ID, resource string, claims str var msiEndpoint *url.URL msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { - return &http.Request{}, fmt.Errorf("Error creating URL as parsing the URL filed") + return nil, errors.New("error creating URL as parsing the URL filed") } msiParameters := msiEndpoint.Query() msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") - resource = removeSuffix(resource, "/.default") - print(resource) + resource = strings.TrimSuffix(resource, "/.default") + msiParameters.Add(resourceQuerryParameterName, resource) if len(claims) > 0 { @@ -145,43 +173,41 @@ func createIMDSAuthRequest(_ context.Context, id ID, resource string, claims str if len(string(t)) > 0 { msiParameters.Add(miQuerryParameterClientId, string(t)) } else { - return &http.Request{}, fmt.Errorf("ClientId parameter is empty for %T", t) + return nil, fmt.Errorf("clientId parameter is empty for %T", t) } case ResourceID: if len(string(t)) > 0 { msiParameters.Add(miQuerryParameterResourceId, string(t)) } else { - return &http.Request{}, fmt.Errorf("ResourceID parameter is empty for %T", t) + return nil, fmt.Errorf("resourceID parameter is empty for %T", t) } case ObjectID: if len(string(t)) > 0 { msiParameters.Add(miQuerryParameterObjectId, string(t)) } else { - return &http.Request{}, fmt.Errorf("ObjectID parameter is empty for %T", t) + return nil, fmt.Errorf("objectID parameter is empty for %T", t) } case systemAssignedValue: // not adding anything default: - return &http.Request{}, fmt.Errorf("unsupported type %T", id) + return nil, fmt.Errorf("unsupported type %T", id) } msiEndpoint.RawQuery = msiParameters.Encode() - fmt.Println(msiEndpoint) req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) if err != nil { - return &http.Request{}, fmt.Errorf("Error creating request") + return nil, errors.New("error creating request") } return req, nil } -func getTokenForRequest(_ context.Context, req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { +func getTokenForRequest(ctx context.Context, req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { req.Header.Add(metaHTTPHeadderName, "true") - resp, err := httpClient.Do(req) if err != nil { return accesstokens.TokenResponse{}, err } if resp.StatusCode != http.StatusOK { - return accesstokens.TokenResponse{}, fmt.Errorf("Error code was non Ok %T ", resp.StatusCode) + return accesstokens.TokenResponse{}, fmt.Errorf("failed to authenticate with status code %d ", resp.StatusCode) } // Pull out response body responseBytes, err := io.ReadAll(resp.Body) @@ -199,14 +225,6 @@ func getTokenForRequest(_ context.Context, req *http.Request, httpClient ops.HTT return r, nil } -// RemoveSuffix removes the specified 'suffix' from 'str' if it exists. -func removeSuffix(str, suffix string) string { - if strings.HasSuffix(str, suffix) { - return str[:len(str)-len(suffix)] // Remove the suffix if it exists - } - return str // Return the original string if suffix doesn't exist -} - // Acquires tokens from the configured managed identity on an azure resource. // // Resource: scopes application is requesting access to @@ -219,8 +237,7 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options } req, err := createIMDSAuthRequest(ctx, client.miType, resource, o.claims) if err != nil { - fmt.Println("Error creating URL: ", err) - return base.AuthResult{}, fmt.Errorf("Error while creating request") + return base.AuthResult{}, errors.New("error while creating request") } tokenResponse, err := getTokenForRequest(ctx, req, client.httpClient) if err != nil { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 3d614fd5..4b16e830 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -4,9 +4,12 @@ package managedidentity import ( "context" + "encoding/json" "fmt" "io" "net/http" + "net/url" + "strconv" "strings" "testing" "time" @@ -15,6 +18,37 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" ) +const ( + // test Resources + resource = "https://demo.azure.com" + resourceDefaultSuffix = "https://demo.azure.com/.default" +) + +type HttpRequest struct { + Source Source + Resource string + Identity ID +} + +type SuccessfulResponse struct { + AccessToken string `json:"access_token"` + ExpiresOn int64 `json:"expires_on"` + Resource string `json:"resource"` + TokenType string `json:"token_type"` + ClientID string `json:"client_id"` +} + +type ErrorResponse struct { + StatusCode int `json:"statusCode"` + Message string `json:"message"` + CorrelationID string `json:"correlationId,omitempty"` +} + +type fakeClient struct{} +type errorClient struct { + errResponse ErrorResponse +} + func fakeMIClient(mangedIdentityId ID, options ...ClientOption) (Client, error) { fakeClient, err := New(mangedIdentityId, options...) @@ -25,71 +59,221 @@ func fakeMIClient(mangedIdentityId ID, options ...ClientOption) (Client, error) return fakeClient, nil } -type errorClient struct{} - +func (*fakeClient) CloseIdleConnections() {} func (*errorClient) CloseIdleConnections() {} -func (*errorClient) Do(req *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String()) -} -type FakeClient struct { - responseType int +func (*fakeClient) Do(req *http.Request) (*http.Response, error) { + w := http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(getSuccessfulResponse(resource))), + Header: make(http.Header), + } + return &w, nil } -func (c *FakeClient) CloseIdleConnections() {} -func (c *FakeClient) Do(req *http.Request) (*http.Response, error) { - println(c.responseType) - w := makeResponse(c.responseType) +func (e *errorClient) Do(req *http.Request) (*http.Response, error) { + w := http.Response{ + StatusCode: e.errResponse.StatusCode, + Body: io.NopCloser(strings.NewReader(makeResponseWithErrorData(e.errResponse))), + Header: make(http.Header), + } return &w, nil } -func makeResponse(responseType int) http.Response { - if responseType == 1 { - return http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(`{ - "access_token": "fakeToken", - "refresh_token": "", - "expires_in": "3599", - "expires_on": "1506484173", - "not_before": "1506480273", - "token_type": "Bearer" - }`)), - Header: make(http.Header), - } - } else { - return http.Response{ - StatusCode: http.StatusBadRequest, - Body: io.NopCloser(strings.NewReader(`{}`)), - Header: make(http.Header), +func getSuccessfulResponse(resource string) string { + expiresOn := time.Now().Add(1 * time.Hour).Unix() + response := SuccessfulResponse{ + AccessToken: "fakeToken", + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + ClientID: "client_id", + } + jsonResponse, _ := json.Marshal(response) + return string(jsonResponse) +} + +func makeResponseWithErrorData(errRsp ErrorResponse) string { + response := ErrorResponse{ + StatusCode: errRsp.StatusCode, + Message: errRsp.Message, + CorrelationID: errRsp.CorrelationID, + } + jsonResponse, _ := json.Marshal(response) + return string(jsonResponse) +} + +func getMsiErrorResponse() string { + response := ErrorResponse{ + StatusCode: 500, + Message: "An unexpected error occurred while fetching the AAD Token.", + CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", + } + jsonResponse, _ := json.Marshal(response) + return string(jsonResponse) +} + +func getMsiErrorResponseNotFound() string { + response := ErrorResponse{ + StatusCode: 500, + Message: "An unexpected error occurred while fetching the AAD Token.", + CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", + } + jsonResponse, _ := json.Marshal(response) + return string(jsonResponse) +} + +func getMsiErrorResponseNoRetry() string { + response := ErrorResponse{ + StatusCode: 123, + Message: "Not one of the retryable error responses", + CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", + } + jsonResponse, _ := json.Marshal(response) + return string(jsonResponse) +} + +func computeUri(endpoint string, queryParameters map[string][]string) string { + if len(queryParameters) == 0 { + return endpoint + } + + queryString := url.Values{} + for key, values := range queryParameters { + for _, value := range values { + queryString.Add(key, value) } } + + return endpoint + "?" + queryString.Encode() +} + +func expectedRequest(source Source, resource string, id ID) (*http.Request, error) { + return expectedRequestWithId(source, resource, id) } -func TestManagedIdentityIMDS_SAMISuccess(t *testing.T) { - fakeHTTPClient := FakeClient{responseType: 1} - client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) +func expectedRequestWithId(_ Source, resource string, id ID) (*http.Request, error) { + var endpoint string + headers := http.Header{} + queryParameters := make(map[string][]string) + + //check with source when added different sources. + endpoint = imdsEndpoint + queryParameters["api-version"] = []string{"2018-02-01"} + queryParameters["resource"] = []string{resource} + headers.Add("Metadata", "true") + + switch id.(type) { + case ClientID: + queryParameters[miQuerryParameterClientId] = []string{id.value()} + case ResourceID: + queryParameters[miQuerryParameterResourceId] = []string{id.value()} + case ObjectID: + queryParameters[miQuerryParameterObjectId] = []string{id.value()} + case systemAssignedValue: + // not adding anything + default: + return nil, fmt.Errorf("Type not supported") + } + uri, err := url.Parse(computeUri(endpoint, queryParameters)) if err != nil { - t.Fatal(err) + return nil, err } - result, err := client.AcquireToken(context.Background(), "fakeresource") + req := &http.Request{ + Method: "GET", + URL: uri, + Header: headers, + } - if err != nil { - t.Fatal("TestManagedIdentity: unexpected nil error from TestManagedIdentity") + return req, nil +} + +func ExpectedResponse(statusCode int, response string) http.Response { + return http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(response)), } +} + +type resourceTestData struct { + source Source + endpoint string + resource string +} - expected := accesstokens.TokenResponse{ - AccessToken: "fakeToken", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - TokenType: "Bearer", +func createResourceData() []resourceTestData { + return []resourceTestData{ + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix}, } - if result.AccessToken != expected.AccessToken { - t.Fatalf(`unexpected access token "%s"`, result.AccessToken) +} + +func Test_SystemAssigned_Returns_Token_Failure(t *testing.T) { + testCases := []ErrorResponse{ + {StatusCode: 404, Message: "IMDS service not available", CorrelationID: "121212"}, + {StatusCode: 501, Message: "Service error 1", CorrelationID: "121212"}, + {StatusCode: 503, Message: "Service error 2", CorrelationID: "121212"}, + {StatusCode: 400, Message: "invalid id", CorrelationID: "121212"}, } + for _, testCase := range testCases { + t.Run(strconv.Itoa(testCase.StatusCode), func(t *testing.T) { + fakeErrorClient := errorClient{errResponse: testCase} + client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + + if err != nil { + t.Fatal(err) + } + + resp, err := client.AcquireToken(context.Background(), resource) + + if resp.AccessToken != "" { + t.Fatalf("testManagedIdentity: accesstoken should be nil") + } + if err == nil { + t.Fatalf("testManagedIdentity: Should have encountered the error") + } + if err.Error() != fmt.Errorf("failed to authenticate with status code %d ", testCase.StatusCode).Error() { + t.Fatalf(`unexpected error "%s"`, err) + + } + }) + } +} + +func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { + testCases := createResourceData() + + for _, testCase := range testCases { + + t.Run(testCase.source.String(), func(t *testing.T) { + fakeHTTPClient := fakeClient{} + client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) + + if err != nil { + t.Fatal(err) + } + + result, err := client.AcquireToken(context.Background(), testCase.resource) + + if err != nil { + t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") + } + var tokenScope = []string{"the_scope"} + expected := accesstokens.TokenResponse{ + AccessToken: "fakeToken", + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "TokenType", + } + if result.AccessToken != expected.AccessToken { + t.Fatalf(`unexpected access token "%s"`, result.AccessToken) + } + }) + } } func TestCreateIMDSAuthRequest(t *testing.T) { @@ -166,6 +350,7 @@ func TestCreateIMDSAuthRequest(t *testing.T) { } for _, tt := range tests { + t.Log("0------") t.Run(tt.name, func(t *testing.T) { req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) if tt.wantErr { @@ -194,7 +379,7 @@ func TestCreateIMDSAuthRequest(t *testing.T) { t.Fatal("createIMDSAuthRequest() api-version missmatch") } - if query.Get(resourceQuerryParameterName) != removeSuffix(tt.resource, "/.default") { + if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") } @@ -215,33 +400,3 @@ func TestCreateIMDSAuthRequest(t *testing.T) { }) } } - -func TestManagedIdentityIMDS_SAMIHttpRequestFailure(t *testing.T) { - fakeHTTPClient := FakeClient{responseType: 2} - - client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) - - if err != nil { - t.Fatal(err) - } - - if _, err := client.AcquireToken(context.Background(), "fakeresource"); err == nil { - t.Fatal("TestManagedIdentity: Should have returned error for incorrect http request.") - } - -} - -func TestManagedIdentityIMDS_SAMIError(t *testing.T) { - fakeHTTPClient := errorClient{} - - client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) - - if err != nil { - t.Fatal(err) - } - - if _, err := client.AcquireToken(context.Background(), "fakeresource"); err == nil { - t.Fatal("TestManagedIdentity: Should have returned error for incorrect http request.") - } - -} From df2ad5a54bc6cb9f0a5e0cfa396d5f5c834b38d5 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 12 Sep 2024 00:42:55 +0100 Subject: [PATCH 12/32] Passed context to http request added context to request --- apps/managedidentity/managedidentity.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 2c49e028..f28e812b 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -150,7 +150,7 @@ func New(id ID, options ...ClientOption) (Client, error) { return client, nil } -func createIMDSAuthRequest(_ context.Context, id ID, resource string, claims string) (*http.Request, error) { +func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims string) (*http.Request, error) { var msiEndpoint *url.URL msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { @@ -193,14 +193,14 @@ func createIMDSAuthRequest(_ context.Context, id ID, resource string, claims str } msiEndpoint.RawQuery = msiParameters.Encode() - req, err := http.NewRequest(http.MethodGet, msiEndpoint.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) if err != nil { return nil, errors.New("error creating request") } return req, nil } -func getTokenForRequest(ctx context.Context, req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { +func getTokenForRequest(req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { req.Header.Add(metaHTTPHeadderName, "true") resp, err := httpClient.Do(req) if err != nil { From 287963e675320582d62646f537a180a840802a43 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 13 Sep 2024 15:55:36 +0100 Subject: [PATCH 13/32] Updated service errors handling and tests Updated the tests to check for errors more correctly --- apps/managedidentity/managedidentity.go | 93 +++++----- apps/managedidentity/managedidentity_test.go | 182 ++++++++++--------- apps/tests/devapps/main.go | 6 +- 3 files changed, 146 insertions(+), 135 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index f28e812b..f0f403fe 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -12,13 +12,13 @@ package managedidentity import ( "context" "encoding/json" - "errors" "fmt" "io" "net/http" "net/url" "strings" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" @@ -26,6 +26,7 @@ import ( ) const ( + // DefaultToIMDS indicates that the source is defaulted to IMDS when no environment variables are set. DefaultToIMDS Source = 0 AzureArc Source = 1 ServiceFabric Source = 2 @@ -53,14 +54,6 @@ const ( imdsAPIVersion = "2018-02-01" ) -const ( - // DefaultToIMDS indicates that the source is defaulted to IMDS since no environment variables are set. - defaultToIMDS = 0 - - // AzureArc represents the source to acquire token for managed identity is Azure Arc. - azureArc = 1 -) - type Source int func (s Source) String() string { @@ -114,7 +107,7 @@ type ClientOption func(o *ClientOptions) type AcquireTokenOption func(o *AcquireTokenOptions) -// WithClaims sets additional claims to request for the token, such as those required by conditional access policies. +// WithClaims sets additional claims to request for the token, such as those required by token revocation or conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. func WithClaims(claims string) AcquireTokenOption { return func(o *AcquireTokenOptions) { @@ -141,7 +134,23 @@ func New(id ID, options ...ClientOption) (Client, error) { for _, option := range options { option(&opts) } - + switch t := id.(type) { + case ClientID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("clientId parameter is empty for %T", t) + } + case ResourceID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("resourceID parameter is empty for %T", t) + } + case ObjectID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("objectID parameter is empty for %T", t) + } + case systemAssignedValue: + default: + return Client{}, fmt.Errorf("unsupported type %T", id) + } client := Client{ miType: id, httpClient: opts.httpClient, @@ -154,14 +163,11 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s var msiEndpoint *url.URL msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { - return nil, errors.New("error creating URL as parsing the URL filed") + return nil, fmt.Errorf("error creating URL \n %s", err.Error()) } - msiParameters := msiEndpoint.Query() msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") - resource = strings.TrimSuffix(resource, "/.default") - msiParameters.Add(resourceQuerryParameterName, resource) if len(claims) > 0 { @@ -170,23 +176,11 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s switch t := id.(type) { case ClientID: - if len(string(t)) > 0 { - msiParameters.Add(miQuerryParameterClientId, string(t)) - } else { - return nil, fmt.Errorf("clientId parameter is empty for %T", t) - } + msiParameters.Add(miQuerryParameterClientId, string(t)) case ResourceID: - if len(string(t)) > 0 { - msiParameters.Add(miQuerryParameterResourceId, string(t)) - } else { - return nil, fmt.Errorf("resourceID parameter is empty for %T", t) - } + msiParameters.Add(miQuerryParameterResourceId, string(t)) case ObjectID: - if len(string(t)) > 0 { - msiParameters.Add(miQuerryParameterObjectId, string(t)) - } else { - return nil, fmt.Errorf("objectID parameter is empty for %T", t) - } + msiParameters.Add(miQuerryParameterObjectId, string(t)) case systemAssignedValue: // not adding anything default: return nil, fmt.Errorf("unsupported type %T", id) @@ -195,28 +189,43 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s msiEndpoint.RawQuery = msiParameters.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) if err != nil { - return nil, errors.New("error creating request") + return nil, fmt.Errorf("error creating http request %s", err) } + req.Header.Add(metaHTTPHeadderName, "true") return req, nil } -func getTokenForRequest(req *http.Request, httpClient ops.HTTPClient) (accesstokens.TokenResponse, error) { - req.Header.Add(metaHTTPHeadderName, "true") - resp, err := httpClient.Do(req) +func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { + resp, err := client.httpClient.Do(req) if err != nil { return accesstokens.TokenResponse{}, err } - if resp.StatusCode != http.StatusOK { - return accesstokens.TokenResponse{}, fmt.Errorf("failed to authenticate with status code %d ", resp.StatusCode) - } - // Pull out response body responseBytes, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { return accesstokens.TokenResponse{}, err } - - // Unmarshall response body into struct + switch resp.StatusCode { + case 200, 201: + default: + sd := strings.TrimSpace(string(responseBytes)) + if sd != "" { + return accesstokens.TokenResponse{}, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", + req.URL.String(), + req.Method, + resp.StatusCode, + sd), + } + } + return accesstokens.TokenResponse{}, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, resp.StatusCode), + } + } var r accesstokens.TokenResponse err = json.Unmarshal(responseBytes, &r) if err != nil { @@ -237,9 +246,9 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options } req, err := createIMDSAuthRequest(ctx, client.miType, resource, o.claims) if err != nil { - return base.AuthResult{}, errors.New("error while creating request") + return base.AuthResult{}, err } - tokenResponse, err := getTokenForRequest(ctx, req, client.httpClient) + tokenResponse, err := client.getTokenForRequest(req) if err != nil { return base.AuthResult{}, err } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4b16e830..dce145ce 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" ) @@ -74,7 +75,7 @@ func (*fakeClient) Do(req *http.Request) (*http.Response, error) { func (e *errorClient) Do(req *http.Request) (*http.Response, error) { w := http.Response{ StatusCode: e.errResponse.StatusCode, - Body: io.NopCloser(strings.NewReader(makeResponseWithErrorData(e.errResponse))), + Body: io.NopCloser(strings.NewReader(e.errResponse.Message)), Header: make(http.Header), } return &w, nil @@ -103,36 +104,6 @@ func makeResponseWithErrorData(errRsp ErrorResponse) string { return string(jsonResponse) } -func getMsiErrorResponse() string { - response := ErrorResponse{ - StatusCode: 500, - Message: "An unexpected error occurred while fetching the AAD Token.", - CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", - } - jsonResponse, _ := json.Marshal(response) - return string(jsonResponse) -} - -func getMsiErrorResponseNotFound() string { - response := ErrorResponse{ - StatusCode: 500, - Message: "An unexpected error occurred while fetching the AAD Token.", - CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", - } - jsonResponse, _ := json.Marshal(response) - return string(jsonResponse) -} - -func getMsiErrorResponseNoRetry() string { - response := ErrorResponse{ - StatusCode: 123, - Message: "Not one of the retryable error responses", - CorrelationID: "7d0c9763-ff1d-4842-a3f3-6d49e64f4513", - } - jsonResponse, _ := json.Marshal(response) - return string(jsonResponse) -} - func computeUri(endpoint string, queryParameters map[string][]string) string { if len(queryParameters) == 0 { return endpoint @@ -212,33 +183,39 @@ func createResourceData() []resourceTestData { func Test_SystemAssigned_Returns_Token_Failure(t *testing.T) { testCases := []ErrorResponse{ - {StatusCode: 404, Message: "IMDS service not available", CorrelationID: "121212"}, - {StatusCode: 501, Message: "Service error 1", CorrelationID: "121212"}, - {StatusCode: 503, Message: "Service error 2", CorrelationID: "121212"}, - {StatusCode: 400, Message: "invalid id", CorrelationID: "121212"}, + {StatusCode: http.StatusNotFound, Message: ``, CorrelationID: "121212"}, + {StatusCode: http.StatusNotImplemented, Message: ``, CorrelationID: "121212"}, + {StatusCode: http.StatusServiceUnavailable, Message: ``, CorrelationID: "121212"}, + {StatusCode: http.StatusBadRequest, + Message: `{"error": "invalid_request", "error_description": "Identity not found"}`, + CorrelationID: "121212", + }, } for _, testCase := range testCases { t.Run(strconv.Itoa(testCase.StatusCode), func(t *testing.T) { fakeErrorClient := errorClient{errResponse: testCase} client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) - if err != nil { t.Fatal(err) } - resp, err := client.AcquireToken(context.Background(), resource) - - if resp.AccessToken != "" { - t.Fatalf("testManagedIdentity: accesstoken should be nil") - } if err == nil { t.Fatalf("testManagedIdentity: Should have encountered the error") } - if err.Error() != fmt.Errorf("failed to authenticate with status code %d ", testCase.StatusCode).Error() { - t.Fatalf(`unexpected error "%s"`, err) - + switch e := err.(type) { + case errors.CallErr: + if actual := err.Error(); !strings.Contains(e.Error(), testCase.Message) { + t.Fatalf("testManagedIdentity: expected response body in error, got %q", actual) + } + if e.Resp.StatusCode != testCase.StatusCode { + t.Fatal("testManagedIdentity: got unexpected status code.") + } + } + if resp.AccessToken != "" { + t.Fatalf("testManagedIdentity: accesstoken should be nil") } + }) } } @@ -276,6 +253,65 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { } } +func TestCreatingIMDSClient(t *testing.T) { + tests := []struct { + name string + id ID + wantErr bool + }{ + { + name: "System Assigned", + id: SystemAssigned(), + }, + { + name: "Client ID", + id: ClientID("test-client-id"), + }, + { + name: "Resource ID", + id: ResourceID("test-resource-id"), + }, + { + name: "Object ID", + id: ObjectID("test-object-id"), + }, + { + name: "Empty Client ID", + id: ClientID(""), + wantErr: true, + }, + { + name: "Empty Resource ID", + id: ResourceID(""), + wantErr: true, + }, + { + name: "Empty Object ID", + id: ObjectID(""), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := New(tt.id) + if tt.wantErr { + if err == nil { + t.Fatal("client New() should return a error but did not.") + } + return + } + if err != nil { + t.Fatal("client New() error while creating client") + } else { + if client.miType.value() != tt.id.value() { + t.Fatal("client New() did not assign a correct value to type.") + } + } + }) + } + +} func TestCreateIMDSAuthRequest(t *testing.T) { tests := []struct { name string @@ -288,69 +324,36 @@ func TestCreateIMDSAuthRequest(t *testing.T) { name: "System Assigned", id: SystemAssigned(), resource: "https://management.azure.com", - claims: "", - wantErr: false, }, { name: "System Assigned", id: SystemAssigned(), resource: "https://management.azure.com/.default", - claims: "", - wantErr: false, }, { name: "Client ID", id: ClientID("test-client-id"), resource: "https://storage.azure.com", - claims: "", - wantErr: false, }, { name: "Resource ID", id: ResourceID("test-resource-id"), resource: "https://vault.azure.net", - claims: "", - wantErr: false, }, { name: "Object ID", id: ObjectID("test-object-id"), resource: "https://graph.microsoft.com", - claims: "", - wantErr: false, }, { name: "With Claims", id: SystemAssigned(), resource: "https://management.azure.com", claims: "test-claims", - wantErr: false, - }, - { - name: "Empty Client ID", - id: ClientID(""), - resource: "https://management.azure.com", - claims: "", - wantErr: true, - }, - { - name: "Empty Resource ID", - id: ResourceID(""), - resource: "https://management.azure.com", - claims: "", - wantErr: true, - }, - { - name: "Empty Object ID", - id: ObjectID(""), - resource: "https://management.azure.com", - claims: "", - wantErr: true, }, } for _, tt := range tests { - t.Log("0------") t.Run(tt.name, func(t *testing.T) { req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) if tt.wantErr { @@ -359,44 +362,43 @@ func TestCreateIMDSAuthRequest(t *testing.T) { } return } - if req == nil { t.Fatal("createIMDSAuthRequest() returned nil request") return } - if req.Method != http.MethodGet { t.Fatal("createIMDSAuthRequest() method is not GET") } - if !strings.HasPrefix(req.URL.String(), imdsEndpoint) { t.Fatal("createIMDSAuthRequest() URL is not matched.") } - query := req.URL.Query() if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { t.Fatal("createIMDSAuthRequest() api-version missmatch") } - if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") } - - switch tt.id.(type) { + switch i := tt.id.(type) { case ClientID: - if query.Get(miQuerryParameterClientId) != tt.id.value() { - t.Fatal("createIMDSAuthRequest() client_id does not match with the id value") + if query.Get(miQuerryParameterClientId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource client-id is incorrect") } case ResourceID: - if query.Get(miQuerryParameterResourceId) != tt.id.value() { - t.Fatal("createIMDSAuthRequest() resource id does not match with the id value") + if query.Get(miQuerryParameterResourceId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource resource-id is incorrect") } case ObjectID: - if query.Get(miQuerryParameterObjectId) != tt.id.value() { - t.Fatal("createIMDSAuthRequest() object id does not match with the id value") + if query.Get(miQuerryParameterObjectId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource objectiid is incorrect") } + case systemAssignedValue: // not adding anything + default: + t.Fatal("createIMDSAuthRequest() unsupported type") + } + }) } } diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index b15d0a9e..e03239aa 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -36,9 +36,9 @@ func main() { } else if exampleType == "6" { // This sample does not use a serialized cache - it relies on in-memory cache by reusing the app object // This works well for app tokens, because there is only 1 token per resource, per tenant. - acquireTokenClientCertificate() - // this time the token comes from the cache! - acquireTokenClientCertificate() + // acquireTokenClientCertificate() + // // this time the token comes from the cache! + // acquireTokenClientCertificate() } else if exampleType == "7" { RunManagedIdentity() } From df9faf16fa6334d59906f059482fc294090b6fb2 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 16 Sep 2024 17:18:23 +0100 Subject: [PATCH 14/32] Updated tests to use mock Update some test and used mock.Client some refactoring for comments --- apps/internal/mock/mock.go | 8 + apps/managedidentity/managedidentity.go | 37 +-- apps/managedidentity/managedidentity_test.go | 226 ++++++------------- 3 files changed, 87 insertions(+), 184 deletions(-) diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 5de171fd..8a1d02e2 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -59,6 +59,14 @@ func (c *Client) AppendResponse(opts ...responseOption) { c.resp = append(c.resp, r) } +func (c *Client) AppendCustomResponse(status int, opts ...responseOption) { + r := response{code: status, headers: http.Header{}} + for _, o := range opts { + o.apply(&r) + } + c.resp = append(c.resp, r) +} + func (c *Client) Do(req *http.Request) (*http.Response, error) { if len(c.resp) == 0 { panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index f0f403fe..6763db42 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -27,11 +27,11 @@ import ( const ( // DefaultToIMDS indicates that the source is defaulted to IMDS when no environment variables are set. - DefaultToIMDS Source = 0 - AzureArc Source = 1 - ServiceFabric Source = 2 - CloudShell Source = 3 - AppService Source = 4 + DefaultToIMDS Source = "DefaultToIMDS" + AzureArc Source = "AzureArc" + ServiceFabric Source = "ServiceFabric" + CloudShell Source = "CloudShell" + AppService Source = "AppService" ) // General request querry parameter names @@ -54,24 +54,7 @@ const ( imdsAPIVersion = "2018-02-01" ) -type Source int - -func (s Source) String() string { - switch s { - case DefaultToIMDS: - return "DefaultToIMDS" - case AzureArc: - return "AzureArc" - case ServiceFabric: - return "ServiceFabric" - case CloudShell: - return "CloudShell" - case AppService: - return "AppService" - default: - return fmt.Sprintf("UnknownSource(%d)", s) - } -} +type Source string type ID interface { value() string @@ -93,6 +76,7 @@ func SystemAssigned() ID { type Client struct { httpClient ops.HTTPClient miType ID + source Source } type ClientOptions struct { @@ -206,7 +190,7 @@ func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRe return accesstokens.TokenResponse{}, err } switch resp.StatusCode { - case 200, 201: + case http.StatusOK, http.StatusAccepted: default: sd := strings.TrimSpace(string(responseBytes)) if sd != "" { @@ -228,10 +212,7 @@ func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRe } var r accesstokens.TokenResponse err = json.Unmarshal(responseBytes, &r) - if err != nil { - return accesstokens.TokenResponse{}, err - } - return r, nil + return r, err } // Acquires tokens from the configured managed identity on an azure resource. diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index dce145ce..cc4d37a3 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,31 +5,23 @@ package managedidentity import ( "context" "encoding/json" - "fmt" - "io" "net/http" - "net/url" "strconv" "strings" "testing" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" - internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" ) const ( // test Resources resource = "https://demo.azure.com" resourceDefaultSuffix = "https://demo.azure.com/.default" -) -type HttpRequest struct { - Source Source - Resource string - Identity ID -} + token = "fakeToken" +) type SuccessfulResponse struct { AccessToken string `json:"access_token"` @@ -39,133 +31,41 @@ type SuccessfulResponse struct { ClientID string `json:"client_id"` } -type ErrorResponse struct { - StatusCode int `json:"statusCode"` - Message string `json:"message"` - CorrelationID string `json:"correlationId,omitempty"` -} - -type fakeClient struct{} -type errorClient struct { - errResponse ErrorResponse -} - -func fakeMIClient(mangedIdentityId ID, options ...ClientOption) (Client, error) { - fakeClient, err := New(mangedIdentityId, options...) - - if err != nil { - return Client{}, err - } - - return fakeClient, nil +type ErrorRespone struct { + Err string `json:"error"` + Desc string `json:"error_description"` } -func (*fakeClient) CloseIdleConnections() {} -func (*errorClient) CloseIdleConnections() {} - -func (*fakeClient) Do(req *http.Request) (*http.Response, error) { - w := http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(getSuccessfulResponse(resource))), - Header: make(http.Header), - } - return &w, nil +type response struct { + body []byte + callback func(*http.Request) + code int } -func (e *errorClient) Do(req *http.Request) (*http.Response, error) { - w := http.Response{ - StatusCode: e.errResponse.StatusCode, - Body: io.NopCloser(strings.NewReader(e.errResponse.Message)), - Header: make(http.Header), - } - return &w, nil -} - -func getSuccessfulResponse(resource string) string { +func getSuccessfulResponse(resource string) []byte { expiresOn := time.Now().Add(1 * time.Hour).Unix() response := SuccessfulResponse{ - AccessToken: "fakeToken", + AccessToken: token, ExpiresOn: expiresOn, Resource: resource, TokenType: "Bearer", ClientID: "client_id", } jsonResponse, _ := json.Marshal(response) - return string(jsonResponse) + return jsonResponse } -func makeResponseWithErrorData(errRsp ErrorResponse) string { - response := ErrorResponse{ - StatusCode: errRsp.StatusCode, - Message: errRsp.Message, - CorrelationID: errRsp.CorrelationID, +func makeResponseWithErrorData(err string, desc string) []byte { + responseBody := ErrorRespone{ + Err: err, + Desc: desc, } - jsonResponse, _ := json.Marshal(response) - return string(jsonResponse) -} - -func computeUri(endpoint string, queryParameters map[string][]string) string { - if len(queryParameters) == 0 { - return endpoint - } - - queryString := url.Values{} - for key, values := range queryParameters { - for _, value := range values { - queryString.Add(key, value) - } - } - - return endpoint + "?" + queryString.Encode() -} - -func expectedRequest(source Source, resource string, id ID) (*http.Request, error) { - return expectedRequestWithId(source, resource, id) -} - -func expectedRequestWithId(_ Source, resource string, id ID) (*http.Request, error) { - var endpoint string - headers := http.Header{} - queryParameters := make(map[string][]string) - - //check with source when added different sources. - endpoint = imdsEndpoint - queryParameters["api-version"] = []string{"2018-02-01"} - queryParameters["resource"] = []string{resource} - headers.Add("Metadata", "true") - - switch id.(type) { - case ClientID: - queryParameters[miQuerryParameterClientId] = []string{id.value()} - case ResourceID: - queryParameters[miQuerryParameterResourceId] = []string{id.value()} - case ObjectID: - queryParameters[miQuerryParameterObjectId] = []string{id.value()} - case systemAssignedValue: - // not adding anything - default: - return nil, fmt.Errorf("Type not supported") - } - - uri, err := url.Parse(computeUri(endpoint, queryParameters)) - if err != nil { - return nil, err - } - - req := &http.Request{ - Method: "GET", - URL: uri, - Header: headers, - } - - return req, nil -} - -func ExpectedResponse(statusCode int, response string) http.Response { - return http.Response{ - StatusCode: statusCode, - Body: io.NopCloser(strings.NewReader(response)), + if len(err) == 0 && len(desc) == 0 { + jsonResponse, _ := json.Marshal(responseBody) + return jsonResponse } + jsonResponse, _ := json.Marshal(responseBody) + return jsonResponse } type resourceTestData struct { @@ -174,6 +74,13 @@ type resourceTestData struct { resource string } +type errorTestData struct { + code int + err string + desc string + correlationid string +} + func createResourceData() []resourceTestData { return []resourceTestData{ {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource}, @@ -181,21 +88,32 @@ func createResourceData() []resourceTestData { } } -func Test_SystemAssigned_Returns_Token_Failure(t *testing.T) { - testCases := []ErrorResponse{ - {StatusCode: http.StatusNotFound, Message: ``, CorrelationID: "121212"}, - {StatusCode: http.StatusNotImplemented, Message: ``, CorrelationID: "121212"}, - {StatusCode: http.StatusServiceUnavailable, Message: ``, CorrelationID: "121212"}, - {StatusCode: http.StatusBadRequest, - Message: `{"error": "invalid_request", "error_description": "Identity not found"}`, - CorrelationID: "121212", +func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { + testCases := []errorTestData{ + {code: http.StatusNotFound, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusNotImplemented, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusServiceUnavailable, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusBadRequest, + err: "invalid_request", + desc: "Identity not found", + correlationid: "121212", }, } for _, testCase := range testCases { - t.Run(strconv.Itoa(testCase.StatusCode), func(t *testing.T) { - fakeErrorClient := errorClient{errResponse: testCase} - client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + t.Run(strconv.Itoa(testCase.code), func(t *testing.T) { + fakeErrorClient := mock.Client{} + fakeErrorClient.AppendCustomResponse(testCase.code, mock.WithBody(makeResponseWithErrorData(testCase.err, testCase.desc))) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) if err != nil { t.Fatal(err) } @@ -203,51 +121,47 @@ func Test_SystemAssigned_Returns_Token_Failure(t *testing.T) { if err == nil { t.Fatalf("testManagedIdentity: Should have encountered the error") } - switch e := err.(type) { - case errors.CallErr: - if actual := err.Error(); !strings.Contains(e.Error(), testCase.Message) { - t.Fatalf("testManagedIdentity: expected response body in error, got %q", actual) + var callErr errors.CallErr + if errors.As(err, &callErr) { + callErr = err.(errors.CallErr) + if !strings.Contains(err.Error(), testCase.err) { + t.Fatalf("testManagedIdentity: expected message '%s' in error, got %q", testCase.err, callErr.Error()) } - if e.Resp.StatusCode != testCase.StatusCode { - t.Fatal("testManagedIdentity: got unexpected status code.") + if callErr.Resp.StatusCode != testCase.code { + t.Fatalf("testManagedIdentity: expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) } + } else { + t.Fatalf("testManagedIdentity: expected error of type %T, got %T", callErr.Error(), err) } if resp.AccessToken != "" { t.Fatalf("testManagedIdentity: accesstoken should be nil") } - }) } } func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { testCases := createResourceData() - for _, testCase := range testCases { t.Run(testCase.source.String(), func(t *testing.T) { - fakeHTTPClient := fakeClient{} - client, err := fakeMIClient(SystemAssigned(), WithHTTPClient(&fakeHTTPClient)) + var url string + mockClient := mock.Client{} + mockClient.AppendCustomResponse(http.StatusOK, mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) + client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } - result, err := client.AcquireToken(context.Background(), testCase.resource) - - if err != nil { - t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") + if !strings.HasPrefix(url, testCase.endpoint) { + t.Fatalf("TestManagedIdentity: URL request is not on %s fgot %s", testCase.endpoint, url) } - var tokenScope = []string{"the_scope"} - expected := accesstokens.TokenResponse{ - AccessToken: "fakeToken", - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, - TokenType: "TokenType", + if err != nil { + t.Fatalf("TestManagedIdentity: unexpected nil error from TestManagedIdentity %s", err.Error()) } - if result.AccessToken != expected.AccessToken { - t.Fatalf(`unexpected access token "%s"`, result.AccessToken) + if result.AccessToken != token { + t.Fatalf("TestManagedIdentity: wanted %q, got %q", token, result.AccessToken) } }) } From 5395b9ac5e4be72cbccf1530ac594d4b043edb25 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 16 Sep 2024 17:19:27 +0100 Subject: [PATCH 15/32] small update --- apps/managedidentity/managedidentity_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index cc4d37a3..4f8975de 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -144,7 +144,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { testCases := createResourceData() for _, testCase := range testCases { - t.Run(testCase.source.String(), func(t *testing.T) { + t.Run(string(testCase.source), func(t *testing.T) { var url string mockClient := mock.Client{} mockClient.AppendCustomResponse(http.StatusOK, mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) From 6a72df2e03f4273c207ca526390bfb872ce8a80f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 17 Sep 2024 15:34:34 +0100 Subject: [PATCH 16/32] Added a withStatusCode method in mock --- apps/internal/mock/mock.go | 15 +++++++-------- apps/managedidentity/managedidentity_test.go | 5 +++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 8a1d02e2..853950bc 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,6 +46,13 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } +// WithCallback sets a callback to invoke before returning the response. +func WithHttpStatusCode(statusCode int) responseOption { + return respOpt(func(r *response) { + r.code = statusCode + }) +} + // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { resp []response @@ -59,14 +66,6 @@ func (c *Client) AppendResponse(opts ...responseOption) { c.resp = append(c.resp, r) } -func (c *Client) AppendCustomResponse(status int, opts ...responseOption) { - r := response{code: status, headers: http.Header{}} - for _, o := range opts { - o.apply(&r) - } - c.resp = append(c.resp, r) -} - func (c *Client) Do(req *http.Request) (*http.Response, error) { if len(c.resp) == 0 { panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4f8975de..9926fbbb 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -112,7 +112,8 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { for _, testCase := range testCases { t.Run(strconv.Itoa(testCase.code), func(t *testing.T) { fakeErrorClient := mock.Client{} - fakeErrorClient.AppendCustomResponse(testCase.code, mock.WithBody(makeResponseWithErrorData(testCase.err, testCase.desc))) + fakeErrorClient.AppendResponse(mock.WithHttpStatusCode(testCase.code), + mock.WithBody(makeResponseWithErrorData(testCase.err, testCase.desc))) client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) if err != nil { t.Fatal(err) @@ -147,7 +148,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { t.Run(string(testCase.source), func(t *testing.T) { var url string mockClient := mock.Client{} - mockClient.AppendCustomResponse(http.StatusOK, mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) + mockClient.AppendResponse(mock.WithHttpStatusCode(http.StatusOK), mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) if err != nil { From b293a60a91c54b0934d20de608cf9292c4f149e9 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:36:50 +0100 Subject: [PATCH 17/32] Update apps/internal/mock/mock.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/internal/mock/mock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 853950bc..255e5223 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -47,7 +47,7 @@ func WithCallback(callback func(*http.Request)) responseOption { } // WithCallback sets a callback to invoke before returning the response. -func WithHttpStatusCode(statusCode int) responseOption { +func WithHTTPStatusCode(statusCode int) responseOption { return respOpt(func(r *response) { r.code = statusCode }) From c2b9127e8e802e776190f0162444c2ad218da286 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 17 Sep 2024 17:38:09 +0100 Subject: [PATCH 18/32] Updated the method usage for WithHTTPStatusCode --- apps/managedidentity/managedidentity_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 9926fbbb..de3f11c1 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -112,7 +112,7 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { for _, testCase := range testCases { t.Run(strconv.Itoa(testCase.code), func(t *testing.T) { fakeErrorClient := mock.Client{} - fakeErrorClient.AppendResponse(mock.WithHttpStatusCode(testCase.code), + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), mock.WithBody(makeResponseWithErrorData(testCase.err, testCase.desc))) client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) if err != nil { @@ -148,7 +148,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { t.Run(string(testCase.source), func(t *testing.T) { var url string mockClient := mock.Client{} - mockClient.AppendResponse(mock.WithHttpStatusCode(http.StatusOK), mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) if err != nil { From e451611eab9cc3a9fec1b3732f0447983401b4e5 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:13:38 +0100 Subject: [PATCH 19/32] Update apps/managedidentity/managedidentity_test.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index de3f11c1..71d40bbb 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -284,8 +284,8 @@ func TestCreateIMDSAuthRequest(t *testing.T) { if req.Method != http.MethodGet { t.Fatal("createIMDSAuthRequest() method is not GET") } - if !strings.HasPrefix(req.URL.String(), imdsEndpoint) { - t.Fatal("createIMDSAuthRequest() URL is not matched.") + if got := req.URL.String(); !strings.HasPrefix(got, imdsEndpoint) { + t.Fatalf("wanted %q, got %q", imdsEndpoint, got) } query := req.URL.Query() From 9912ee92157e98784a1cb60f4aed1eadaffe92c0 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:14:04 +0100 Subject: [PATCH 20/32] Update apps/managedidentity/managedidentity_test.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 71d40bbb..31f1868c 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -279,7 +279,6 @@ func TestCreateIMDSAuthRequest(t *testing.T) { } if req == nil { t.Fatal("createIMDSAuthRequest() returned nil request") - return } if req.Method != http.MethodGet { t.Fatal("createIMDSAuthRequest() method is not GET") From 7f147d4a85711d41ecfcb9bb28743412aca98bcd Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 20 Sep 2024 15:17:34 +0100 Subject: [PATCH 21/32] Removed typed data from test --- apps/managedidentity/managedidentity_test.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index de3f11c1..8f32e5b7 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -81,13 +81,6 @@ type errorTestData struct { correlationid string } -func createResourceData() []resourceTestData { - return []resourceTestData{ - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix}, - } -} - func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { testCases := []errorTestData{ {code: http.StatusNotFound, @@ -142,7 +135,10 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { - testCases := createResourceData() + testCases := []resourceTestData{ + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix}, + } for _, testCase := range testCases { t.Run(string(testCase.source), func(t *testing.T) { From 82b11551d4968354706806a3e6ecdd167ac490b5 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 20 Sep 2024 15:30:20 +0100 Subject: [PATCH 22/32] Updated test to return json error --- apps/managedidentity/managedidentity_test.go | 28 +++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 9210bbc9..4c1e2f6b 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -42,7 +42,7 @@ type response struct { code int } -func getSuccessfulResponse(resource string) []byte { +func getSuccessfulResponse(resource string) ([]byte, error) { expiresOn := time.Now().Add(1 * time.Hour).Unix() response := SuccessfulResponse{ AccessToken: token, @@ -51,21 +51,21 @@ func getSuccessfulResponse(resource string) []byte { TokenType: "Bearer", ClientID: "client_id", } - jsonResponse, _ := json.Marshal(response) - return jsonResponse + jsonResponse, err := json.Marshal(response) + return jsonResponse, err } -func makeResponseWithErrorData(err string, desc string) []byte { +func makeResponseWithErrorData(err string, desc string) ([]byte, error) { responseBody := ErrorRespone{ Err: err, Desc: desc, } if len(err) == 0 && len(desc) == 0 { - jsonResponse, _ := json.Marshal(responseBody) - return jsonResponse + jsonResponse, error := json.Marshal(responseBody) + return jsonResponse, error } - jsonResponse, _ := json.Marshal(responseBody) - return jsonResponse + jsonResponse, error := json.Marshal(responseBody) + return jsonResponse, error } type resourceTestData struct { @@ -105,8 +105,12 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { for _, testCase := range testCases { t.Run(strconv.Itoa(testCase.code), func(t *testing.T) { fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) + if err != nil { + t.Fatalf("TestManagedIdentity: Error while forming json response : %s", err.Error()) + } fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), - mock.WithBody(makeResponseWithErrorData(testCase.err, testCase.desc))) + mock.WithBody(responseBody)) client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) if err != nil { t.Fatal(err) @@ -144,7 +148,11 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { t.Run(string(testCase.source), func(t *testing.T) { var url string mockClient := mock.Client{} - mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(getSuccessfulResponse(resource)), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) + responseBody, err := getSuccessfulResponse(resource) + if err != nil { + t.Fatalf("TestManagedIdentity: Error while forming json response : %s", err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) if err != nil { From 522883a7d3db8788780aba63da75af2fd5ddd95f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 20 Sep 2024 15:35:51 +0100 Subject: [PATCH 23/32] Updating sample app Removed printing token Variable name updated. --- apps/tests/devapps/managedidentity_sample.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index 95a67d9d..b49ec23f 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -12,9 +12,9 @@ func RunManagedIdentity() { if err != nil { fmt.Println(err) } - temp, err := miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") + result, err := miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") if err != nil { println(err.Error()) } - fmt.Println("token : ", temp.AccessToken) + fmt.Println("token : ", result.ExpiresOn) } From 6ad761f10b31feebdacd1671fde8ba48e28607fa Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 23 Sep 2024 17:11:45 +0100 Subject: [PATCH 24/32] Updated the MI identity for UAMI with "UserAssigned" as prefix Updated the MI identity for UAMI with "UserAssigned" as prefix --- apps/internal/mock/mock.go | 2 +- apps/managedidentity/managedidentity.go | 72 ++++----- apps/managedidentity/managedidentity_test.go | 154 ++++--------------- apps/tests/devapps/managedidentity_sample.go | 4 +- 4 files changed, 63 insertions(+), 169 deletions(-) diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 255e5223..af684a32 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,7 +46,7 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } -// WithCallback sets a callback to invoke before returning the response. +// WithHTTPStatusCode sets the HTTP statusCode of response to the specified value. func WithHTTPStatusCode(statusCode int) responseOption { return respOpt(func(r *response) { r.code = statusCode diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 6763db42..1eb9e56e 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -32,24 +32,18 @@ const ( ServiceFabric Source = "ServiceFabric" CloudShell Source = "CloudShell" AppService Source = "AppService" -) -// General request querry parameter names -const ( - metaHTTPHeadderName = "Metadata" + // General request querry parameter names + metaHTTPHeaderName = "Metadata" apiVersionQuerryParameterName = "api-version" resourceQuerryParameterName = "resource" -) -// UAMI querry parameter name -const ( - miQuerryParameterClientId = "client_id" - miQuerryParameterObjectId = "object_id" - miQuerryParameterResourceId = "msi_res_id" -) + // UAMI querry parameter name + miQueryParameterClientId = "client_id" + miQueryParameterObjectId = "object_id" + miQueryParameterResourceId = "msi_res_id" -// IMDS -const ( + // IMDS imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" imdsAPIVersion = "2018-02-01" ) @@ -61,14 +55,14 @@ type ID interface { } type systemAssignedValue string // its private for a reason to make the input consistent. -type ClientID string -type ObjectID string -type ResourceID string - -func (s systemAssignedValue) value() string { return string(s) } -func (c ClientID) value() string { return string(c) } -func (o ObjectID) value() string { return string(o) } -func (r ResourceID) value() string { return string(r) } +type UserAssignedClientID string +type UserAssignedObjectID string +type UserAssignedResourceID string + +func (s systemAssignedValue) value() string { return string(s) } +func (c UserAssignedClientID) value() string { return string(c) } +func (o UserAssignedObjectID) value() string { return string(o) } +func (r UserAssignedResourceID) value() string { return string(r) } func SystemAssigned() ID { return systemAssignedValue("") } @@ -107,7 +101,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { } // Client to be used to acquire tokens for managed identity. -// ID: [SystemAssigned()], [ClientID("clientID")], [ResourceID("resourceID")], [ObjectID("objectID")] +// ID: [SystemAssigned()], [UserAssignedClientID("clientID")], [UserAssignedResourceID("resourceID")], [UserAssignedObjectID("objectID")] // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { @@ -119,17 +113,17 @@ func New(id ID, options ...ClientOption) (Client, error) { option(&opts) } switch t := id.(type) { - case ClientID: + case UserAssignedClientID: if len(string(t)) == 0 { - return Client{}, fmt.Errorf("clientId parameter is empty for %T", t) + return Client{}, fmt.Errorf("empty %T", t) } - case ResourceID: + case UserAssignedResourceID: if len(string(t)) == 0 { - return Client{}, fmt.Errorf("resourceID parameter is empty for %T", t) + return Client{}, fmt.Errorf("empty %T", t) } - case ObjectID: + case UserAssignedObjectID: if len(string(t)) == 0 { - return Client{}, fmt.Errorf("objectID parameter is empty for %T", t) + return Client{}, fmt.Errorf("empty %T", t) } case systemAssignedValue: default: @@ -147,24 +141,24 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s var msiEndpoint *url.URL msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { - return nil, fmt.Errorf("error creating URL \n %s", err.Error()) + return nil, fmt.Errorf("couldn't parse %q: %s", imdsEndpoint, err) } msiParameters := msiEndpoint.Query() - msiParameters.Add(apiVersionQuerryParameterName, "2018-02-01") + msiParameters.Set(apiVersionQuerryParameterName, "2018-02-01") resource = strings.TrimSuffix(resource, "/.default") - msiParameters.Add(resourceQuerryParameterName, resource) + msiParameters.Set(resourceQuerryParameterName, resource) if len(claims) > 0 { - msiParameters.Add("claims", claims) + msiParameters.Set("claims", claims) } switch t := id.(type) { - case ClientID: - msiParameters.Add(miQuerryParameterClientId, string(t)) - case ResourceID: - msiParameters.Add(miQuerryParameterResourceId, string(t)) - case ObjectID: - msiParameters.Add(miQuerryParameterObjectId, string(t)) + case UserAssignedClientID: + msiParameters.Set(miQueryParameterClientId, string(t)) + case UserAssignedResourceID: + msiParameters.Set(miQueryParameterResourceId, string(t)) + case UserAssignedObjectID: + msiParameters.Set(miQueryParameterObjectId, string(t)) case systemAssignedValue: // not adding anything default: return nil, fmt.Errorf("unsupported type %T", id) @@ -175,7 +169,7 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s if err != nil { return nil, fmt.Errorf("error creating http request %s", err) } - req.Header.Add(metaHTTPHeadderName, "true") + req.Header.Set(metaHTTPHeaderName, "true") return req, nil } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4c1e2f6b..7b5f2a9a 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "net/http" - "strconv" "strings" "testing" "time" @@ -36,12 +35,6 @@ type ErrorRespone struct { Desc string `json:"error_description"` } -type response struct { - body []byte - callback func(*http.Request) - code int -} - func getSuccessfulResponse(resource string) ([]byte, error) { expiresOn := time.Now().Add(1 * time.Hour).Unix() response := SuccessfulResponse{ @@ -60,18 +53,15 @@ func makeResponseWithErrorData(err string, desc string) ([]byte, error) { Err: err, Desc: desc, } - if len(err) == 0 && len(desc) == 0 { - jsonResponse, error := json.Marshal(responseBody) - return jsonResponse, error - } - jsonResponse, error := json.Marshal(responseBody) - return jsonResponse, error + jsonResponse, e := json.Marshal(responseBody) + return jsonResponse, e } type resourceTestData struct { source Source endpoint string resource string + miType ID } type errorTestData struct { @@ -103,11 +93,11 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } for _, testCase := range testCases { - t.Run(strconv.Itoa(testCase.code), func(t *testing.T) { + t.Run(http.StatusText(testCase.code), func(t *testing.T) { fakeErrorClient := mock.Client{} responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) if err != nil { - t.Fatalf("TestManagedIdentity: Error while forming json response : %s", err.Error()) + t.Fatalf("error while forming json response : %s", err.Error()) } fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), mock.WithBody(responseBody)) @@ -117,22 +107,21 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { } resp, err := client.AcquireToken(context.Background(), resource) if err == nil { - t.Fatalf("testManagedIdentity: Should have encountered the error") + t.Fatalf("should have encountered the error") } var callErr errors.CallErr if errors.As(err, &callErr) { - callErr = err.(errors.CallErr) if !strings.Contains(err.Error(), testCase.err) { - t.Fatalf("testManagedIdentity: expected message '%s' in error, got %q", testCase.err, callErr.Error()) + t.Fatalf("expected message '%s' in error, got %q", testCase.err, callErr.Error()) } if callErr.Resp.StatusCode != testCase.code { - t.Fatalf("testManagedIdentity: expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) + t.Fatalf("expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) } } else { - t.Fatalf("testManagedIdentity: expected error of type %T, got %T", callErr.Error(), err) + t.Fatalf("expected error of type %T, got %T", callErr, err) } if resp.AccessToken != "" { - t.Fatalf("testManagedIdentity: accesstoken should be nil") + t.Fatalf("accesstoken should be empty") } }) } @@ -140,33 +129,34 @@ func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { testCases := []resourceTestData{ - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: SystemAssigned()}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned()}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedClientID("asd")}, } for _, testCase := range testCases { t.Run(string(testCase.source), func(t *testing.T) { - var url string + url := testCase.endpoint mockClient := mock.Client{} responseBody, err := getSuccessfulResponse(resource) if err != nil { - t.Fatalf("TestManagedIdentity: Error while forming json response : %s", err.Error()) + t.Fatalf("error while forming json response : %s", err.Error()) } - mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { url = r.URL.String() })) - client, err := New(SystemAssigned(), WithHTTPClient(&mockClient)) + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody)) + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } result, err := client.AcquireToken(context.Background(), testCase.resource) if !strings.HasPrefix(url, testCase.endpoint) { - t.Fatalf("TestManagedIdentity: URL request is not on %s fgot %s", testCase.endpoint, url) + t.Fatalf("url request is not on %s got %s", testCase.endpoint, url) } if err != nil { - t.Fatalf("TestManagedIdentity: unexpected nil error from TestManagedIdentity %s", err.Error()) + t.Fatal(err) } if result.AccessToken != token { - t.Fatalf("TestManagedIdentity: wanted %q, got %q", token, result.AccessToken) + t.Fatalf("wanted %q, got %q", token, result.AccessToken) } }) } @@ -184,29 +174,29 @@ func TestCreatingIMDSClient(t *testing.T) { }, { name: "Client ID", - id: ClientID("test-client-id"), + id: UserAssignedClientID("test-client-id"), }, { name: "Resource ID", - id: ResourceID("test-resource-id"), + id: UserAssignedResourceID("test-resource-id"), }, { name: "Object ID", - id: ObjectID("test-object-id"), + id: UserAssignedObjectID("test-object-id"), }, { name: "Empty Client ID", - id: ClientID(""), + id: UserAssignedClientID(""), wantErr: true, }, { name: "Empty Resource ID", - id: ResourceID(""), + id: UserAssignedResourceID(""), wantErr: true, }, { name: "Empty Object ID", - id: ObjectID(""), + id: UserAssignedObjectID(""), wantErr: true, }, } @@ -221,7 +211,7 @@ func TestCreatingIMDSClient(t *testing.T) { return } if err != nil { - t.Fatal("client New() error while creating client") + t.Fatal(err) } else { if client.miType.value() != tt.id.value() { t.Fatal("client New() did not assign a correct value to type.") @@ -229,94 +219,4 @@ func TestCreatingIMDSClient(t *testing.T) { } }) } - -} -func TestCreateIMDSAuthRequest(t *testing.T) { - tests := []struct { - name string - id ID - resource string - claims string - wantErr bool - }{ - { - name: "System Assigned", - id: SystemAssigned(), - resource: "https://management.azure.com", - }, - { - name: "System Assigned", - id: SystemAssigned(), - resource: "https://management.azure.com/.default", - }, - { - name: "Client ID", - id: ClientID("test-client-id"), - resource: "https://storage.azure.com", - }, - { - name: "Resource ID", - id: ResourceID("test-resource-id"), - resource: "https://vault.azure.net", - }, - { - name: "Object ID", - id: ObjectID("test-object-id"), - resource: "https://graph.microsoft.com", - }, - { - name: "With Claims", - id: SystemAssigned(), - resource: "https://management.azure.com", - claims: "test-claims", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) - if tt.wantErr { - if err == nil { - t.Fatal(err) - } - return - } - if req == nil { - t.Fatal("createIMDSAuthRequest() returned nil request") - } - if req.Method != http.MethodGet { - t.Fatal("createIMDSAuthRequest() method is not GET") - } - if got := req.URL.String(); !strings.HasPrefix(got, imdsEndpoint) { - t.Fatalf("wanted %q, got %q", imdsEndpoint, got) - } - query := req.URL.Query() - - if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { - t.Fatal("createIMDSAuthRequest() api-version missmatch") - } - if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { - t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") - } - switch i := tt.id.(type) { - case ClientID: - if query.Get(miQuerryParameterClientId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource client-id is incorrect") - } - case ResourceID: - if query.Get(miQuerryParameterResourceId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource resource-id is incorrect") - } - case ObjectID: - if query.Get(miQuerryParameterObjectId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource objectiid is incorrect") - } - case systemAssignedValue: // not adding anything - default: - t.Fatal("createIMDSAuthRequest() unsupported type") - - } - - }) - } } diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index b49ec23f..65e71b54 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -14,7 +14,7 @@ func RunManagedIdentity() { } result, err := miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") if err != nil { - println(err.Error()) + fmt.Println(err) } - fmt.Println("token : ", result.ExpiresOn) + fmt.Println("token expire at : ", result.ExpiresOn) } From 1dcad54c9aab2a11323d3543bf7b25d9e892fe36 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 11:21:17 +0100 Subject: [PATCH 25/32] Added Correct response format in test --- apps/managedidentity/managedidentity_test.go | 50 ++++++++++++++++---- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 7b5f2a9a..3ad12b3e 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,6 +5,7 @@ package managedidentity import ( "context" "encoding/json" + "fmt" "net/http" "strings" "testing" @@ -25,9 +26,12 @@ const ( type SuccessfulResponse struct { AccessToken string `json:"access_token"` ExpiresOn int64 `json:"expires_on"` + ExpiresIn int64 `json:"expires_in"` Resource string `json:"resource"` TokenType string `json:"token_type"` ClientID string `json:"client_id"` + ObjectID string `json:"object_id"` + ResourceID string `json:"msi_res_id"` } type ErrorRespone struct { @@ -35,14 +39,44 @@ type ErrorRespone struct { Desc string `json:"error_description"` } -func getSuccessfulResponse(resource string) ([]byte, error) { +func getSuccessfulResponse(resource string, miType ID) ([]byte, error) { expiresOn := time.Now().Add(1 * time.Hour).Unix() - response := SuccessfulResponse{ - AccessToken: token, - ExpiresOn: expiresOn, - Resource: resource, - TokenType: "Bearer", - ClientID: "client_id", + var response SuccessfulResponse + switch miType.(type) { + case UserAssignedClientID: + response = SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + ClientID: "client_id", + } + case UserAssignedResourceID: + response = SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + ResourceID: "msi_res_id", + } + case UserAssignedObjectID: + response = SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + ObjectID: "object_id", + } + case systemAssignedValue: + response = SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + ObjectID: "object_id", + } + default: + return nil, fmt.Errorf("unsupported type %T", miType) } jsonResponse, err := json.Marshal(response) return jsonResponse, err @@ -138,7 +172,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { t.Run(string(testCase.source), func(t *testing.T) { url := testCase.endpoint mockClient := mock.Client{} - responseBody, err := getSuccessfulResponse(resource) + responseBody, err := getSuccessfulResponse(resource, testCase.miType) if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } From 149c6aa01e0de885c1acaef9478b0692bf71fd6e Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 15:14:56 +0100 Subject: [PATCH 26/32] Removed Elements from the response that were not used --- apps/managedidentity/managedidentity_test.go | 46 +++----------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 3ad12b3e..4ee78833 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -5,7 +5,6 @@ package managedidentity import ( "context" "encoding/json" - "fmt" "net/http" "strings" "testing" @@ -29,9 +28,6 @@ type SuccessfulResponse struct { ExpiresIn int64 `json:"expires_in"` Resource string `json:"resource"` TokenType string `json:"token_type"` - ClientID string `json:"client_id"` - ObjectID string `json:"object_id"` - ResourceID string `json:"msi_res_id"` } type ErrorRespone struct { @@ -41,43 +37,13 @@ type ErrorRespone struct { func getSuccessfulResponse(resource string, miType ID) ([]byte, error) { expiresOn := time.Now().Add(1 * time.Hour).Unix() - var response SuccessfulResponse - switch miType.(type) { - case UserAssignedClientID: - response = SuccessfulResponse{ - AccessToken: token, - ExpiresOn: expiresOn, - Resource: resource, - TokenType: "Bearer", - ClientID: "client_id", - } - case UserAssignedResourceID: - response = SuccessfulResponse{ - AccessToken: token, - ExpiresOn: expiresOn, - Resource: resource, - TokenType: "Bearer", - ResourceID: "msi_res_id", - } - case UserAssignedObjectID: - response = SuccessfulResponse{ - AccessToken: token, - ExpiresOn: expiresOn, - Resource: resource, - TokenType: "Bearer", - ObjectID: "object_id", - } - case systemAssignedValue: - response = SuccessfulResponse{ - AccessToken: token, - ExpiresOn: expiresOn, - Resource: resource, - TokenType: "Bearer", - ObjectID: "object_id", - } - default: - return nil, fmt.Errorf("unsupported type %T", miType) + response := SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", } + jsonResponse, err := json.Marshal(response) return jsonResponse, err } From e24ca264474aa2ae1a5d7fe3a693c82e9bbcd0b0 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 15:27:11 +0100 Subject: [PATCH 27/32] Removed un used fields reformatted code --- apps/managedidentity/managedidentity_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4ee78833..f10b5fff 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -25,7 +25,6 @@ const ( type SuccessfulResponse struct { AccessToken string `json:"access_token"` ExpiresOn int64 `json:"expires_on"` - ExpiresIn int64 `json:"expires_in"` Resource string `json:"resource"` TokenType string `json:"token_type"` } @@ -43,7 +42,6 @@ func getSuccessfulResponse(resource string, miType ID) ([]byte, error) { Resource: resource, TokenType: "Bearer", } - jsonResponse, err := json.Marshal(response) return jsonResponse, err } From cac4441a9ab4dfa575bef159d921213fc3440d51 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 17:56:35 +0100 Subject: [PATCH 28/32] Removed unused vairable. --- apps/managedidentity/managedidentity_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index f10b5fff..997ac3fe 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -34,7 +34,7 @@ type ErrorRespone struct { Desc string `json:"error_description"` } -func getSuccessfulResponse(resource string, miType ID) ([]byte, error) { +func getSuccessfulResponse(resource string) ([]byte, error) { expiresOn := time.Now().Add(1 * time.Hour).Unix() response := SuccessfulResponse{ AccessToken: token, @@ -136,7 +136,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { t.Run(string(testCase.source), func(t *testing.T) { url := testCase.endpoint mockClient := mock.Client{} - responseBody, err := getSuccessfulResponse(resource, testCase.miType) + responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } From d967d312f10b526ef0ce3461da5bad03955ec5d6 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary <107404295+4gust@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:12:19 +0100 Subject: [PATCH 29/32] Update apps/managedidentity/managedidentity.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/managedidentity/managedidentity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 1eb9e56e..4e7a010d 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -101,7 +101,7 @@ func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { } // Client to be used to acquire tokens for managed identity. -// ID: [SystemAssigned()], [UserAssignedClientID("clientID")], [UserAssignedResourceID("resourceID")], [UserAssignedObjectID("objectID")] +// ID: [SystemAssigned], [UserAssignedClientID], [UserAssignedResourceID], [UserAssignedObjectID] // // Options: [WithHTTPClient] func New(id ID, options ...ClientOption) (Client, error) { From 3367c044b6b971a633c24624c4a49aa2351c1940 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 21:54:16 +0100 Subject: [PATCH 30/32] Updated to have more coverage Updated to have more coverage --- apps/managedidentity/managedidentity_test.go | 138 +++++++++++++++++-- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 997ac3fe..f4cee963 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "net/http" + "net/url" "strings" "testing" "time" @@ -129,26 +130,55 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { testCases := []resourceTestData{ {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: SystemAssigned()}, {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned()}, - {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedClientID("asd")}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: UserAssignedClientID("clientId")}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, } for _, testCase := range testCases { t.Run(string(testCase.source), func(t *testing.T) { - url := testCase.endpoint + var localUrl *url.URL mockClient := mock.Client{} responseBody, err := getSuccessfulResponse(resource) if err != nil { t.Fatalf("error while forming json response : %s", err.Error()) } - mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody)) + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) client, err := New(testCase.miType, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } result, err := client.AcquireToken(context.Background(), testCase.resource) - if !strings.HasPrefix(url, testCase.endpoint) { - t.Fatalf("url request is not on %s got %s", testCase.endpoint, url) + if !strings.HasPrefix(localUrl.String(), testCase.endpoint) { + t.Fatalf("url request is not on %s got %s", testCase.endpoint, localUrl) + } + if !strings.Contains(localUrl.String(), testCase.miType.value()) { + t.Fatalf("url request does not contain the %s got %s", testCase.endpoint, localUrl) + } + query := localUrl.Query() + + if query.Get(apiVersionQuerryParameterName) != imdsAPIVersion { + t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQuerryParameterName)) + } + if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + switch i := testCase.miType.(type) { + case UserAssignedClientID: + if query.Get(miQueryParameterClientId) != i.value() { + t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId)) + } + case UserAssignedResourceID: + if query.Get(miQueryParameterResourceId) != i.value() { + t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) + } + case UserAssignedObjectID: + if query.Get(miQueryParameterObjectId) != i.value() { + t.Fatalf("resource objectiid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) + } } if err != nil { t.Fatal(err) @@ -156,8 +186,99 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { if result.AccessToken != token { t.Fatalf("wanted %q, got %q", token, result.AccessToken) } + + }) + } + + // Testing createIMDSAuthRequest + tests := []struct { + name string + id ID + resource string + claims string + wantErr bool + }{ + { + name: "System Assigned", + id: SystemAssigned(), + resource: "https://management.azure.com", + }, + { + name: "System Assigned", + id: SystemAssigned(), + resource: "https://management.azure.com/.default", + }, + { + name: "Client ID", + id: UserAssignedClientID("test-client-id"), + resource: "https://storage.azure.com", + }, + { + name: "Resource ID", + id: UserAssignedResourceID("test-resource-id"), + resource: "https://vault.azure.net", + }, + { + name: "Object ID", + id: UserAssignedObjectID("test-object-id"), + resource: "https://graph.microsoft.com", + }, + { + name: "With Claims", + id: SystemAssigned(), + resource: "https://management.azure.com", + claims: "test-claims", + }, + } + // testing IMDSAuthRequest Creation method. + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) + if tt.wantErr { + if err == nil { + t.Fatal(err) + } + return + } + if req == nil { + t.Fatal("createIMDSAuthRequest() returned nil request") + } + if req.Method != http.MethodGet { + t.Fatal("createIMDSAuthRequest() method is not GET") + } + if got := req.URL.String(); !strings.HasPrefix(got, imdsEndpoint) { + t.Fatalf("wanted %q, got %q", imdsEndpoint, got) + } + query := req.URL.Query() + + if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { + t.Fatal("createIMDSAuthRequest() api-version missmatch") + } + if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { + t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") + } + switch i := tt.id.(type) { + case UserAssignedClientID: + if query.Get(miQueryParameterClientId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource client-id is incorrect") + } + case UserAssignedResourceID: + if query.Get(miQueryParameterResourceId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource resource-id is incorrect") + } + case UserAssignedObjectID: + if query.Get(miQueryParameterObjectId) != i.value() { + t.Fatal("createIMDSAuthRequest() resource objectiid is incorrect") + } + case systemAssignedValue: // not adding anything + default: + t.Fatal("createIMDSAuthRequest() unsupported type") + + } + }) } + } func TestCreatingIMDSClient(t *testing.T) { @@ -210,10 +331,9 @@ func TestCreatingIMDSClient(t *testing.T) { } if err != nil { t.Fatal(err) - } else { - if client.miType.value() != tt.id.value() { - t.Fatal("client New() did not assign a correct value to type.") - } + } + if client.miType.value() != tt.id.value() { + t.Fatal("client New() did not assign a correct value to type.") } }) } From 795cd670768ea74ba4ad4de16ab33dcc9b11a9d3 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 24 Sep 2024 21:59:07 +0100 Subject: [PATCH 31/32] Updated tests to test request --- apps/managedidentity/managedidentity.go | 2 +- apps/managedidentity/managedidentity_test.go | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 4e7a010d..8a66cbf5 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -144,7 +144,7 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims s return nil, fmt.Errorf("couldn't parse %q: %s", imdsEndpoint, err) } msiParameters := msiEndpoint.Query() - msiParameters.Set(apiVersionQuerryParameterName, "2018-02-01") + msiParameters.Set(apiVersionQuerryParameterName, imdsAPIVersion) resource = strings.TrimSuffix(resource, "/.default") msiParameters.Set(resourceQuerryParameterName, resource) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index f4cee963..6085dba7 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -177,7 +177,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { } case UserAssignedObjectID: if query.Get(miQueryParameterObjectId) != i.value() { - t.Fatalf("resource objectiid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) + t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) } } if err != nil { @@ -241,38 +241,38 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { return } if req == nil { - t.Fatal("createIMDSAuthRequest() returned nil request") + t.Fatal("returned nil request") } if req.Method != http.MethodGet { - t.Fatal("createIMDSAuthRequest() method is not GET") + t.Fatalf("method is wannted GET, got %s", req.Method) } if got := req.URL.String(); !strings.HasPrefix(got, imdsEndpoint) { t.Fatalf("wanted %q, got %q", imdsEndpoint, got) } query := req.URL.Query() - if query.Get(apiVersionQuerryParameterName) != "2018-02-01" { - t.Fatal("createIMDSAuthRequest() api-version missmatch") + if query.Get(apiVersionQuerryParameterName) != imdsAPIVersion { + t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQuerryParameterName)) } if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { - t.Fatal("createIMDSAuthRequest() resource does not ahve suffix removed ") + t.Fatal("resource does not have suffix removed") } switch i := tt.id.(type) { case UserAssignedClientID: if query.Get(miQueryParameterClientId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource client-id is incorrect") + t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId)) } case UserAssignedResourceID: if query.Get(miQueryParameterResourceId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource resource-id is incorrect") + t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) } case UserAssignedObjectID: if query.Get(miQueryParameterObjectId) != i.value() { - t.Fatal("createIMDSAuthRequest() resource objectiid is incorrect") + t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) } case systemAssignedValue: // not adding anything default: - t.Fatal("createIMDSAuthRequest() unsupported type") + t.Fatal(" unsupported type") } From 6b9cd68c0fb828dc8aa02a91e2e3388534e54623 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 25 Sep 2024 09:00:15 +0100 Subject: [PATCH 32/32] Removed some tests which were redundant --- apps/managedidentity/managedidentity_test.go | 90 -------------------- apps/tests/devapps/main.go | 2 +- 2 files changed, 1 insertion(+), 91 deletions(-) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 6085dba7..e813f1bc 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -189,96 +189,6 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { }) } - - // Testing createIMDSAuthRequest - tests := []struct { - name string - id ID - resource string - claims string - wantErr bool - }{ - { - name: "System Assigned", - id: SystemAssigned(), - resource: "https://management.azure.com", - }, - { - name: "System Assigned", - id: SystemAssigned(), - resource: "https://management.azure.com/.default", - }, - { - name: "Client ID", - id: UserAssignedClientID("test-client-id"), - resource: "https://storage.azure.com", - }, - { - name: "Resource ID", - id: UserAssignedResourceID("test-resource-id"), - resource: "https://vault.azure.net", - }, - { - name: "Object ID", - id: UserAssignedObjectID("test-object-id"), - resource: "https://graph.microsoft.com", - }, - { - name: "With Claims", - id: SystemAssigned(), - resource: "https://management.azure.com", - claims: "test-claims", - }, - } - // testing IMDSAuthRequest Creation method. - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req, err := createIMDSAuthRequest(context.Background(), tt.id, tt.resource, tt.claims) - if tt.wantErr { - if err == nil { - t.Fatal(err) - } - return - } - if req == nil { - t.Fatal("returned nil request") - } - if req.Method != http.MethodGet { - t.Fatalf("method is wannted GET, got %s", req.Method) - } - if got := req.URL.String(); !strings.HasPrefix(got, imdsEndpoint) { - t.Fatalf("wanted %q, got %q", imdsEndpoint, got) - } - query := req.URL.Query() - - if query.Get(apiVersionQuerryParameterName) != imdsAPIVersion { - t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQuerryParameterName)) - } - if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(tt.resource, "/.default") { - t.Fatal("resource does not have suffix removed") - } - switch i := tt.id.(type) { - case UserAssignedClientID: - if query.Get(miQueryParameterClientId) != i.value() { - t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId)) - } - case UserAssignedResourceID: - if query.Get(miQueryParameterResourceId) != i.value() { - t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) - } - case UserAssignedObjectID: - if query.Get(miQueryParameterObjectId) != i.value() { - t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) - } - case systemAssignedValue: // not adding anything - default: - t.Fatal(" unsupported type") - - } - - }) - } - } func TestCreatingIMDSClient(t *testing.T) { diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index e03239aa..f927468f 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -36,7 +36,7 @@ func main() { } else if exampleType == "6" { // This sample does not use a serialized cache - it relies on in-memory cache by reusing the app object // This works well for app tokens, because there is only 1 token per resource, per tenant. - // acquireTokenClientCertificate() + acquireTokenClientCertificate() // // this time the token comes from the cache! // acquireTokenClientCertificate() } else if exampleType == "7" {