diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 5de171fd..af684a32 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,6 +46,13 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } +// WithHTTPStatusCode sets the HTTP statusCode of response to the specified value. +func WithHTTPStatusCode(statusCode int) responseOption { + return respOpt(func(r *response) { + r.code = statusCode + }) +} + // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { resp []response diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index ddbe71b0..8a66cbf5 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -11,127 +11,221 @@ package managedidentity import ( "context" + "encoding/json" "fmt" - "sync" + "io" + "net/http" + "net/url" + "strings" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared" ) const ( - // DefaultToIMDS indicates that the source is defaulted to IMDS since no environment variables are set. - DefaultToIMDS = 0 - - // AzureArc represents the source to acquire token for managed identity is Azure Arc. - AzureArc = 1 + // DefaultToIMDS indicates that the source is defaulted to IMDS when no environment variables are set. + DefaultToIMDS Source = "DefaultToIMDS" + AzureArc Source = "AzureArc" + ServiceFabric Source = "ServiceFabric" + CloudShell Source = "CloudShell" + AppService Source = "AppService" + + // General request querry parameter names + metaHTTPHeaderName = "Metadata" + apiVersionQuerryParameterName = "api-version" + resourceQuerryParameterName = "resource" + + // UAMI querry parameter name + miQueryParameterClientId = "client_id" + miQueryParameterObjectId = "object_id" + miQueryParameterResourceId = "msi_res_id" + + // IMDS + imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + imdsAPIVersion = "2018-02-01" ) -// Client is a client that provides access to Managed Identity token calls. -type Client struct { - AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). also may remove from here - cacheAccessorMu *sync.RWMutex - // base ops.HTTPClient - // managedIdentityType Type - // Token *oauth.Client - // pmanager manager // todo : expose the manager from base. - // cacheAccessor cache.ExportReplace -} - -// clientOptions are optional settings for New(). These options are set using various functions -// returning Option calls. -type clientOptions struct { - claims string // bypasses cache, does nothing else - httpClient ops.HTTPClient - // disableInstanceDiscovery bool // always false - // clientId string -} - -type withClaimsOption struct{ Claims string } -type withHTTPClientOption struct{ HttpClient ops.HTTPClient } - -// Option is an optional argument to New(). -type Option interface{ apply(*clientOptions) } -type ClientOption interface{ ClientOption() } -type AcquireTokenOption interface{ AcquireTokenOption() } - -// Source represents the managed identity sources supported. -type Source int - -type systemAssignedValue string +type Source string type ID interface { value() string } +type systemAssignedValue string // its private for a reason to make the input consistent. +type UserAssignedClientID string +type UserAssignedObjectID string +type UserAssignedResourceID string + +func (s systemAssignedValue) value() string { return string(s) } +func (c UserAssignedClientID) value() string { return string(c) } +func (o UserAssignedObjectID) value() string { return string(o) } +func (r UserAssignedResourceID) value() string { return string(r) } func SystemAssigned() ID { return systemAssignedValue("") } -type ClientID string -type ObjectID string -type ResourceID string +type Client struct { + httpClient ops.HTTPClient + miType ID + source Source +} -func (s systemAssignedValue) value() string { return string(s) } -func (c ClientID) value() string { return string(c) } -func (o ObjectID) value() string { return string(o) } -func (r ResourceID) value() string { return string(r) } +type ClientOptions struct { + httpClient ops.HTTPClient +} -func (w withClaimsOption) AcquireTokenOption() {} -func (w withHTTPClientOption) AcquireTokenOption() {} -func (w withHTTPClientOption) apply(opts *clientOptions) { opts.httpClient = w.HttpClient } +type AcquireTokenOptions struct { + claims string +} -// WithClaims sets additional claims to request for the token, such as those required by conditional access policies. +type ClientOption func(o *ClientOptions) + +type AcquireTokenOption func(o *AcquireTokenOptions) + +// WithClaims sets additional claims to request for the token, such as those required by token revocation or conditional access policies. // Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded. func WithClaims(claims string) AcquireTokenOption { - return withClaimsOption{Claims: claims} + return func(o *AcquireTokenOptions) { + o.claims = claims + } } // WithHTTPClient allows for a custom HTTP client to be set. -func WithHTTPClient(httpClient ops.HTTPClient) Option { - return withHTTPClientOption{HttpClient: httpClient} +func WithHTTPClient(httpClient ops.HTTPClient) ClientOption { + return func(o *ClientOptions) { + o.httpClient = httpClient + } } // Client to be used to acquire tokens for managed identity. -// ID: [SystemAssigned()], [ClientID("clientID")], [ResourceID("resourceID")], [ObjectID("objectID")] +// ID: [SystemAssigned], [UserAssignedClientID], [UserAssignedResourceID], [UserAssignedObjectID] // // Options: [WithHTTPClient] -func New(id ID, options ...Option) (Client, error) { - fmt.Println("idType: ", id.value()) - - opts := clientOptions{ - claims: "claims", +func New(id ID, options ...ClientOption) (Client, error) { + opts := ClientOptions{ + httpClient: shared.DefaultClient, } for _, option := range options { - option.apply(&opts) + option(&opts) } + switch t := id.(type) { + case UserAssignedClientID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case UserAssignedResourceID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case UserAssignedObjectID: + if len(string(t)) == 0 { + return Client{}, fmt.Errorf("empty %T", t) + } + case systemAssignedValue: + default: + return Client{}, fmt.Errorf("unsupported type %T", id) + } + client := Client{ + miType: id, + httpClient: opts.httpClient, + } + + return client, nil +} - authInfo, err := authority.NewInfoFromAuthorityURI("authorityURI", true, false) +func createIMDSAuthRequest(ctx context.Context, id ID, resource string, claims string) (*http.Request, error) { + var msiEndpoint *url.URL + msiEndpoint, err := url.Parse(imdsEndpoint) if err != nil { - return Client{}, err + return nil, fmt.Errorf("couldn't parse %q: %s", imdsEndpoint, err) + } + msiParameters := msiEndpoint.Query() + msiParameters.Set(apiVersionQuerryParameterName, imdsAPIVersion) + resource = strings.TrimSuffix(resource, "/.default") + msiParameters.Set(resourceQuerryParameterName, resource) + + if len(claims) > 0 { + msiParameters.Set("claims", claims) } - authParams := authority.NewAuthParams(id.value(), authInfo) - client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go - AuthParams: authParams, - cacheAccessorMu: &sync.RWMutex{}, - // manager: storage.New(token), - // pmanager: storage.NewPartitionedManager(token), + switch t := id.(type) { + case UserAssignedClientID: + msiParameters.Set(miQueryParameterClientId, string(t)) + case UserAssignedResourceID: + msiParameters.Set(miQueryParameterResourceId, string(t)) + case UserAssignedObjectID: + msiParameters.Set(miQueryParameterObjectId, string(t)) + case systemAssignedValue: // not adding anything + default: + return nil, fmt.Errorf("unsupported type %T", id) } - return client, err + msiEndpoint.RawQuery = msiParameters.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, msiEndpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("error creating http request %s", err) + } + req.Header.Set(metaHTTPHeaderName, "true") + return req, nil +} + +func (client Client) getTokenForRequest(req *http.Request) (accesstokens.TokenResponse, error) { + resp, err := client.httpClient.Do(req) + if err != nil { + return accesstokens.TokenResponse{}, err + } + responseBytes, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return accesstokens.TokenResponse{}, err + } + switch resp.StatusCode { + case http.StatusOK, http.StatusAccepted: + default: + sd := strings.TrimSpace(string(responseBytes)) + if sd != "" { + return accesstokens.TokenResponse{}, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", + req.URL.String(), + req.Method, + resp.StatusCode, + sd), + } + } + return accesstokens.TokenResponse{}, errors.CallErr{ + Req: req, + Resp: resp, + Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, resp.StatusCode), + } + } + var r accesstokens.TokenResponse + err = json.Unmarshal(responseBytes, &r) + return r, err } // Acquires tokens from the configured managed identity on an azure resource. // // Resource: scopes application is requesting access to // Options: [WithClaims] -func (client Client) AcquireToken(context context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { - return base.AuthResult{}, nil -} +func (client Client) AcquireToken(ctx context.Context, resource string, options ...AcquireTokenOption) (base.AuthResult, error) { + o := AcquireTokenOptions{} -// Detects and returns the managed identity source available on the environment. -func GetSource() Source { - return DefaultToIMDS + for _, option := range options { + option(&o) + } + req, err := createIMDSAuthRequest(ctx, client.miType, resource, o.claims) + if err != nil { + return base.AuthResult{}, err + } + tokenResponse, err := client.getTokenForRequest(req) + if err != nil { + return base.AuthResult{}, err + } + return base.NewAuthResult(tokenResponse, shared.Account{}) } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 4bf34540..e813f1bc 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -4,43 +4,247 @@ package managedidentity import ( "context" + "encoding/json" + "net/http" + "net/url" + "strings" "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" ) -func fakeClient(mangedIdentityId ID, options ...Option) (Client, error) { - client, err := New(mangedIdentityId, options...) +const ( + // test Resources + resource = "https://demo.azure.com" + resourceDefaultSuffix = "https://demo.azure.com/.default" - if err != nil { - return Client{}, err - } + token = "fakeToken" +) + +type SuccessfulResponse struct { + AccessToken string `json:"access_token"` + ExpiresOn int64 `json:"expires_on"` + Resource string `json:"resource"` + TokenType string `json:"token_type"` +} - return client, nil +type ErrorRespone struct { + Err string `json:"error"` + Desc string `json:"error_description"` } -func TestManagedIdentity(t *testing.T) { - client, err := fakeClient(SystemAssigned()) +func getSuccessfulResponse(resource string) ([]byte, error) { + expiresOn := time.Now().Add(1 * time.Hour).Unix() + response := SuccessfulResponse{ + AccessToken: token, + ExpiresOn: expiresOn, + Resource: resource, + TokenType: "Bearer", + } + jsonResponse, err := json.Marshal(response) + return jsonResponse, err +} - if err != nil { - t.Fatal(err) +func makeResponseWithErrorData(err string, desc string) ([]byte, error) { + responseBody := ErrorRespone{ + Err: err, + Desc: desc, } + jsonResponse, e := json.Marshal(responseBody) + return jsonResponse, e +} - _, err = client.AcquireToken(context.Background(), "scope", WithClaims("claim")) +type resourceTestData struct { + source Source + endpoint string + resource string + miType ID +} - if err == nil { - t.Errorf("TestManagedIdentity: unexpected nil error from TestManagedIdentity") +type errorTestData struct { + code int + err string + desc string + correlationid string +} + +func Test_SystemAssigned_Returns_AcquireToken_Failure(t *testing.T) { + testCases := []errorTestData{ + {code: http.StatusNotFound, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusNotImplemented, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusServiceUnavailable, + err: "", + desc: "", + correlationid: "121212"}, + {code: http.StatusBadRequest, + err: "invalid_request", + desc: "Identity not found", + correlationid: "121212", + }, + } + + for _, testCase := range testCases { + t.Run(http.StatusText(testCase.code), func(t *testing.T) { + fakeErrorClient := mock.Client{} + responseBody, err := makeResponseWithErrorData(testCase.err, testCase.desc) + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + fakeErrorClient.AppendResponse(mock.WithHTTPStatusCode(testCase.code), + mock.WithBody(responseBody)) + client, err := New(SystemAssigned(), WithHTTPClient(&fakeErrorClient)) + if err != nil { + t.Fatal(err) + } + resp, err := client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatalf("should have encountered the error") + } + var callErr errors.CallErr + if errors.As(err, &callErr) { + if !strings.Contains(err.Error(), testCase.err) { + t.Fatalf("expected message '%s' in error, got %q", testCase.err, callErr.Error()) + } + if callErr.Resp.StatusCode != testCase.code { + t.Fatalf("expected status code %d, got %d", testCase.code, callErr.Resp.StatusCode) + } + } else { + t.Fatalf("expected error of type %T, got %T", callErr, err) + } + if resp.AccessToken != "" { + t.Fatalf("accesstoken should be empty") + } + }) } } -func TestManagedIdentityWithClaims(t *testing.T) { - client, err := fakeClient(ClientID("123")) +func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { + testCases := []resourceTestData{ + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: SystemAssigned()}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned()}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resource, miType: UserAssignedClientID("clientId")}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")}, + {source: DefaultToIMDS, endpoint: imdsEndpoint, resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")}, + } + for _, testCase := range testCases { + + t.Run(string(testCase.source), func(t *testing.T) { + var localUrl *url.URL + mockClient := mock.Client{} + responseBody, err := getSuccessfulResponse(resource) + if err != nil { + t.Fatalf("error while forming json response : %s", err.Error()) + } + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) { + localUrl = r.URL + })) + client, err := New(testCase.miType, WithHTTPClient(&mockClient)) - if err != nil { - t.Fatal(err) + if err != nil { + t.Fatal(err) + } + result, err := client.AcquireToken(context.Background(), testCase.resource) + if !strings.HasPrefix(localUrl.String(), testCase.endpoint) { + t.Fatalf("url request is not on %s got %s", testCase.endpoint, localUrl) + } + if !strings.Contains(localUrl.String(), testCase.miType.value()) { + t.Fatalf("url request does not contain the %s got %s", testCase.endpoint, localUrl) + } + query := localUrl.Query() + + if query.Get(apiVersionQuerryParameterName) != imdsAPIVersion { + t.Fatalf("api-version not on %s got %s", imdsAPIVersion, query.Get(apiVersionQuerryParameterName)) + } + if query.Get(resourceQuerryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") { + t.Fatal("suffix /.default was not removed.") + } + switch i := testCase.miType.(type) { + case UserAssignedClientID: + if query.Get(miQueryParameterClientId) != i.value() { + t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId)) + } + case UserAssignedResourceID: + if query.Get(miQueryParameterResourceId) != i.value() { + t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId)) + } + case UserAssignedObjectID: + if query.Get(miQueryParameterObjectId) != i.value() { + t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId)) + } + } + if err != nil { + t.Fatal(err) + } + if result.AccessToken != token { + t.Fatalf("wanted %q, got %q", token, result.AccessToken) + } + + }) } +} - _, err = client.AcquireToken(context.Background(), "scope", WithClaims("claim")) +func TestCreatingIMDSClient(t *testing.T) { + tests := []struct { + name string + id ID + wantErr bool + }{ + { + name: "System Assigned", + id: SystemAssigned(), + }, + { + name: "Client ID", + id: UserAssignedClientID("test-client-id"), + }, + { + name: "Resource ID", + id: UserAssignedResourceID("test-resource-id"), + }, + { + name: "Object ID", + id: UserAssignedObjectID("test-object-id"), + }, + { + name: "Empty Client ID", + id: UserAssignedClientID(""), + wantErr: true, + }, + { + name: "Empty Resource ID", + id: UserAssignedResourceID(""), + wantErr: true, + }, + { + name: "Empty Object ID", + id: UserAssignedObjectID(""), + wantErr: true, + }, + } - if err == nil { - t.Errorf("TestManagedIdentityWithClaims: unexpected nil error from TestManagedIdentityWithClaims") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := New(tt.id) + if tt.wantErr { + if err == nil { + t.Fatal("client New() should return a error but did not.") + } + return + } + if err != nil { + t.Fatal(err) + } + if client.miType.value() != tt.id.value() { + t.Fatal("client New() did not assign a correct value to type.") + } + }) } } diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index 027bd6e4..f927468f 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -5,7 +5,7 @@ import ( ) var ( - //config = CreateConfig("config.json") + config = CreateConfig("config.json") cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) @@ -13,7 +13,7 @@ func main() { ctx := context.Background() // Choose a sammple to run. - exampleType := "5" + exampleType := "7" if exampleType == "1" { acquireTokenDeviceCode() @@ -37,8 +37,7 @@ func main() { // This sample does not use a serialized cache - it relies on in-memory cache by reusing the app object // This works well for app tokens, because there is only 1 token per resource, per tenant. acquireTokenClientCertificate() - - // this time the token comes from the cache! + // // this time the token comes from the cache! // acquireTokenClientCertificate() } else if exampleType == "7" { RunManagedIdentity() diff --git a/apps/tests/devapps/managedidentity_sample.go b/apps/tests/devapps/managedidentity_sample.go index 337b1128..65e71b54 100644 --- a/apps/tests/devapps/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity_sample.go @@ -3,36 +3,18 @@ package main import ( "context" "fmt" - "net/http" mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" ) func RunManagedIdentity() { - customHttpClient := &http.Client{} - - miSystemAssigned, error := mi.New(mi.SystemAssigned()) - if error != nil { - fmt.Println(error) - } - - miClientIdAssigned, error := mi.New(mi.ClientID("client id 123"), mi.WithHTTPClient(customHttpClient)) - if error != nil { - fmt.Println(error) + miSystemAssigned, err := mi.New(mi.SystemAssigned()) + if err != nil { + fmt.Println(err) } - - miResourceIdAssigned, error := mi.New(mi.ResourceID("resource id 123")) - if error != nil { - fmt.Println(error) + result, err := miSystemAssigned.AcquireToken(context.Background(), "https://management.azure.com/") + if err != nil { + fmt.Println(err) } - - miObjectIdAssigned, error := mi.New(mi.ObjectID("object id 123")) - if error != nil { - fmt.Println(error) - } - - miSystemAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) - miClientIdAssigned.AcquireToken(context.Background(), "resource") - miResourceIdAssigned.AcquireToken(context.Background(), "resource", mi.WithClaims("claim")) - miObjectIdAssigned.AcquireToken(context.Background(), "resource") + fmt.Println("token expire at : ", result.ExpiresOn) }