diff --git a/auth/auth.go b/auth/auth.go index 45726b73..6b428061 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -65,7 +65,7 @@ func (c *Config) NewAuthorizer(ctx context.Context, api Api) (Authorizer, error) } if c.EnableMsiAuth { - a, err := NewMsiAuthorizer(ctx, c.Environment, api, c.MsiEndpoint) + a, err := NewMsiAuthorizer(ctx, c.Environment, api, c.MsiEndpoint, c.ClientID) if err != nil { return nil, fmt.Errorf("could not configure MSI Authorizer: %s", err) } @@ -97,8 +97,8 @@ func NewAzureCliAuthorizer(ctx context.Context, api Api, tenantId string) (Autho } // NewMsiAuthorizer returns an authorizer which uses managed service identity to for authentication. -func NewMsiAuthorizer(ctx context.Context, environment environments.Environment, api Api, msiEndpoint string) (Authorizer, error) { - conf, err := NewMsiConfig(ctx, resource(environment, api), msiEndpoint) +func NewMsiAuthorizer(ctx context.Context, environment environments.Environment, api Api, msiEndpoint, clientId string) (Authorizer, error) { + conf, err := NewMsiConfig(ctx, resource(environment, api), msiEndpoint, clientId) if err != nil { return nil, err } diff --git a/auth/auth_test.go b/auth/auth_test.go index d15ffa21..76450012 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -122,7 +122,7 @@ func TestMsiAuthorizer(t *testing.T) { done <- true }() } - auth, err := auth.NewMsiAuthorizer(ctx, environments.Global, auth.MsGraph, msiEndpoint) + auth, err := auth.NewMsiAuthorizer(ctx, environments.Global, auth.MsGraph, msiEndpoint, clientId) if err != nil { t.Fatalf("NewMsiAuthorizer(): %v", err) } diff --git a/auth/azcli.go b/auth/azcli.go index a81c3ce9..ad943ac9 100644 --- a/auth/azcli.go +++ b/auth/azcli.go @@ -31,7 +31,11 @@ type AzureCliAuthorizer struct { } // Token returns an access token using the Azure CLI as an authentication mechanism. -func (a AzureCliAuthorizer) Token() (*oauth2.Token, error) { +func (a *AzureCliAuthorizer) Token() (*oauth2.Token, error) { + if a.conf == nil { + return nil, fmt.Errorf("could not request token: conf is nil") + } + var token struct { AccessToken string `json:"accessToken"` ExpiresOn string `json:"expiresOn"` diff --git a/auth/clientcredentials.go b/auth/clientcredentials.go index 78d8edab..f12cd469 100644 --- a/auth/clientcredentials.go +++ b/auth/clientcredentials.go @@ -169,7 +169,11 @@ type clientAssertionAuthorizer struct { conf *ClientCredentialsConfig } -func (a clientAssertionAuthorizer) Token() (*oauth2.Token, error) { +func (a *clientAssertionAuthorizer) Token() (*oauth2.Token, error) { + if a.conf == nil { + return nil, fmt.Errorf("could not request token: conf is nil") + } + crt := a.conf.Certificate if der, _ := pem.Decode(a.conf.Certificate); der != nil { crt = der.Bytes @@ -246,7 +250,11 @@ type clientSecretAuthorizer struct { conf *ClientCredentialsConfig } -func (a clientSecretAuthorizer) Token() (*oauth2.Token, error) { +func (a *clientSecretAuthorizer) Token() (*oauth2.Token, error) { + if a.conf == nil { + return nil, fmt.Errorf("could not request token: conf is nil") + } + v := url.Values{ "client_id": {a.conf.ClientID}, "client_secret": {a.conf.ClientSecret}, diff --git a/auth/msi.go b/auth/msi.go index 909ce0b2..e87f18e8 100644 --- a/auth/msi.go +++ b/auth/msi.go @@ -27,10 +27,19 @@ type MsiAuthorizer struct { // Token returns an access token acquired from the metadata endpoint. func (a *MsiAuthorizer) Token() (*oauth2.Token, error) { + if a.conf == nil { + return nil, fmt.Errorf("could not request token: conf is nil") + } + query := url.Values{ "api-version": []string{a.conf.MsiApiVersion}, "resource": []string{a.conf.Resource}, } + + if a.conf.ClientID != "" { + query["client_id"] = []string{a.conf.ClientID} + } + url := fmt.Sprintf("%s?%s", a.conf.MsiEndpoint, query.Encode()) body, err := azureMetadata(a.ctx, url) @@ -38,7 +47,6 @@ func (a *MsiAuthorizer) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("MsiAuthorizer: failed to request token from metadata endpoint: %v", err) } - // TODO: surface the client ID for use by callers var tokenRes struct { AccessToken string `json:"access_token"` ClientID string `json:"client_id"` @@ -76,13 +84,22 @@ func (a *MsiAuthorizer) Token() (*oauth2.Token, error) { // MsiConfig configures an MsiAuthorizer. type MsiConfig struct { + // ClientID is optionally used to determine which application to assume when a resource has multiple managed identities + ClientID string + + // MsiApiVersion is the API version to use when requesting a token from the metadata service MsiApiVersion string - MsiEndpoint string - Resource string + + // MsiEndpoint is the endpoint where the metadata service can be found + MsiEndpoint string + + // Resource is the service for which to request an access token + Resource string } // NewMsiConfig returns a new MsiConfig with a configured metadata endpoint and resource. -func NewMsiConfig(ctx context.Context, resource string, msiEndpoint string) (*MsiConfig, error) { +// clientId and objectId can be left blank when a single managed identity is available +func NewMsiConfig(ctx context.Context, resource, msiEndpoint, clientId string) (*MsiConfig, error) { endpoint := msiDefaultEndpoint if msiEndpoint != "" { endpoint = msiEndpoint @@ -107,6 +124,7 @@ func NewMsiConfig(ctx context.Context, resource string, msiEndpoint string) (*Ms } return &MsiConfig{ + ClientID: clientId, Resource: resource, MsiApiVersion: msiDefaultApiVersion, MsiEndpoint: endpoint, diff --git a/internal/test/msi_stub.go b/internal/test/msi_stub.go index 0f2af4a8..0fd48c56 100644 --- a/internal/test/msi_stub.go +++ b/internal/test/msi_stub.go @@ -11,8 +11,12 @@ func MsiStubServer(ctx context.Context, port int, token string) chan bool { handler := http.NewServeMux() handler.HandleFunc("/metadata/identity/oauth2/token", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + clientId := q.Get("client_id") + resource := q.Get("resource") w.Header().Set("Content-Type", "application/json; charset=utf-8") - fmt.Fprintf(w, `{"access_token":"%s","client_id":"00000000-0000-0000-0000-000000000000","expires_in":"86391","expires_on":"1611701390","ext_expires_in":"86399","not_before":"1611614690","resource":"https://graph.microsoft.com/","token_type":"Bearer"}`, token) + fmt.Fprintf(w, `{"access_token":"%s","client_id":"%s","expires_in":"86391","expires_on":"1611701390","ext_expires_in":"86399","not_before":"1611614690","resource":"%s","token_type":"Bearer"}`, + token, clientId, resource) }) handler.HandleFunc("/metadata", func(w http.ResponseWriter, r *http.Request) {