Skip to content

Commit

Permalink
pass keyvault env too
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxu92 committed Jul 12, 2024
1 parent 50417f7 commit 35f4094
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
13 changes: 7 additions & 6 deletions internal/customermanagedkeys/key_vault_or_managed_hsm_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (k *KeyVaultOrManagedHSMKey) BaseUri() string {
return ""
}

func parseKeyvauleID(keyRaw string, requireVersion VersionType) (*parse.NestedItemId, error) {
func parseKeyvaultID(keyRaw string, requireVersion VersionType, _ environments.Api) (*parse.NestedItemId, error) {
keyID, err := parse.ParseOptionallyVersionedNestedKeyID(keyRaw)
if err != nil {
return nil, err
Expand Down Expand Up @@ -123,14 +123,14 @@ func parseManagedHSMKey(keyRaw string, requireVersion VersionType, hsmEnv enviro
return versioned, versionless, err
}

func ExpandKeyVaultOrManagedHSMKey(d interface{}, requireVersion VersionType, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
return ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d, requireVersion, "key_vault_key_id", "managed_hsm_key_id", hsmEnv)
func ExpandKeyVaultOrManagedHSMKey(d interface{}, requireVersion VersionType, keyVaultEnv, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
return ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d, requireVersion, "key_vault_key_id", "managed_hsm_key_id", keyVaultEnv, hsmEnv)
}

// ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey
// d: should be one of *pluginsdk.ResourceData or map[string]interface{}
// if return nil, nil, it means no key_vault_key_id or managed_hsm_key_id is specified
func ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d interface{}, requireVersion VersionType, keyVaultFieldName, hsmFieldName string, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
func ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d interface{}, requireVersion VersionType, keyVaultFieldName, hsmFieldName string, keyVaultEnv, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
key := &KeyVaultOrManagedHSMKey{}
var err error
var vaultKeyStr, hsmKeyStr string
Expand All @@ -153,7 +153,7 @@ func ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d interface{}, requireVersi

switch {
case vaultKeyStr != "":
if key.KeyVaultKeyId, err = parseKeyvauleID(vaultKeyStr, requireVersion); err != nil {
if key.KeyVaultKeyId, err = parseKeyvaultID(vaultKeyStr, requireVersion, keyVaultEnv); err != nil {
return nil, err
}
case hsmKeyStr != "":
Expand All @@ -167,7 +167,8 @@ func ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d interface{}, requireVersi
}

// FlattenKeyVaultOrManagedHSMID uses `KeyVaultOrManagedHSMKey.SetState()` to save the state, which this function is designed not to do.
func FlattenKeyVaultOrManagedHSMID(id string, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
func FlattenKeyVaultOrManagedHSMID(id string, keyVaultEnv, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
_ = keyVaultEnv
if id == "" {
return nil, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/hashicorp/go-azure-sdk/sdk/environments"
"github.com/hashicorp/terraform-provider-azurerm/internal/customermanagedkeys"
cmk "github.com/hashicorp/terraform-provider-azurerm/internal/customermanagedkeys"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/parse"
hsmParse "github.com/hashicorp/terraform-provider-azurerm/internal/services/managedhsm/parse"
)
Expand Down Expand Up @@ -43,7 +42,7 @@ func TestExpandKeyVaultOrManagedHSMKeyKey(t *testing.T) {
tests := []struct {
name string
args args
want *cmk.KeyVaultOrManagedHSMKey
want *customermanagedkeys.KeyVaultOrManagedHSMKey
wantErr bool
}{
{
Expand All @@ -52,7 +51,7 @@ func TestExpandKeyVaultOrManagedHSMKeyKey(t *testing.T) {
d: buildKeyVaultData("key_vault_key_id", "https://test.keyvault.azure.net/keys/test-key-name"),
keyVaultFieldName: "key_vault_key_id",
},
want: &cmk.KeyVaultOrManagedHSMKey{
want: &customermanagedkeys.KeyVaultOrManagedHSMKey{
KeyVaultKeyId: &parse.NestedItemId{
KeyVaultBaseUrl: "https://test.keyvault.azure.net/",
NestedItemType: "keys",
Expand Down Expand Up @@ -94,7 +93,7 @@ func TestExpandKeyVaultOrManagedHSMKeyKey(t *testing.T) {
hsmFieldName: "managed_hsm_key_id",
hasVersion: customermanagedkeys.VersionTypeVersionless,
},
want: &cmk.KeyVaultOrManagedHSMKey{
want: &customermanagedkeys.KeyVaultOrManagedHSMKey{
ManagedHSMKeyVersionlessId: &hsmParse.ManagedHSMDataPlaneVersionlessKeyId{
ManagedHSMName: "test",
DomainSuffix: "managedhsm.azure.net",
Expand All @@ -105,7 +104,7 @@ func TestExpandKeyVaultOrManagedHSMKeyKey(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t2 *testing.T) {
got, err := cmk.ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(tt.args.d, tt.args.hasVersion, tt.args.keyVaultFieldName, tt.args.hsmFieldName, tt.args.hsmEnv)
got, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(tt.args.d, tt.args.hasVersion, tt.args.keyVaultFieldName, tt.args.hsmFieldName, nil, tt.args.hsmEnv)
if (err != nil) != tt.wantErr {
t2.Errorf("ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
12 changes: 7 additions & 5 deletions internal/services/cosmos/cosmosdb_account_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,8 @@ func resourceCosmosDbAccount() *pluginsdk.Resource {
func resourceCosmosDbAccountCreate(d *pluginsdk.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).Cosmos.CosmosDBClient
databaseClient := meta.(*clients.Client).Cosmos.DatabaseClient
subscriptionId := meta.(*clients.Client).Account.SubscriptionId
accountClient := meta.(*clients.Client).Account
subscriptionId := accountClient.SubscriptionId
ctx, cancel := timeouts.ForCreate(meta.(*clients.Client).StopContext, d)
defer cancel()
log.Printf("[INFO] Preparing arguments for AzureRM Cosmos DB Account creation")
Expand Down Expand Up @@ -1021,7 +1022,7 @@ func resourceCosmosDbAccountCreate(d *pluginsdk.ResourceData, meta interface{})
return fmt.Errorf("`create_mode` only works when `backup.type` is `Continuous`")
}

if key, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKey(d, customermanagedkeys.VersionTypeAny, meta.(*clients.Client).Account.Environment.ManagedHSM); err != nil {
if key, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKey(d, customermanagedkeys.VersionTypeAny, accountClient.Environment.KeyVault, accountClient.Environment.ManagedHSM); err != nil {
return fmt.Errorf("parse key vault key id: %+v", err)
} else if key != nil {
account.Properties.KeyVaultKeyUri = pointer.To(key.ID())
Expand Down Expand Up @@ -1059,6 +1060,7 @@ func resourceCosmosDbAccountCreate(d *pluginsdk.ResourceData, meta interface{})

func resourceCosmosDbAccountUpdate(d *pluginsdk.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).Cosmos.CosmosDBClient
apiEnvs := meta.(*clients.Client).Account.Environment
// subscriptionId := meta.(*clients.Client).Account.SubscriptionId
ctx, cancel := timeouts.ForUpdate(meta.(*clients.Client).StopContext, d)
defer cancel()
Expand Down Expand Up @@ -1233,7 +1235,7 @@ func resourceCosmosDbAccountUpdate(d *pluginsdk.ResourceData, meta interface{})
Tags: t,
}

if key, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKey(d, customermanagedkeys.VersionTypeAny, meta.(*clients.Client).Account.Environment.ManagedHSM); err != nil {
if key, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKey(d, customermanagedkeys.VersionTypeAny, apiEnvs.KeyVault, apiEnvs.ManagedHSM); err != nil {
return err
} else if key != nil {
account.Properties.KeyVaultKeyUri = pointer.To(key.ID())
Expand Down Expand Up @@ -1481,7 +1483,6 @@ func resourceCosmosDbAccountUpdate(d *pluginsdk.ResourceData, meta interface{})

func resourceCosmosDbAccountRead(d *pluginsdk.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).Cosmos.CosmosDBClient
hsmEnv := meta.(*clients.Client).Account.Environment.ManagedHSM
ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
defer cancel()

Expand Down Expand Up @@ -1552,7 +1553,8 @@ func resourceCosmosDbAccountRead(d *pluginsdk.ResourceData, meta interface{}) er
}

if v := props.KeyVaultKeyUri; v != nil {
if key, err := customermanagedkeys.FlattenKeyVaultOrManagedHSMID(*v, hsmEnv); err != nil {
envs := meta.(*clients.Client).Account.Environment
if key, err := customermanagedkeys.FlattenKeyVaultOrManagedHSMID(*v, envs.KeyVault, envs.ManagedHSM); err != nil {
return fmt.Errorf("flatten key vault uri: %+v", err)
} else if key.IsSet() {
if key.KeyVaultKeyId != nil {
Expand Down

0 comments on commit 35f4094

Please sign in to comment.