diff --git a/apps/public/public.go b/apps/public/public.go index e346ff3d..392e5e43 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -217,11 +217,13 @@ func WithClaims(claims string) interface { func WithAuthenticationScheme(authnScheme AuthenticationScheme) interface { AcquireSilentOption AcquireInteractiveOption + AcquireByUsernamePasswordOption options.CallOption } { return struct { AcquireSilentOption AcquireInteractiveOption + AcquireByUsernamePasswordOption options.CallOption }{ CallOption: options.NewCallOption( @@ -231,6 +233,8 @@ func WithAuthenticationScheme(authnScheme AuthenticationScheme) interface { t.authnScheme = authnScheme case *interactiveAuthOptions: t.authnScheme = authnScheme + case *acquireTokenByUsernamePasswordOptions: + t.authnScheme = authnScheme default: return fmt.Errorf("unexpected options type %T", a) } @@ -349,6 +353,7 @@ func (pca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts // acquireTokenByUsernamePasswordOptions contains optional configuration for AcquireTokenByUsernamePassword type acquireTokenByUsernamePasswordOptions struct { claims, tenantID string + authnScheme AuthenticationScheme } // AcquireByUsernamePasswordOption is implemented by options for AcquireTokenByUsernamePassword @@ -374,6 +379,9 @@ func (pca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []s authParams.Claims = o.claims authParams.Username = username authParams.Password = password + if o.authnScheme != nil { + authParams.AuthnScheme = o.authnScheme + } token, err := pca.base.Token.UsernamePassword(ctx, authParams) if err != nil { diff --git a/apps/public/public_test.go b/apps/public/public_test.go index ad106747..c0fb9b33 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -859,36 +859,111 @@ func TestWithDomainHint(t *testing.T) { func TestWithAuthenticationScheme(t *testing.T) { clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo, tenant := "login.microsoftonline.com", "tenant" - authori := fmt.Sprintf(authorityFmt, lmo, tenant) accessToken, idToken, refreshToken := "at", mock.GetIDToken(tenant, lmo), "rt" authScheme := mock.NewTestAuthnScheme() + var ar AuthResult + var client Client + var err error - mockClient := mock.Client{} - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600))) - client, err := New("client-id", WithAuthority(authori), WithHTTPClient(&mockClient)) - if err != nil { - t.Fatal(err) + for _, testCase := range []struct { + name string + responses [][]byte + }{ + { + name: "interactive", + responses: [][]byte{ + mock.GetTenantDiscoveryBody(lmo, tenant), + mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600), + }, + }, + { + name: "password", + responses: [][]byte{ + mock.GetTenantDiscoveryBody(lmo, tenant), + []byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`), + mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600), + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ctx := context.Background() + + // get a fresh client to avoid any overflow from other tests + client, err = getNewClientWithMockedResponses( + testCase.responses, + lmo, + tenant, + accessToken, + idToken, + refreshToken, + clientInfo, + true, + ) + if err != nil { + t.Fatal(err) + } + + // first, run the test case and try to acquire a token with the mock flow + switch testCase.name { + case "interactive": + ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithAuthenticationScheme(authScheme), WithOpenURL(fakeBrowserOpenURL)) + case "password": + ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithAuthenticationScheme(authScheme)) + default: + t.Fatalf("test bug: no test for " + testCase.name) + } + + // validate that the token is created correctly + if err != nil { + t.Fatal(err) + } + + // and that the token is formatted according to the authentication scheme + if ar.AccessToken != fmt.Sprintf(mock.Authnschemeformat, accessToken) { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + + // next, try acquiring the token again silently from the cache, and validate the token again + ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithAuthenticationScheme(authScheme)) + if err != nil { + t.Fatal(err) + } + + if ar.AccessToken != fmt.Sprintf(mock.Authnschemeformat, accessToken) { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + }) } - ctx := context.Background() - var ar AuthResult - ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithAuthenticationScheme(authScheme), WithOpenURL(fakeBrowserOpenURL)) +} - if err != nil { - t.Fatal(err) +// gets a new public.Client instance with the given responses pre-added to the underlying mock HttpClient +func getNewClientWithMockedResponses( + responses [][]byte, + lmo, + tenant, + accessToken, + idToken, + refreshToken, + clientInfo string, + includeAcquireSilentResponses bool, +) (Client, error) { + authority := fmt.Sprintf(authorityFmt, lmo, tenant) + + mockClient := mock.Client{} + for _, response := range responses { + mockClient.AppendResponse(mock.WithBody(response)) } - if ar.AccessToken != fmt.Sprintf(mock.Authnschemeformat, accessToken) { - t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + + if includeAcquireSilentResponses { + // we will be testing the AcquireTokenSilent flow after the initial flow, so append the correct responses + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600))) } - mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600))) - ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithAuthenticationScheme(authScheme)) + client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient)) if err != nil { - t.Fatal(err) - } - if ar.AccessToken != fmt.Sprintf(mock.Authnschemeformat, accessToken) { - t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + return client, fmt.Errorf("failed to create new public client with error: %w", err) } + return client, nil } diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index 4f29d3aa..c55ed571 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -28,8 +28,8 @@ func createAppWithCert() *confidential.Client { // This extracts our public certificates and private key from the PEM file. If it is // encrypted, the second argument must be password to decode. - // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) - // and to download it when the application starts. + // IMPORTANT SECURITY NOTICE: never store passwords in code. The recommended pattern is to keep the certificate in a vault (e.g. Azure KeyVault) + // and to download it when the application starts. certs, privateKey, err := confidential.CertFromPEM(pemData, "") if err != nil { log.Fatal(err)