Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase azidentity test coverage #21345

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sdk/azidentity/client_assertion_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func TestClientAssertionCredentialCallbackError(t *testing.T) {
}
}

func TestClientAssertionCredentialNilCallback(t *testing.T) {
_, err := NewClientAssertionCredential(fakeTenantID, fakeClientID, nil, nil)
if err == nil {
t.Fatal("expected an error")
}
}

func TestClientAssertionCredential_Live(t *testing.T) {
data, err := os.ReadFile(liveSP.pemPath)
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions sdk/azidentity/environment_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -93,6 +94,30 @@ func TestEnvironmentCredential_ClientSecretSet(t *testing.T) {
}
}

func TestEnvironmentCredential_CertificateErrors(t *testing.T) {
resetEnvironmentVarsForTest()
for _, test := range []struct {
name, path string
}{
{"file doesn't exist", filepath.Join(t.TempDir(), t.Name())},
{"invalid file", "testdata/certificate-wrong-key.pem"},
} {
t.Run(test.name, func(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureClientCertificatePath: test.path,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
_, err := NewEnvironmentCredential(nil)
if err == nil {
t.Fatal("expected an error")
}
}
})
}
}

func TestEnvironmentCredential_ClientCertificatePathSet(t *testing.T) {
resetEnvironmentVarsForTest()
err := os.Setenv(azureTenantID, fakeTenantID)
Expand Down
61 changes: 61 additions & 0 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)

type userAgentValidatingPolicy struct {
Expand Down Expand Up @@ -73,3 +74,63 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) {
t.Fatal(err)
}
}

func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) {
for _, test := range []struct {
name string
createRequest func(*managedIdentityClient) error
}{
{
name: "Azure Arc",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key")
return err
},
},
{
name: "Cloud Shell",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
{
name: "Service Fabric",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
} {
for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} {
s := "-ClientID"
if id.String() == fakeResourceID {
s = "-ResourceID"
}
t.Run(test.name+s, func(t *testing.T) {
msgs := []string{}
log.SetListener(func(event log.Event, msg string) {
if event == EventAuthentication {
msgs = append(msgs, msg)
}
})
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ID: id,
})
if err != nil {
t.Fatal(err)
}
err = test.createRequest(client)
if err != nil {
t.Fatal(err)
}
for _, msg := range msgs {
if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") {
return
}
}
t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n"))
})
}
}
}
71 changes: 71 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,77 @@ func TestManagedIdentityCredential_AzureArc(t *testing.T) {
testGetTokenSuccess(t, cred)
}

func TestManagedIdentityCredential_AzureArcErrors(t *testing.T) {
for k, v := range map[string]string{
arcIMDSEndpoint: "https://localhost",
identityEndpoint: "https://localhost",
} {
t.Setenv(k, v)
}

for _, test := range []struct {
challenge, name string
statusCode int
}{
{name: "no challenge", statusCode: http.StatusUnauthorized},
{name: "malformed challenge", challenge: "Basic realm", statusCode: http.StatusUnauthorized},
{name: "unexpected status code", statusCode: http.StatusOK},
} {
t.Run(test.name, func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", test.challenge),
mock.WithStatusCode(test.statusCode),
)
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{Transport: srv},
})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
}
t.Run("failed to get key", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetError(fmt.Errorf("it didn't work"))
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: srv,
},
})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
t.Run("no key file", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", "Basic realm="+filepath.Join(t.TempDir(), t.Name())),
mock.WithStatusCode(http.StatusUnauthorized),
)
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
})
}

func TestManagedIdentityCredential_CloudShell(t *testing.T) {
validateReq := func(req *http.Request) *http.Response {
err := req.ParseForm()
Expand Down
23 changes: 23 additions & 0 deletions sdk/azidentity/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ func TestWorkloadIdentityCredential(t *testing.T) {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}})
if err != nil {
t.Fatal(err)
}
}

func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
Expand Down Expand Up @@ -186,6 +190,25 @@ func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
}
}

func TestWorkloadIdentityCredential_NoFile(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: filepath.Join(t.TempDir(), t.Name()),
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: &mockSTS{}},
})
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), testTRO); err == nil {
t.Fatal("expected an error")
}
}

func TestWorkloadIdentityCredential_Options(t *testing.T) {
clientID := "not-" + fakeClientID
tenantID := "not-" + fakeTenantID
Expand Down