Skip to content

Commit

Permalink
Merge pull request #15957 from DrFaust92/r/sagemaker_model_refactor
Browse files Browse the repository at this point in the history
r/sagemaker_model - add support for `image_config` + refactor tests
  • Loading branch information
breathingdust authored Nov 5, 2020
2 parents ba639a1 + 0ae0891 commit f7cef75
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 368 deletions.
112 changes: 83 additions & 29 deletions aws/resource_aws_sagemaker_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
Expand Down Expand Up @@ -80,6 +79,21 @@ func resourceAwsSagemakerModel() *schema.Resource {
ValidateFunc: validateSagemakerEnvironment,
Elem: &schema.Schema{Type: schema.TypeString},
},
"image_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_access_mode": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
},
},
},
},
},
},
Expand All @@ -95,22 +109,21 @@ func resourceAwsSagemakerModel() *schema.Resource {
Type: schema.TypeSet,
Required: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
},
"security_group_ids": {
Type: schema.TypeSet,
Required: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
},
},
},
},

"execution_role_arn": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validateArn,
},

"enable_network_isolation": {
Expand Down Expand Up @@ -160,6 +173,21 @@ func resourceAwsSagemakerModel() *schema.Resource {
ValidateFunc: validateSagemakerEnvironment,
Elem: &schema.Schema{Type: schema.TypeString},
},
"image_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_access_mode": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
},
},
},
},
},
},
Expand Down Expand Up @@ -213,7 +241,7 @@ func resourceAwsSagemakerModelCreate(d *schema.ResourceData, meta interface{}) e
})

if err != nil {
return fmt.Errorf("error creating Sagemaker model: %s", err)
return fmt.Errorf("error creating Sagemaker model: %w", err)
}
d.SetId(name)

Expand Down Expand Up @@ -243,12 +271,12 @@ func resourceAwsSagemakerModelRead(d *schema.ResourceData, meta interface{}) err

model, err := conn.DescribeModel(request)
if err != nil {
if sagemakerErr, ok := err.(awserr.Error); ok && sagemakerErr.Code() == "ValidationException" {
if isAWSErr(err, "ValidationException", "") {
log.Printf("[INFO] unable to find the sagemaker model resource and therefore it is removed from the state: %s", d.Id())
d.SetId("")
return nil
}
return fmt.Errorf("error reading Sagemaker model %s: %s", d.Id(), err)
return fmt.Errorf("error reading Sagemaker model %s: %w", d.Id(), err)
}

if err := d.Set("arn", model.ModelArn); err != nil {
Expand All @@ -270,16 +298,16 @@ func resourceAwsSagemakerModelRead(d *schema.ResourceData, meta interface{}) err
return err
}
if err := d.Set("vpc_config", flattenSageMakerVpcConfigResponse(model.VpcConfig)); err != nil {
return fmt.Errorf("error setting vpc_config: %s", err)
return fmt.Errorf("error setting vpc_config: %w", err)
}

tags, err := keyvaluetags.SagemakerListTags(conn, aws.StringValue(model.ModelArn))
if err != nil {
return fmt.Errorf("error listing tags for Sagemaker Model (%s): %s", d.Id(), err)
return fmt.Errorf("error listing tags for Sagemaker Model (%s): %w", d.Id(), err)
}

if err := d.Set("tags", tags.IgnoreAws().IgnoreConfig(ignoreTagsConfig).Map()); err != nil {
return fmt.Errorf("error setting tags: %s", err)
return fmt.Errorf("error setting tags: %w", err)
}

return nil
Expand All @@ -291,8 +319,8 @@ func flattenSageMakerVpcConfigResponse(vpcConfig *sagemaker.VpcConfig) []map[str
}

m := map[string]interface{}{
"security_group_ids": schema.NewSet(schema.HashString, flattenStringList(vpcConfig.SecurityGroupIds)),
"subnets": schema.NewSet(schema.HashString, flattenStringList(vpcConfig.Subnets)),
"security_group_ids": flattenStringSet(vpcConfig.SecurityGroupIds),
"subnets": flattenStringSet(vpcConfig.Subnets),
}

return []map[string]interface{}{m}
Expand All @@ -305,7 +333,7 @@ func resourceAwsSagemakerModelUpdate(d *schema.ResourceData, meta interface{}) e
o, n := d.GetChange("tags")

if err := keyvaluetags.SagemakerUpdateTags(conn, d.Get("arn").(string), o, n); err != nil {
return fmt.Errorf("error updating Sagemaker Model (%s) tags: %s", d.Id(), err)
return fmt.Errorf("error updating Sagemaker Model (%s) tags: %w", d.Id(), err)
}
}

Expand Down Expand Up @@ -335,7 +363,7 @@ func resourceAwsSagemakerModelDelete(d *schema.ResourceData, meta interface{}) e
_, err = conn.DeleteModel(deleteOpts)
}
if err != nil {
return fmt.Errorf("Error deleting sagemaker model: %s", err)
return fmt.Errorf("Error deleting sagemaker model: %w", err)
}
return nil
}
Expand All @@ -359,9 +387,27 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
container.Environment = stringMapToPointers(v.(map[string]interface{}))
}

if v, ok := m["image_config"]; ok {
container.ImageConfig = expandSagemakerModelImageConfig(v.([]interface{}))
}

return &container
}

func expandSagemakerModelImageConfig(l []interface{}) *sagemaker.ImageConfig {
if len(l) == 0 {
return nil
}

m := l[0].(map[string]interface{})

imageConfig := &sagemaker.ImageConfig{
RepositoryAccessMode: aws.String(m["repository_access_mode"].(string)),
}

return imageConfig
}

func expandContainers(a []interface{}) []*sagemaker.ContainerDefinition {
containers := make([]*sagemaker.ContainerDefinition, 0, len(a))

Expand All @@ -379,22 +425,38 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {

cfg := make(map[string]interface{})

cfg["image"] = *container.Image
cfg["image"] = aws.StringValue(container.Image)

if container.Mode != nil {
cfg["mode"] = *container.Mode
cfg["mode"] = aws.StringValue(container.Mode)
}

if container.ContainerHostname != nil {
cfg["container_hostname"] = *container.ContainerHostname
cfg["container_hostname"] = aws.StringValue(container.ContainerHostname)
}
if container.ModelDataUrl != nil {
cfg["model_data_url"] = *container.ModelDataUrl
cfg["model_data_url"] = aws.StringValue(container.ModelDataUrl)
}
if container.Environment != nil {
cfg["environment"] = flattenEnvironment(container.Environment)
cfg["environment"] = aws.StringValueMap(container.Environment)
}

if container.ImageConfig != nil {
cfg["image_config"] = flattenSagemakerImageConfig(container.ImageConfig)
}

return []interface{}{cfg}
}

func flattenSagemakerImageConfig(imageConfig *sagemaker.ImageConfig) []interface{} {
if imageConfig == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

cfg["repository_access_mode"] = aws.StringValue(imageConfig.RepositoryAccessMode)

return []interface{}{cfg}
}

Expand All @@ -405,11 +467,3 @@ func flattenContainers(containers []*sagemaker.ContainerDefinition) []interface{
}
return fContainers
}

func flattenEnvironment(env map[string]*string) map[string]string {
m := map[string]string{}
for k, v := range env {
m[k] = *v
}
return m
}
Loading

0 comments on commit f7cef75

Please sign in to comment.