From 45d873f9dcd1453e1496e765fecb718d2b268410 Mon Sep 17 00:00:00 2001
From: arcturusZhang <ufo54153@gmail.com>
Date: Wed, 25 Dec 2019 11:24:51 +0800
Subject: [PATCH] Add SSE-CMK feature for managed disks

---
 .../compute/data_source_managed_disk.go       | 30 ++++++--
 .../compute/resource_arm_managed_disk.go      | 68 +++++++++++++------
 website/docs/d/managed_disk.html.markdown     |  2 +
 website/docs/r/managed_disk.html.markdown     |  6 ++
 4 files changed, 81 insertions(+), 25 deletions(-)

diff --git a/azurerm/internal/services/compute/data_source_managed_disk.go b/azurerm/internal/services/compute/data_source_managed_disk.go
index bf3d6ff3aedbc..3ae3cef0e88a6 100644
--- a/azurerm/internal/services/compute/data_source_managed_disk.go
+++ b/azurerm/internal/services/compute/data_source_managed_disk.go
@@ -71,6 +71,16 @@ func dataSourceArmManagedDisk() *schema.Resource {
 				Computed: true,
 			},
 
+			"encryption_type": {
+				Type:     schema.TypeString,
+				Computed: true,
+			},
+
+			"managed_disk_encryption_set_id": {
+				Type:     schema.TypeString,
+				Computed: true,
+			},
+
 			"tags": tags.Schema(),
 		},
 	}
@@ -94,22 +104,32 @@ func dataSourceArmManagedDiskRead(d *schema.ResourceData, meta interface{}) erro
 
 	d.SetId(*resp.ID)
 
+	d.Set("name", name)
+	d.Set("resource_group_name", resGroup)
+
+	if location := resp.Location; location != nil {
+		d.Set("location", azure.NormalizeLocation(*location))
+	}
+
 	if sku := resp.Sku; sku != nil {
 		d.Set("storage_account_type", string(sku.Name))
 	}
 
 	if props := resp.DiskProperties; props != nil {
+		if creationData := props.CreationData; creationData != nil {
+			flattenAzureRmManagedDiskCreationData(d, creationData)
+		}
 		d.Set("disk_size_gb", props.DiskSizeGB)
 		d.Set("disk_iops_read_write", props.DiskIOPSReadWrite)
 		d.Set("disk_mbps_read_write", props.DiskMBpsReadWrite)
 		d.Set("os_type", props.OsType)
+		if encryption := props.Encryption; encryption != nil {
+			d.Set("encryption_type", string(encryption.Type))
+			d.Set("managed_disk_encryption_set_id", encryption.DiskEncryptionSetID)
+		}
 	}
 
-	if resp.CreationData != nil {
-		flattenAzureRmManagedDiskCreationData(d, resp.CreationData)
-	}
-
-	d.Set("zones", resp.Zones)
+	d.Set("zones", utils.FlattenStringSlice(resp.Zones))
 
 	return tags.FlattenAndSet(d, resp.Tags)
 }
diff --git a/azurerm/internal/services/compute/resource_arm_managed_disk.go b/azurerm/internal/services/compute/resource_arm_managed_disk.go
index 55baffd4599df..fd2a151d35386 100644
--- a/azurerm/internal/services/compute/resource_arm_managed_disk.go
+++ b/azurerm/internal/services/compute/resource_arm_managed_disk.go
@@ -125,6 +125,22 @@ func resourceArmManagedDisk() *schema.Resource {
 
 			"encryption_settings": encryptionSettingsSchema(),
 
+			"encryption_type": {
+				Type:     schema.TypeString,
+				Optional: true,
+				ValidateFunc: validation.StringInSlice([]string{
+					string(compute.EncryptionAtRestWithPlatformKey),
+					string(compute.EncryptionAtRestWithCustomerKey),
+				}, false),
+				Default: string(compute.EncryptionAtRestWithPlatformKey),
+			},
+
+			"managed_disk_encryption_set_id": {
+				Type:         schema.TypeString,
+				Optional:     true,
+				ValidateFunc: azure.ValidateResourceID,
+			},
+
 			"tags": tags.Schema(),
 		},
 	}
@@ -169,17 +185,6 @@ func resourceArmManagedDiskCreateUpdate(d *schema.ResourceData, meta interface{}
 	expandedTags := tags.Expand(t)
 	zones := azure.ExpandZones(d.Get("zones").([]interface{}))
 
-	var skuName compute.DiskStorageAccountTypes
-	if strings.EqualFold(storageAccountType, string(compute.PremiumLRS)) {
-		skuName = compute.PremiumLRS
-	} else if strings.EqualFold(storageAccountType, string(compute.StandardLRS)) {
-		skuName = compute.StandardLRS
-	} else if strings.EqualFold(storageAccountType, string(compute.StandardSSDLRS)) {
-		skuName = compute.StandardSSDLRS
-	} else if strings.EqualFold(storageAccountType, string(compute.UltraSSDLRS)) {
-		skuName = compute.UltraSSDLRS
-	}
-
 	createDisk := compute.Disk{
 		Name:     &name,
 		Location: &location,
@@ -187,7 +192,7 @@ func resourceArmManagedDiskCreateUpdate(d *schema.ResourceData, meta interface{}
 			OsType: compute.OperatingSystemTypes(osType),
 		},
 		Sku: &compute.DiskSku{
-			Name: skuName,
+			Name: compute.DiskStorageAccountTypes(storageAccountType),
 		},
 		Tags:  expandedTags,
 		Zones: zones,
@@ -249,6 +254,25 @@ func resourceArmManagedDiskCreateUpdate(d *schema.ResourceData, meta interface{}
 		createDisk.EncryptionSettingsCollection = expandManagedDiskEncryptionSettings(settings)
 	}
 
+	encryption := compute.Encryption{}
+
+	if v, ok := d.GetOk("encryption_type"); ok {
+		encryption.Type = compute.EncryptionType(v.(string))
+		if strings.EqualFold(v.(string), string(compute.EncryptionAtRestWithPlatformKey)) {
+			if _, ok := d.GetOk("managed_disk_encryption_set_id"); ok {
+				return fmt.Errorf("[Error] `managed_disk_encryption_set_id` should not be set when `encryption_type` is `%s`", compute.EncryptionAtRestWithPlatformKey)
+			}
+		} else if strings.EqualFold(v.(string), string(compute.EncryptionAtRestWithCustomerKey)) {
+			if v, ok := d.GetOk("managed_disk_encryption_set_id"); ok {
+				encryption.DiskEncryptionSetID = utils.String(v.(string))
+			} else {
+				return fmt.Errorf("[Error] `managed_disk_encryption_set_id` must be set when `encryption_type` is `%s`", compute.EncryptionAtRestWithCustomerKey)
+			}
+		}
+	}
+
+	createDisk.Encryption = &encryption
+
 	future, err := client.CreateOrUpdate(ctx, resGroup, name, createDisk)
 	if err != nil {
 		return err
@@ -286,6 +310,7 @@ func resourceArmManagedDiskRead(d *schema.ResourceData, meta interface{}) error
 	resp, err := client.Get(ctx, resGroup, name)
 	if err != nil {
 		if utils.ResponseWasNotFound(resp.Response) {
+			log.Printf("[INFO] Disk %q does not exist - removing from state", d.Id())
 			d.SetId("")
 			return nil
 		}
@@ -294,7 +319,7 @@ func resourceArmManagedDiskRead(d *schema.ResourceData, meta interface{}) error
 
 	d.Set("name", resp.Name)
 	d.Set("resource_group_name", resGroup)
-	d.Set("zones", resp.Zones)
+	d.Set("zones", utils.FlattenStringSlice(resp.Zones))
 
 	if location := resp.Location; location != nil {
 		d.Set("location", azure.NormalizeLocation(*location))
@@ -305,19 +330,22 @@ func resourceArmManagedDiskRead(d *schema.ResourceData, meta interface{}) error
 	}
 
 	if props := resp.DiskProperties; props != nil {
+		if creationData := props.CreationData; creationData != nil {
+			flattenAzureRmManagedDiskCreationData(d, creationData)
+		}
 		d.Set("disk_size_gb", props.DiskSizeGB)
 		d.Set("os_type", props.OsType)
 		d.Set("disk_iops_read_write", props.DiskIOPSReadWrite)
 		d.Set("disk_mbps_read_write", props.DiskMBpsReadWrite)
-	}
 
-	if resp.CreationData != nil {
-		flattenAzureRmManagedDiskCreationData(d, resp.CreationData)
-	}
+		if encryption := props.Encryption; encryption != nil {
+			d.Set("encryption_type", string(encryption.Type))
+			d.Set("managed_disk_encryption_set_id", encryption.DiskEncryptionSetID)
+		}
 
-	flattened := flattenManagedDiskEncryptionSettings(resp.EncryptionSettingsCollection)
-	if err := d.Set("encryption_settings", flattened); err != nil {
-		return fmt.Errorf("Error setting encryption settings: %+v", err)
+		if err := d.Set("encryption_settings", flattenManagedDiskEncryptionSettings(props.EncryptionSettingsCollection)); err != nil {
+			return fmt.Errorf("Error setting `encryption_settings`: %+v", err)
+		}
 	}
 
 	return tags.FlattenAndSet(d, resp.Tags)
diff --git a/website/docs/d/managed_disk.html.markdown b/website/docs/d/managed_disk.html.markdown
index 0a0afd3754a9a..9478db0d8fc90 100644
--- a/website/docs/d/managed_disk.html.markdown
+++ b/website/docs/d/managed_disk.html.markdown
@@ -113,5 +113,7 @@ resource "azurerm_virtual_machine" "example" {
 * `disk_size_gb` - The size of the managed disk in gigabytes.
 * `disk_iops_read_write` - The number of IOPS allowed for this disk. One operation can transfer between 4k and 256k bytes.
 * `disk_mbps_read_write` - The bandwidth allowed for this disk. 
+* `encryption_type` - The type of key used to encrypt the data of the disk.
+* `managed_disk_encryption_set_id` - ID of an existing disk encryption set that the current resource is using for data encryption. 
 * `tags` - A mapping of tags assigned to the resource.
 * `zones` - A collection containing the availability zone the managed disk is allocated in.
diff --git a/website/docs/r/managed_disk.html.markdown b/website/docs/r/managed_disk.html.markdown
index e423f34457198..40cd2ab2a0064 100644
--- a/website/docs/r/managed_disk.html.markdown
+++ b/website/docs/r/managed_disk.html.markdown
@@ -113,6 +113,12 @@ The following arguments are supported:
 
 * `encryption_settings` - (Optional) an `encryption_settings` block as defined below.
 
+* `encryption_type` - (Optional) The type of key used to encrypt the data of the disk. Valid values are `EncryptionAtRestWithPlatformKey` or `EncryptionAtRestWithCustomerKey`. Default value is `EncryptionAtRestWithPlatformKey`. When set to `EncryptionAtRestWithPlatformKey`, the disk is encrypted with XStore managed key at rest. When set to `EncryptionAtRestWithCustomerKey`, the disk is encrypted with Customer managed key at rest, and the `managed_disk_encryption_set_id` must be set to a valid `azurerm_disk_encryption_set` ID.
+
+* `managed_disk_encryption_set_id` - (Optional) ID of the disk encryption set to use for enabling encryption at rest.
+
+-> **NOTE** To associate a custom Disk Encryption Set to a managed disk, you must grant access of the KeyVault for the Disk Encryption Set. For instructions, please refer to the doc of [Server side encryption of Azure managed disks](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/disk-encryption).
+
 * `tags` - (Optional) A mapping of tags to assign to the resource.
 
 * `zones` - (Optional) A collection containing the availability zone to allocate the Managed Disk in.