Skip to content

Commit

Permalink
Merge pull request #1189 from mpminardi/mpminardi/add-client-id-secre…
Browse files Browse the repository at this point in the history
…t-file-paths

Allow for setting client ID / secret through files
  • Loading branch information
manicminer authored Sep 14, 2023
2 parents 4081389 + c910c2e commit a924023
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 5 deletions.
74 changes: 71 additions & 3 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"log"
"os"
"strings"

"github.com/hashicorp/go-azure-sdk/sdk/auth"
"github.com/hashicorp/go-azure-sdk/sdk/environments"
Expand Down Expand Up @@ -81,6 +82,13 @@ func AzureADProvider() *schema.Provider {
Description: "The Client ID which should be used for service principal authentication",
},

"client_id_file_path": {
Type: schema.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_ID_FILE_PATH", ""),
Description: "The path to a file containing the Client ID which should be used for service principal authentication",
},

"tenant_id": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -132,6 +140,13 @@ func AzureADProvider() *schema.Provider {
Description: "The application password to use when authenticating as a Service Principal using a Client Secret",
},

"client_secret_file_path": {
Type: schema.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_SECRET_FILE_PATH", ""),
Description: "The path to a file containing the application password to use when authenticating as a Service Principal using a Client Secret",
},

// OIDC specific fields
"use_oidc": {
Type: schema.TypeBool,
Expand Down Expand Up @@ -228,9 +243,18 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
}
}

clientSecret, err := getClientSecret(d)
if err != nil {
return nil, diag.FromErr(err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

var (
env *environments.Environment
err error

envName = d.Get("environment").(string)
metadataHost = d.Get("metadata_host").(string)
Expand Down Expand Up @@ -258,11 +282,11 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: d.Get("client_id").(string),
ClientID: *clientId,
ClientCertificateData: certData,
ClientCertificatePassword: d.Get("client_certificate_password").(string),
ClientCertificatePath: d.Get("client_certificate_path").(string),
ClientSecret: d.Get("client_secret").(string),
ClientSecret: *clientSecret,
OIDCAssertionToken: idToken,
GitHubOIDCTokenRequestURL: d.Get("oidc_request_url").(string),
GitHubOIDCTokenRequestToken: d.Get("oidc_request_token").(string),
Expand Down Expand Up @@ -339,3 +363,47 @@ func oidcToken(d *schema.ResourceData) (string, error) {

return idToken, nil
}

func getClientId(d *schema.ResourceData) (*string, error) {
clientId := strings.TrimSpace(d.Get("client_id").(string))

if path := d.Get("client_id_file_path").(string); path != "" {
fileClientIdRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client ID from file %q: %v", path, err)
}

fileClientId := strings.TrimSpace(string(fileClientIdRaw))

if clientId != "" && clientId != fileClientId {
return nil, fmt.Errorf("mismatch between supplied Client ID and supplied Client ID file contents - please either remove one or ensure they match")
}

clientId = fileClientId
}

return &clientId, nil
}

func getClientSecret(d *schema.ResourceData) (*string, error) {
clientSecret := strings.TrimSpace(d.Get("client_secret").(string))

if path := d.Get("client_secret_file_path").(string); path != "" {
fileSecretRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client Secret from file %q: %v", path, err)
}

fileSecret := strings.TrimSpace(string(fileSecretRaw))

if clientSecret != "" && clientSecret != fileSecret {
return nil, fmt.Errorf("mismatch between supplied Client Secret and supplied Client Secret file contents - please either remove one or ensure they match")
}

clientSecret = fileSecret
}

return &clientSecret, nil
}
91 changes: 89 additions & 2 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,25 @@ func TestAccProvider_clientCertificateInlineAuth(t *testing.T) {
}

func TestAccProvider_clientSecretAuth(t *testing.T) {
t.Run("fromEnvironment", testAccProvider_clientSecretAuthFromEnvironment)
t.Run("fromFiles", testAccProvider_clientSecretAuthFromFiles)
}

func testAccProvider_clientSecretAuthFromEnvironment(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip("TF_ACC not set")
}
if os.Getenv("ARM_CLIENT_ID") == "" {
t.Skip("ARM_CLIENT_ID not set")
}
if os.Getenv("ARM_CLIENT_SECRET") == "" {
t.Skip("ARM_CLIENT_SECRET not set")
}

// Ensure we are running using the expected env-vars
// t.SetEnv does automatic cleanup / resets the values after the test
t.Setenv("ARM_CLIENT_ID_FILE_PATH", "")
t.Setenv("ARM_CLIENT_SECRET_FILE_PATH", "")

provider := AzureADProvider()
ctx := context.Background()
Expand All @@ -172,13 +188,84 @@ func TestAccProvider_clientSecretAuth(t *testing.T) {
t.Fatalf("configuring environment %q: %v", envName, err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

clientSecret, err := getClientSecret(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: d.Get("client_id").(string),
ClientID: *clientId,

EnableAuthenticatingUsingClientSecret: true,
ClientSecret: *clientSecret,
}

return buildClient(ctx, provider, authConfig, "")
}

d := provider.Configure(ctx, terraform.NewResourceConfigRaw(nil))
if d != nil && d.HasError() {
t.Fatalf("err: %+v", d)
}

if errs := testCheckProvider(provider); len(errs) > 0 {
for _, err := range errs {
t.Error(err)
}
}
}

func testAccProvider_clientSecretAuthFromFiles(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip("TF_ACC not set")
}
if os.Getenv("ARM_CLIENT_ID_FILE_PATH") == "" {
t.Skip("ARM_CLIENT_ID_FILE_PATH not set")
}
if os.Getenv("ARM_CLIENT_SECRET_FILE_PATH") == "" {
t.Skip("ARM_CLIENT_SECRET_FILE_PATH not set")
}

// Ensure we are running using the expected env-vars
// t.SetEnv does automatic cleanup / resets the values after the test
t.Setenv("ARM_CLIENT_ID", "")
t.Setenv("ARM_CLIENT_SECRET", "")

provider := AzureADProvider()
ctx := context.Background()

// Support only client secret authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
t.Fatalf("configuring environment %q: %v", envName, err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

clientSecret, err := getClientSecret(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: *clientId,

EnableAuthenticatingUsingClientSecret: true,
ClientSecret: d.Get("client_secret").(string),
ClientSecret: *clientSecret,
}

return buildClient(ctx, provider, authConfig, "")
Expand Down

0 comments on commit a924023

Please sign in to comment.