diff --git a/postgresql/helpers.go b/postgresql/helpers.go index 2eb93e88..5349d4e8 100644 --- a/postgresql/helpers.go +++ b/postgresql/helpers.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/lib/pq" ) @@ -657,3 +658,19 @@ func quoteTableName(tableName string) string { } return strings.Join(parts, ".") } + +func cloudConfigFromName(name string) (cloud.Configuration, error) { + switch strings.ToLower(name) { + case "china": + return cloud.AzureChina, nil + + case "usgovernment": + return cloud.AzureGovernment, nil + + case "public": + return cloud.AzurePublic, nil + + default: + return cloud.Configuration{}, fmt.Errorf("unsupported Azure Cloud environment: %s", name) + } +} diff --git a/postgresql/provider.go b/postgresql/provider.go index f0df3ea9..5f3e368f 100644 --- a/postgresql/provider.go +++ b/postgresql/provider.go @@ -7,9 +7,10 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "os" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/blang/semver" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" @@ -21,10 +22,23 @@ import ( ) const ( - defaultProviderMaxOpenConnections = 20 - defaultExpectedPostgreSQLVersion = "9.0.0" + defaultProviderMaxOpenConnections = 20 + defaultExpectedPostgreSQLVersion = "9.0.0" + serviceName cloud.ServiceName = "ossrdbms-aad" ) +func init() { + cloud.AzureChina.Services[serviceName] = cloud.ServiceConfiguration{ + Audience: "https://ossrdbms-aad.database.chinacloudapi.cn", + } + cloud.AzureGovernment.Services[serviceName] = cloud.ServiceConfiguration{ + Audience: "https://ossrdbms-aad.database.usgovcloudapi.net", + } + cloud.AzurePublic.Services[serviceName] = cloud.ServiceConfiguration{ + Audience: "https://ossrdbms-aad.database.windows.net", + } +} + // Provider returns a terraform.ResourceProvider. func Provider() *schema.Provider { return &schema.Provider{ @@ -106,6 +120,13 @@ func Provider() *schema.Provider { "(see: https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/how-to-configure-sign-in-azure-ad-authentication)", }, + "azure_environment": { + Type: schema.TypeString, + Optional: true, + Default: "public", + Description: "MS Azure Cloud environment (see: https://registry.terraform.io/providers/hashicorp/azurerm/latest/docs#environment)", + }, + "azure_tenant_id": { Type: schema.TypeString, Optional: true, @@ -310,14 +331,25 @@ func createGoogleCredsFileIfNeeded() error { return os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", tmpFile.Name()) } -func acquireAzureOauthToken(tenantId string) (string, error) { +func acquireAzureOauthToken(environment, tenantId string) (string, error) { + cloudConfig, err := cloudConfigFromName(environment) + if err != nil { + return "", err + } credential, err := azidentity.NewDefaultAzureCredential( - &azidentity.DefaultAzureCredentialOptions{TenantID: tenantId}) + &azidentity.DefaultAzureCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + Cloud: cloudConfig, + }, + TenantID: tenantId, + }) if err != nil { return "", err } token, err := credential.GetToken(context.Background(), policy.TokenRequestOptions{ - Scopes: []string{"https://ossrdbms-aad.database.windows.net/.default"}, + Scopes: []string{ + cloudConfig.Services[serviceName].Audience + "/.default", + }, TenantID: tenantId, }) if err != nil { @@ -354,12 +386,13 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) { return nil, err } } else if d.Get("azure_identity_auth").(bool) { + environment := d.Get("azure_environment").(string) tenantId := d.Get("azure_tenant_id").(string) if tenantId == "" { return nil, fmt.Errorf("postgresql: azure_identity_auth is enabled, azure_tenant_id must be provided also") } var err error - password, err = acquireAzureOauthToken(tenantId) + password, err = acquireAzureOauthToken(environment, tenantId) if err != nil { return nil, err } diff --git a/website/docs/index.html.markdown b/website/docs/index.html.markdown index ce178fbc..9a60bdc2 100644 --- a/website/docs/index.html.markdown +++ b/website/docs/index.html.markdown @@ -186,6 +186,7 @@ The following arguments are supported: * `aws_rds_iam_provider_role_arn` - (Optional) AWS IAM role to assume while using AWS RDS IAM Auth. * `azure_identity_auth` - (Optional) If set to `true`, call the Azure OAuth token endpoint for temporary token * `azure_tenant_id` - (Optional) (Required if `azure_identity_auth` is `true`) Azure tenant ID [read more](https://registry.terraform.io/providers/hashicorp/azurerm/latest/docs/data-sources/client_config.html) +* `azure_environment` - (Optional) The Azure Cloud environment. Possible values are `public`, `usgovernment` and `china`. Defaults to `public`. ## GoCloud