Skip to content

Commit

Permalink
Merge pull request #21041 from DrFaust92/r/sagemaker_studio_lifecycle…
Browse files Browse the repository at this point in the history
…_config

r/sagemaker_studio_lifecycle_config - new resource + usage in domain and user profile
  • Loading branch information
ewbankkit authored Sep 27, 2021
2 parents 98c17ba + 3ccd60b commit acf2c8b
Show file tree
Hide file tree
Showing 13 changed files with 689 additions and 12 deletions.
11 changes: 11 additions & 0 deletions .changelog/21041.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
```release-note:new-resource
aws_sagemaker_studio_lifecycle_config
```

```release-note:enhancement
resource/aws_sagemaker_domain: Add `default_user_settings.jupyter_server_app_settings.lifecycle_config_arns` and `default_user_settings.kernel_gateway_app_settings.lifecycle_config_arns` arguments
```

```release-note:enhancement
resource/aws_user_profile: Add `user_settings.jupyter_server_app_settings.lifecycle_config_arns` and `user_settings.kernel_gateway_app_settings.lifecycle_config_arns` arguments
```
25 changes: 25 additions & 0 deletions aws/internal/service/sagemaker/finder/finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,28 @@ func FlowDefinitionByName(conn *sagemaker.SageMaker, name string) (*sagemaker.De

return output, nil
}

func StudioLifecycleConfigByName(conn *sagemaker.SageMaker, name string) (*sagemaker.DescribeStudioLifecycleConfigOutput, error) {
input := &sagemaker.DescribeStudioLifecycleConfigInput{
StudioLifecycleConfigName: aws.String(name),
}

output, err := conn.DescribeStudioLifecycleConfig(input)

if tfawserr.ErrCodeEquals(err, sagemaker.ErrCodeResourceNotFound) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return output, nil
}
1 change: 1 addition & 0 deletions aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ func Provider() *schema.Provider {
"aws_sagemaker_model_package_group": resourceAwsSagemakerModelPackageGroup(),
"aws_sagemaker_notebook_instance_lifecycle_configuration": resourceAwsSagemakerNotebookInstanceLifeCycleConfiguration(),
"aws_sagemaker_notebook_instance": resourceAwsSagemakerNotebookInstance(),
"aws_sagemaker_studio_lifecycle_config": resourceAwsSagemakerStudioLifecycleConfig(),
"aws_sagemaker_user_profile": resourceAwsSagemakerUserProfile(),
"aws_sagemaker_workforce": resourceAwsSagemakerWorkforce(),
"aws_sagemaker_workteam": resourceAwsSagemakerWorkteam(),
Expand Down
32 changes: 32 additions & 0 deletions aws/resource_aws_sagemaker_domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ func resourceAwsSagemakerDomain() *schema.Resource {
},
},
},
"lifecycle_config_arns": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validateArn,
},
},
},
},
},
Expand Down Expand Up @@ -197,6 +205,14 @@ func resourceAwsSagemakerDomain() *schema.Resource {
},
},
},
"lifecycle_config_arns": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validateArn,
},
},
"custom_image": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -478,6 +494,10 @@ func expandSagemakerDomainJupyterServerAppSettings(l []interface{}) *sagemaker.J
config.DefaultResourceSpec = expandSagemakerDomainDefaultResourceSpec(v)
}

if v, ok := m["lifecycle_config_arns"].(*schema.Set); ok && v.Len() > 0 {
config.LifecycleConfigArns = expandStringSet(v)
}

return config
}

Expand All @@ -494,6 +514,10 @@ func expandSagemakerDomainKernelGatewayAppSettings(l []interface{}) *sagemaker.K
config.DefaultResourceSpec = expandSagemakerDomainDefaultResourceSpec(v)
}

if v, ok := m["lifecycle_config_arns"].(*schema.Set); ok && v.Len() > 0 {
config.LifecycleConfigArns = expandStringSet(v)
}

if v, ok := m["custom_image"].([]interface{}); ok && len(v) > 0 {
config.CustomImages = expandSagemakerDomainCustomImages(v)
}
Expand Down Expand Up @@ -657,6 +681,10 @@ func flattenSagemakerDomainJupyterServerAppSettings(config *sagemaker.JupyterSer
m["default_resource_spec"] = flattenSagemakerDomainDefaultResourceSpec(config.DefaultResourceSpec)
}

if config.LifecycleConfigArns != nil {
m["lifecycle_config_arns"] = flattenStringSet(config.LifecycleConfigArns)
}

return []map[string]interface{}{m}
}

Expand All @@ -671,6 +699,10 @@ func flattenSagemakerDomainKernelGatewayAppSettings(config *sagemaker.KernelGate
m["default_resource_spec"] = flattenSagemakerDomainDefaultResourceSpec(config.DefaultResourceSpec)
}

if config.LifecycleConfigArns != nil {
m["lifecycle_config_arns"] = flattenStringSet(config.LifecycleConfigArns)
}

if config.CustomImages != nil {
m["custom_image"] = flattenSagemakerDomainCustomImages(config.CustomImages)
}
Expand Down
65 changes: 65 additions & 0 deletions aws/resource_aws_sagemaker_domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ func testAccAWSSagemakerDomain_kernelGatewayAppSettings(t *testing.T) {
})
}

func testAccAWSSagemakerDomain_kernelGatewayAppSettings_lifecycleConfig(t *testing.T) {
var domain sagemaker.DescribeDomainOutput
rName := acctest.RandomWithPrefix("tf-acc-test")
resourceName := "aws_sagemaker_domain.test"

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID),
Providers: testAccProviders,
CheckDestroy: testAccCheckAWSSagemakerDomainDestroy,
Steps: []resource.TestStep{
{
Config: testAccAWSSagemakerDomainConfigKernelGatewayAppSettingsLifecycleConfig(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSagemakerDomainExists(resourceName, &domain),
resource.TestCheckResourceAttr(resourceName, "default_user_settings.#", "1"),
resource.TestCheckResourceAttr(resourceName, "default_user_settings.0.kernel_gateway_app_settings.#", "1"),
resource.TestCheckResourceAttr(resourceName, "default_user_settings.0.kernel_gateway_app_settings.0.lifecycle_config_arns.#", "1"),
resource.TestCheckResourceAttr(resourceName, "default_user_settings.0.kernel_gateway_app_settings.0.default_resource_spec.#", "1"),
resource.TestCheckResourceAttr(resourceName, "default_user_settings.0.kernel_gateway_app_settings.0.default_resource_spec.0.instance_type", "ml.t3.micro"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
ImportStateVerifyIgnore: []string{"retention_policy"},
},
},
})
}

func testAccAWSSagemakerDomain_kernelGatewayAppSettings_customImage(t *testing.T) {

if os.Getenv("SAGEMAKER_IMAGE_VERSION_BASE_IMAGE") == "" {
Expand Down Expand Up @@ -820,6 +852,39 @@ resource "aws_sagemaker_domain" "test" {
`, rName)
}

func testAccAWSSagemakerDomainConfigKernelGatewayAppSettingsLifecycleConfig(rName string) string {
return testAccAWSSagemakerDomainConfigBase(rName) + fmt.Sprintf(`
resource "aws_sagemaker_studio_lifecycle_config" "test" {
studio_lifecycle_config_name = %[1]q
studio_lifecycle_config_app_type = "JupyterServer"
studio_lifecycle_config_content = base64encode("echo Hello")
}
resource "aws_sagemaker_domain" "test" {
domain_name = %[1]q
auth_mode = "IAM"
vpc_id = aws_vpc.test.id
subnet_ids = [aws_subnet.test.id]
default_user_settings {
execution_role = aws_iam_role.test.arn
kernel_gateway_app_settings {
default_resource_spec {
instance_type = "ml.t3.micro"
}
lifecycle_config_arns = [aws_sagemaker_studio_lifecycle_config.test.arn]
}
}
retention_policy {
home_efs_file_system = "Delete"
}
}
`, rName)
}

func testAccAWSSagemakerDomainConfigKernelGatewayAppSettingsCustomImage(rName, baseImage string) string {
return testAccAWSSagemakerDomainConfigBase(rName) + fmt.Sprintf(`
resource "aws_sagemaker_image" "test" {
Expand Down
164 changes: 164 additions & 0 deletions aws/resource_aws_sagemaker_studio_lifecycle_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package aws

import (
"fmt"
"log"
"regexp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/service/sagemaker/finder"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource"
)

func resourceAwsSagemakerStudioLifecycleConfig() *schema.Resource {
return &schema.Resource{
Create: resourceAwsSagemakerStudioLifecycleConfigCreate,
Read: resourceAwsSagemakerStudioLifecycleConfigRead,
Update: resourceAwsSagemakerStudioLifecycleConfigUpdate,
Delete: resourceAwsSagemakerStudioLifecycleConfigDelete,
Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},

Schema: map[string]*schema.Schema{
"arn": {
Type: schema.TypeString,
Computed: true,
},
"studio_lifecycle_config_app_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.StudioLifecycleConfigAppType_Values(), false),
},
"studio_lifecycle_config_content": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringLenBetween(1, 16384),
},
"studio_lifecycle_config_name": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.All(
validation.StringLenBetween(1, 63),
validation.StringMatch(regexp.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9])*$`), "Valid characters are a-z, A-Z, 0-9, and - (hyphen)."),
),
},
"tags": tagsSchema(),
"tags_all": tagsSchemaComputed(),
},

CustomizeDiff: SetTagsDiff,
}
}

func resourceAwsSagemakerStudioLifecycleConfigCreate(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).sagemakerconn
defaultTagsConfig := meta.(*AWSClient).DefaultTagsConfig
tags := defaultTagsConfig.MergeTags(keyvaluetags.New(d.Get("tags").(map[string]interface{})))

name := d.Get("studio_lifecycle_config_name").(string)
input := &sagemaker.CreateStudioLifecycleConfigInput{
StudioLifecycleConfigName: aws.String(name),
StudioLifecycleConfigAppType: aws.String(d.Get("studio_lifecycle_config_app_type").(string)),
StudioLifecycleConfigContent: aws.String(d.Get("studio_lifecycle_config_content").(string)),
}

if len(tags) > 0 {
input.Tags = tags.IgnoreAws().SagemakerTags()
}

log.Printf("[DEBUG] Creating SageMaker Studio Lifecycle Config : %s", input)
_, err := conn.CreateStudioLifecycleConfig(input)

if err != nil {
return fmt.Errorf("error creating SageMaker Studio Lifecycle Config (%s): %w", name, err)
}

d.SetId(name)

return resourceAwsSagemakerStudioLifecycleConfigRead(d, meta)
}

func resourceAwsSagemakerStudioLifecycleConfigRead(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).sagemakerconn
defaultTagsConfig := meta.(*AWSClient).DefaultTagsConfig
ignoreTagsConfig := meta.(*AWSClient).IgnoreTagsConfig

image, err := finder.StudioLifecycleConfigByName(conn, d.Id())

if !d.IsNewResource() && tfresource.NotFound(err) {
log.Printf("[WARN] SageMaker Studio Lifecycle Config (%s) not found, removing from state", d.Id())
d.SetId("")
return nil
}

if err != nil {
return fmt.Errorf("error reading SageMaker Studio Lifecycle Config (%s): %w", d.Id(), err)
}

arn := aws.StringValue(image.StudioLifecycleConfigArn)
d.Set("studio_lifecycle_config_name", image.StudioLifecycleConfigName)
d.Set("studio_lifecycle_config_app_type", image.StudioLifecycleConfigAppType)
d.Set("studio_lifecycle_config_content", image.StudioLifecycleConfigContent)
d.Set("arn", arn)

tags, err := keyvaluetags.SagemakerListTags(conn, arn)

if err != nil {
return fmt.Errorf("error listing tags for SageMaker Studio Lifecycle Config (%s): %w", d.Id(), err)
}

tags = tags.IgnoreAws().IgnoreConfig(ignoreTagsConfig)

//lintignore:AWSR002
if err := d.Set("tags", tags.RemoveDefaultConfig(defaultTagsConfig).Map()); err != nil {
return fmt.Errorf("error setting tags: %w", err)
}

if err := d.Set("tags_all", tags.Map()); err != nil {
return fmt.Errorf("error setting tags_all: %w", err)
}

return nil
}

func resourceAwsSagemakerStudioLifecycleConfigUpdate(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).sagemakerconn

if d.HasChange("tags_all") {
o, n := d.GetChange("tags_all")

if err := keyvaluetags.SagemakerUpdateTags(conn, d.Get("arn").(string), o, n); err != nil {
return fmt.Errorf("error updating Studio Lifecycle Config (%s) tags: %w", d.Id(), err)
}
}

return resourceAwsSagemakerStudioLifecycleConfigRead(d, meta)
}

func resourceAwsSagemakerStudioLifecycleConfigDelete(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).sagemakerconn

input := &sagemaker.DeleteStudioLifecycleConfigInput{
StudioLifecycleConfigName: aws.String(d.Id()),
}

log.Printf("[DEBUG] Deleting SageMaker Studio Lifecycle Config: (%s)", d.Id())
if _, err := conn.DeleteStudioLifecycleConfig(input); err != nil {
if tfawserr.ErrMessageContains(err, sagemaker.ErrCodeResourceNotFound, "does not exist") {
return nil
}

return fmt.Errorf("error deleting SageMaker Studio Lifecycle Config (%s): %w", d.Id(), err)
}

return nil
}
Loading

0 comments on commit acf2c8b

Please sign in to comment.