Skip to content

Commit

Permalink
provider: log the metdata configuration mode, tidy up provider tests
Browse files Browse the repository at this point in the history
  • Loading branch information
manicminer committed Apr 11, 2024
1 parent c088cef commit f8849bc
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 57 deletions.
24 changes: 19 additions & 5 deletions internal/provider/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,26 @@ package provider
import (
"encoding/base64"
"fmt"
"log"
"os"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
)

// logEntry avoids log entries showing up in test output
func logEntry(f string, v ...interface{}) {
if os.Getenv("TF_LOG") == "" {
return
}

if os.Getenv("TF_ACC") != "" {
return
}

log.Printf(f, v...)
}

func decodeCertificate(clientCertificate string) ([]byte, error) {
var pfx []byte
if clientCertificate != "" {
Expand All @@ -25,7 +39,7 @@ func decodeCertificate(clientCertificate string) ([]byte, error) {
return pfx, nil
}

func getOidcToken(d *schema.ResourceData) (*string, error) {
func getOidcToken(d *pluginsdk.ResourceData) (*string, error) {
idToken := strings.TrimSpace(d.Get("oidc_token").(string))

if path := d.Get("oidc_token_file_path").(string); path != "" {
Expand Down Expand Up @@ -64,7 +78,7 @@ func getOidcToken(d *schema.ResourceData) (*string, error) {
return &idToken, nil
}

func getClientId(d *schema.ResourceData) (*string, error) {
func getClientId(d *pluginsdk.ResourceData) (*string, error) {
clientId := strings.TrimSpace(d.Get("client_id").(string))

if path := d.Get("client_id_file_path").(string); path != "" {
Expand Down Expand Up @@ -94,7 +108,7 @@ func getClientId(d *schema.ResourceData) (*string, error) {
return &clientId, nil
}

func getClientSecret(d *schema.ResourceData) (*string, error) {
func getClientSecret(d *pluginsdk.ResourceData) (*string, error) {
clientSecret := strings.TrimSpace(d.Get("client_secret").(string))

if path := d.Get("client_secret_file_path").(string); path != "" {
Expand All @@ -116,7 +130,7 @@ func getClientSecret(d *schema.ResourceData) (*string, error) {
return &clientSecret, nil
}

func getTenantId(d *schema.ResourceData) (*string, error) {
func getTenantId(d *pluginsdk.ResourceData) (*string, error) {
tenantId := strings.TrimSpace(d.Get("tenant_id").(string))

if d.Get("use_aks_workload_identity").(bool) && os.Getenv("AZURE_TENANT_ID") != "" {
Expand Down
50 changes: 14 additions & 36 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package provider
import (
"context"
"fmt"
"log"
"os"
"strings"
"time"
Expand Down Expand Up @@ -37,18 +36,6 @@ func ValidatePartnerID(i interface{}, k string) ([]string, []error) {
// * a valid UUID prefixed with "pid-"
// * a valid UUID prefixed with "pid-" and suffixed with "-partnercenter"

debugLog := func(f string, v ...interface{}) {
if os.Getenv("TF_LOG") == "" {
return
}

if os.Getenv("TF_ACC") != "" {
return
}

log.Printf(f, v...)
}

v, ok := i.(string)
if !ok {
return nil, []error{fmt.Errorf("expected type of %q to be string", k)}
Expand All @@ -67,7 +54,7 @@ func ValidatePartnerID(i interface{}, k string) ([]string, []error) {
return nil, []error{fmt.Errorf("expected %q to contain a valid UUID", v)}
}

debugLog("[DEBUG] %q partner_id matches pid-<GUID>-partnercenter...", v)
logEntry("[DEBUG] %q partner_id matches pid-<GUID>-partnercenter...", v)
return nil, nil
}

Expand All @@ -79,39 +66,26 @@ func ValidatePartnerID(i interface{}, k string) ([]string, []error) {
return nil, []error{fmt.Errorf("expected %q to be a valid UUID", k)}
}

debugLog("[DEBUG] %q partner_id matches pid-<GUID>...", v)
logEntry("[DEBUG] %q partner_id matches pid-<GUID>...", v)
return nil, nil
}

// Check for straight UUID
if _, err := validation.IsUUID(v, ""); err != nil {
return nil, []error{fmt.Errorf("expected %q to be a valid UUID", k)}
} else {
debugLog("[DEBUG] %q partner_id is an un-prefixed UUID...", v)
logEntry("[DEBUG] %q partner_id is an un-prefixed UUID...", v)
return nil, nil
}
}

func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
// avoids this showing up in test output
debugLog := func(f string, v ...interface{}) {
if os.Getenv("TF_LOG") == "" {
return
}

if os.Getenv("TF_ACC") != "" {
return
}

log.Printf(f, v...)
}

dataSources := make(map[string]*schema.Resource)
resources := make(map[string]*schema.Resource)

// first handle the typed services
for _, service := range SupportedTypedServices() {
debugLog("[DEBUG] Registering Data Sources for %q..", service.Name())
logEntry("[DEBUG] Registering Data Sources for %q..", service.Name())
for _, ds := range service.DataSources() {
key := ds.ResourceType()
if existing := dataSources[key]; existing != nil {
Expand All @@ -127,7 +101,7 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
dataSources[key] = dataSource
}

debugLog("[DEBUG] Registering Resources for %q..", service.Name())
logEntry("[DEBUG] Registering Resources for %q..", service.Name())
for _, r := range service.Resources() {
key := r.ResourceType()
if existing := resources[key]; existing != nil {
Expand All @@ -145,7 +119,7 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {

// then handle the untyped services
for _, service := range SupportedUntypedServices() {
debugLog("[DEBUG] Registering Data Sources for %q..", service.Name())
logEntry("[DEBUG] Registering Data Sources for %q..", service.Name())
for k, v := range service.SupportedDataSources() {
if existing := dataSources[k]; existing != nil {
panic(fmt.Sprintf("An existing Data Source exists for %q", k))
Expand All @@ -154,7 +128,7 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
dataSources[k] = v
}

debugLog("[DEBUG] Registering Resources for %q..", service.Name())
logEntry("[DEBUG] Registering Resources for %q..", service.Name())
for k, v := range service.SupportedResources() {
if existing := resources[k]; existing != nil {
panic(fmt.Sprintf("An existing Resource exists for %q", k))
Expand Down Expand Up @@ -207,7 +181,7 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
Type: schema.TypeString,
Required: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_ENVIRONMENT", "public"),
Description: "The Cloud Environment which should be used. Possible values are public, usgovernment, and china. Defaults to public. Not used when `metadata_host` is specified.",
Description: "The Cloud Environment which should be used. Possible values are public, usgovernment, and china. Defaults to public. Not used and should not be specified when `metadata_host` is specified.",
},

"metadata_host": {
Expand Down Expand Up @@ -420,11 +394,15 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
)

if metadataHost != "" {
logEntry("[DEBUG] Configuring cloud environment from Metadata Service at %q", metadataHost)
if env, err = environments.FromEndpoint(ctx, fmt.Sprintf("https://%s", metadataHost)); err != nil {
return nil, diag.FromErr(err)
}
} else if env, err = environments.FromName(envName); err != nil {
return nil, diag.FromErr(err)
} else {
logEntry("[DEBUG] Configuring built-in cloud environment by name: %q", envName)
if env, err = environments.FromName(envName); err != nil {
return nil, diag.FromErr(err)
}
}

var (
Expand Down
72 changes: 56 additions & 16 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,24 @@ func TestAccProvider_clientCertificateAuth(t *testing.T) {
}
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: d.Get("client_id").(string),
Environment: *env,
TenantID: *tenantId,
ClientID: *clientId,
ClientCertificateData: certData,
ClientCertificatePath: d.Get("client_certificate_path").(string),
ClientCertificatePassword: d.Get("client_certificate_password").(string),
EnableAuthenticatingUsingClientCertificate: true,
ClientCertificateData: certData,
ClientCertificatePath: d.Get("client_certificate_path").(string),
ClientCertificatePassword: d.Get("client_certificate_password").(string),
}

return buildClient(ctx, provider, d, authConfig)
Expand Down Expand Up @@ -267,12 +277,17 @@ func testAccProvider_clientSecretAuthFromEnvironment(t *testing.T) {
return nil, diag.FromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
TenantID: *tenantId,
ClientID: *clientId,
EnableAuthenticatingUsingClientSecret: true,
ClientSecret: *clientSecret,
EnableAuthenticatingUsingClientSecret: true,
}

return buildClient(ctx, provider, d, authConfig)
Expand Down Expand Up @@ -330,12 +345,17 @@ func testAccProvider_clientSecretAuthFromFiles(t *testing.T) {
return nil, diag.FromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
TenantID: *tenantId,
ClientID: *clientId,
EnableAuthenticatingUsingClientSecret: true,
ClientSecret: *clientSecret,
EnableAuthenticatingUsingClientSecret: true,
}

return buildClient(ctx, provider, d, authConfig)
Expand Down Expand Up @@ -380,10 +400,20 @@ func TestAccProvider_genericOidcAuth(t *testing.T) {
return nil, diag.FromErr(err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: d.Get("client_id").(string),
TenantID: *tenantId,
ClientID: *clientId,
EnableAuthenticationUsingOIDC: true,
OIDCAssertionToken: *oidcToken,
}
Expand Down Expand Up @@ -428,13 +458,23 @@ func TestAccProvider_githubOidcAuth(t *testing.T) {
t.Fatalf("configuring environment %q: %v", envName, err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, diag.FromErr(err)
}

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: d.Get("client_id").(string),
EnableAuthenticationUsingGitHubOIDC: true,
TenantID: *tenantId,
ClientID: *clientId,
GitHubOIDCTokenRequestToken: d.Get("oidc_request_token").(string),
GitHubOIDCTokenRequestURL: d.Get("oidc_request_url").(string),
EnableAuthenticationUsingGitHubOIDC: true,
}

return buildClient(ctx, provider, d, authConfig)
Expand Down Expand Up @@ -499,8 +539,8 @@ func TestAccProvider_aksWorkloadIdentityAuth(t *testing.T) {
Environment: *env,
TenantID: *tenantId,
ClientID: *clientId,
EnableAuthenticationUsingOIDC: true,
OIDCAssertionToken: *oidcToken,
EnableAuthenticationUsingOIDC: true,
}

return buildClient(ctx, provider, d, authConfig)
Expand Down

0 comments on commit f8849bc

Please sign in to comment.