diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 590e4f77e2..c755ba1854 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "os" + "strings" "github.com/hashicorp/go-azure-sdk/sdk/auth" "github.com/hashicorp/go-azure-sdk/sdk/environments" @@ -81,6 +82,13 @@ func AzureADProvider() *schema.Provider { Description: "The Client ID which should be used for service principal authentication", }, + "client_id_file_path": { + Type: schema.TypeString, + Optional: true, + DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_ID_FILE_PATH", ""), + Description: "The path to a file containing the Client ID which should be used for service principal authentication", + }, + "tenant_id": { Type: schema.TypeString, Optional: true, @@ -132,6 +140,13 @@ func AzureADProvider() *schema.Provider { Description: "The application password to use when authenticating as a Service Principal using a Client Secret", }, + "client_secret_file_path": { + Type: schema.TypeString, + Optional: true, + DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_SECRET_FILE_PATH", ""), + Description: "The path to a file containing the application password to use when authenticating as a Service Principal using a Client Secret", + }, + // OIDC specific fields "use_oidc": { Type: schema.TypeBool, @@ -228,9 +243,18 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc { } } + clientSecret, err := getClientSecret(d) + if err != nil { + return nil, diag.FromErr(err) + } + + clientId, err := getClientId(d) + if err != nil { + return nil, diag.FromErr(err) + } + var ( env *environments.Environment - err error envName = d.Get("environment").(string) metadataHost = d.Get("metadata_host").(string) @@ -258,11 +282,11 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc { authConfig := &auth.Credentials{ Environment: *env, TenantID: d.Get("tenant_id").(string), - ClientID: d.Get("client_id").(string), + ClientID: *clientId, ClientCertificateData: certData, ClientCertificatePassword: d.Get("client_certificate_password").(string), ClientCertificatePath: d.Get("client_certificate_path").(string), - ClientSecret: d.Get("client_secret").(string), + ClientSecret: *clientSecret, OIDCAssertionToken: idToken, GitHubOIDCTokenRequestURL: d.Get("oidc_request_url").(string), GitHubOIDCTokenRequestToken: d.Get("oidc_request_token").(string), @@ -339,3 +363,47 @@ func oidcToken(d *schema.ResourceData) (string, error) { return idToken, nil } + +func getClientId(d *schema.ResourceData) (*string, error) { + clientId := strings.TrimSpace(d.Get("client_id").(string)) + + if path := d.Get("client_id_file_path").(string); path != "" { + fileClientIdRaw, err := os.ReadFile(path) + + if err != nil { + return nil, fmt.Errorf("reading Client ID from file %q: %v", path, err) + } + + fileClientId := strings.TrimSpace(string(fileClientIdRaw)) + + if clientId != "" && clientId != fileClientId { + return nil, fmt.Errorf("mismatch between supplied Client ID and supplied Client ID file contents - please either remove one or ensure they match") + } + + clientId = fileClientId + } + + return &clientId, nil +} + +func getClientSecret(d *schema.ResourceData) (*string, error) { + clientSecret := strings.TrimSpace(d.Get("client_secret").(string)) + + if path := d.Get("client_secret_file_path").(string); path != "" { + fileSecretRaw, err := os.ReadFile(path) + + if err != nil { + return nil, fmt.Errorf("reading Client Secret from file %q: %v", path, err) + } + + fileSecret := strings.TrimSpace(string(fileSecretRaw)) + + if clientSecret != "" && clientSecret != fileSecret { + return nil, fmt.Errorf("mismatch between supplied Client Secret and supplied Client Secret file contents - please either remove one or ensure they match") + } + + clientSecret = fileSecret + } + + return &clientSecret, nil +} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index ee0dea0acc..326b1378ef 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -157,9 +157,25 @@ func TestAccProvider_clientCertificateInlineAuth(t *testing.T) { } func TestAccProvider_clientSecretAuth(t *testing.T) { + t.Run("fromEnvironment", testAccProvider_clientSecretAuthFromEnvironment) + t.Run("fromFiles", testAccProvider_clientSecretAuthFromFiles) +} + +func testAccProvider_clientSecretAuthFromEnvironment(t *testing.T) { if os.Getenv("TF_ACC") == "" { t.Skip("TF_ACC not set") } + if os.Getenv("ARM_CLIENT_ID") == "" { + t.Skip("ARM_CLIENT_ID not set") + } + if os.Getenv("ARM_CLIENT_SECRET") == "" { + t.Skip("ARM_CLIENT_SECRET not set") + } + + // Ensure we are running using the expected env-vars + // t.SetEnv does automatic cleanup / resets the values after the test + t.Setenv("ARM_CLIENT_ID_FILE_PATH", "") + t.Setenv("ARM_CLIENT_SECRET_FILE_PATH", "") provider := AzureADProvider() ctx := context.Background() @@ -172,13 +188,84 @@ func TestAccProvider_clientSecretAuth(t *testing.T) { t.Fatalf("configuring environment %q: %v", envName, err) } + clientId, err := getClientId(d) + if err != nil { + return nil, diag.FromErr(err) + } + + clientSecret, err := getClientSecret(d) + if err != nil { + return nil, diag.FromErr(err) + } + authConfig := &auth.Credentials{ Environment: *env, TenantID: d.Get("tenant_id").(string), - ClientID: d.Get("client_id").(string), + ClientID: *clientId, + + EnableAuthenticatingUsingClientSecret: true, + ClientSecret: *clientSecret, + } + + return buildClient(ctx, provider, authConfig, "") + } + + d := provider.Configure(ctx, terraform.NewResourceConfigRaw(nil)) + if d != nil && d.HasError() { + t.Fatalf("err: %+v", d) + } + + if errs := testCheckProvider(provider); len(errs) > 0 { + for _, err := range errs { + t.Error(err) + } + } +} + +func testAccProvider_clientSecretAuthFromFiles(t *testing.T) { + if os.Getenv("TF_ACC") == "" { + t.Skip("TF_ACC not set") + } + if os.Getenv("ARM_CLIENT_ID_FILE_PATH") == "" { + t.Skip("ARM_CLIENT_ID_FILE_PATH not set") + } + if os.Getenv("ARM_CLIENT_SECRET_FILE_PATH") == "" { + t.Skip("ARM_CLIENT_SECRET_FILE_PATH not set") + } + + // Ensure we are running using the expected env-vars + // t.SetEnv does automatic cleanup / resets the values after the test + t.Setenv("ARM_CLIENT_ID", "") + t.Setenv("ARM_CLIENT_SECRET", "") + + provider := AzureADProvider() + ctx := context.Background() + + // Support only client secret authentication + provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { + envName := d.Get("environment").(string) + env, err := environments.FromName(envName) + if err != nil { + t.Fatalf("configuring environment %q: %v", envName, err) + } + + clientId, err := getClientId(d) + if err != nil { + return nil, diag.FromErr(err) + } + + clientSecret, err := getClientSecret(d) + if err != nil { + return nil, diag.FromErr(err) + } + + authConfig := &auth.Credentials{ + Environment: *env, + TenantID: d.Get("tenant_id").(string), + ClientID: *clientId, EnableAuthenticatingUsingClientSecret: true, - ClientSecret: d.Get("client_secret").(string), + ClientSecret: *clientSecret, } return buildClient(ctx, provider, authConfig, "")