Skip to content

Commit

Permalink
Merge pull request #115 from manicminer/auth/managed-identity-support…
Browse files Browse the repository at this point in the history
…-client-id

Managed Identity: Support specifying the client ID when requesting a token from the metadata service
  • Loading branch information
manicminer authored Oct 14, 2021
2 parents bd069e5 + 781f6c0 commit 58b726a
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 12 deletions.
6 changes: 3 additions & 3 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (c *Config) NewAuthorizer(ctx context.Context, api Api) (Authorizer, error)
}

if c.EnableMsiAuth {
a, err := NewMsiAuthorizer(ctx, c.Environment, api, c.MsiEndpoint)
a, err := NewMsiAuthorizer(ctx, c.Environment, api, c.MsiEndpoint, c.ClientID)
if err != nil {
return nil, fmt.Errorf("could not configure MSI Authorizer: %s", err)
}
Expand Down Expand Up @@ -97,8 +97,8 @@ func NewAzureCliAuthorizer(ctx context.Context, api Api, tenantId string) (Autho
}

// NewMsiAuthorizer returns an authorizer which uses managed service identity to for authentication.
func NewMsiAuthorizer(ctx context.Context, environment environments.Environment, api Api, msiEndpoint string) (Authorizer, error) {
conf, err := NewMsiConfig(ctx, resource(environment, api), msiEndpoint)
func NewMsiAuthorizer(ctx context.Context, environment environments.Environment, api Api, msiEndpoint, clientId string) (Authorizer, error) {
conf, err := NewMsiConfig(ctx, resource(environment, api), msiEndpoint, clientId)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestMsiAuthorizer(t *testing.T) {
done <- true
}()
}
auth, err := auth.NewMsiAuthorizer(ctx, environments.Global, auth.MsGraph, msiEndpoint)
auth, err := auth.NewMsiAuthorizer(ctx, environments.Global, auth.MsGraph, msiEndpoint, clientId)
if err != nil {
t.Fatalf("NewMsiAuthorizer(): %v", err)
}
Expand Down
6 changes: 5 additions & 1 deletion auth/azcli.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ type AzureCliAuthorizer struct {
}

// Token returns an access token using the Azure CLI as an authentication mechanism.
func (a AzureCliAuthorizer) Token() (*oauth2.Token, error) {
func (a *AzureCliAuthorizer) Token() (*oauth2.Token, error) {
if a.conf == nil {
return nil, fmt.Errorf("could not request token: conf is nil")
}

var token struct {
AccessToken string `json:"accessToken"`
ExpiresOn string `json:"expiresOn"`
Expand Down
12 changes: 10 additions & 2 deletions auth/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ type clientAssertionAuthorizer struct {
conf *ClientCredentialsConfig
}

func (a clientAssertionAuthorizer) Token() (*oauth2.Token, error) {
func (a *clientAssertionAuthorizer) Token() (*oauth2.Token, error) {
if a.conf == nil {
return nil, fmt.Errorf("could not request token: conf is nil")
}

crt := a.conf.Certificate
if der, _ := pem.Decode(a.conf.Certificate); der != nil {
crt = der.Bytes
Expand Down Expand Up @@ -246,7 +250,11 @@ type clientSecretAuthorizer struct {
conf *ClientCredentialsConfig
}

func (a clientSecretAuthorizer) Token() (*oauth2.Token, error) {
func (a *clientSecretAuthorizer) Token() (*oauth2.Token, error) {
if a.conf == nil {
return nil, fmt.Errorf("could not request token: conf is nil")
}

v := url.Values{
"client_id": {a.conf.ClientID},
"client_secret": {a.conf.ClientSecret},
Expand Down
26 changes: 22 additions & 4 deletions auth/msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ type MsiAuthorizer struct {

// Token returns an access token acquired from the metadata endpoint.
func (a *MsiAuthorizer) Token() (*oauth2.Token, error) {
if a.conf == nil {
return nil, fmt.Errorf("could not request token: conf is nil")
}

query := url.Values{
"api-version": []string{a.conf.MsiApiVersion},
"resource": []string{a.conf.Resource},
}

if a.conf.ClientID != "" {
query["client_id"] = []string{a.conf.ClientID}
}

url := fmt.Sprintf("%s?%s", a.conf.MsiEndpoint, query.Encode())

body, err := azureMetadata(a.ctx, url)
if err != nil {
return nil, fmt.Errorf("MsiAuthorizer: failed to request token from metadata endpoint: %v", err)
}

// TODO: surface the client ID for use by callers
var tokenRes struct {
AccessToken string `json:"access_token"`
ClientID string `json:"client_id"`
Expand Down Expand Up @@ -76,13 +84,22 @@ func (a *MsiAuthorizer) Token() (*oauth2.Token, error) {

// MsiConfig configures an MsiAuthorizer.
type MsiConfig struct {
// ClientID is optionally used to determine which application to assume when a resource has multiple managed identities
ClientID string

// MsiApiVersion is the API version to use when requesting a token from the metadata service
MsiApiVersion string
MsiEndpoint string
Resource string

// MsiEndpoint is the endpoint where the metadata service can be found
MsiEndpoint string

// Resource is the service for which to request an access token
Resource string
}

// NewMsiConfig returns a new MsiConfig with a configured metadata endpoint and resource.
func NewMsiConfig(ctx context.Context, resource string, msiEndpoint string) (*MsiConfig, error) {
// clientId and objectId can be left blank when a single managed identity is available
func NewMsiConfig(ctx context.Context, resource, msiEndpoint, clientId string) (*MsiConfig, error) {
endpoint := msiDefaultEndpoint
if msiEndpoint != "" {
endpoint = msiEndpoint
Expand All @@ -107,6 +124,7 @@ func NewMsiConfig(ctx context.Context, resource string, msiEndpoint string) (*Ms
}

return &MsiConfig{
ClientID: clientId,
Resource: resource,
MsiApiVersion: msiDefaultApiVersion,
MsiEndpoint: endpoint,
Expand Down
6 changes: 5 additions & 1 deletion internal/test/msi_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ func MsiStubServer(ctx context.Context, port int, token string) chan bool {
handler := http.NewServeMux()

handler.HandleFunc("/metadata/identity/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientId := q.Get("client_id")
resource := q.Get("resource")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, `{"access_token":"%s","client_id":"00000000-0000-0000-0000-000000000000","expires_in":"86391","expires_on":"1611701390","ext_expires_in":"86399","not_before":"1611614690","resource":"https://graph.microsoft.com/","token_type":"Bearer"}`, token)
fmt.Fprintf(w, `{"access_token":"%s","client_id":"%s","expires_in":"86391","expires_on":"1611701390","ext_expires_in":"86399","not_before":"1611614690","resource":"%s","token_type":"Bearer"}`,
token, clientId, resource)
})

handler.HandleFunc("/metadata", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 58b726a

Please sign in to comment.