diff --git a/.changelog/20809.txt b/.changelog/20809.txt new file mode 100644 index 00000000000..22b161423df --- /dev/null +++ b/.changelog/20809.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +resource/aws_sagemaker_endpoint_configuration: Add `async_inference_config` argument +``` diff --git a/aws/internal/service/sagemaker/finder/finder.go b/aws/internal/service/sagemaker/finder/finder.go index ca9be2fdd14..9f2775eb719 100644 --- a/aws/internal/service/sagemaker/finder/finder.go +++ b/aws/internal/service/sagemaker/finder/finder.go @@ -5,6 +5,7 @@ import ( "github.com/aws/aws-sdk-go/service/sagemaker" "github.com/hashicorp/aws-sdk-go-base/tfawserr" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) // CodeRepositoryByName returns the code repository corresponding to the specified name. @@ -103,10 +104,7 @@ func DeviceFleetByName(conn *sagemaker.SageMaker, id string) (*sagemaker.Describ } if output == nil { - return nil, &resource.NotFoundError{ - Message: "Empty result", - LastRequest: input, - } + return nil, tfresource.NewEmptyResultError(input) } return output, nil @@ -150,10 +148,7 @@ func FeatureGroupByName(conn *sagemaker.SageMaker, name string) (*sagemaker.Desc } if output == nil { - return nil, &resource.NotFoundError{ - Message: "Empty result", - LastRequest: input, - } + return nil, tfresource.NewEmptyResultError(input) } return output, nil @@ -239,10 +234,7 @@ func WorkforceByName(conn *sagemaker.SageMaker, name string) (*sagemaker.Workfor } if output == nil || output.Workforce == nil { - return nil, &resource.NotFoundError{ - Message: "Empty result", - LastRequest: input, - } + return nil, tfresource.NewEmptyResultError(input) } return output.Workforce, nil @@ -267,10 +259,7 @@ func WorkteamByName(conn *sagemaker.SageMaker, name string) (*sagemaker.Workteam } if output == nil || output.Workteam == nil { - return nil, &resource.NotFoundError{ - Message: "Empty result", - LastRequest: input, - } + return nil, tfresource.NewEmptyResultError(input) } return output.Workteam, nil @@ -295,11 +284,33 @@ func HumanTaskUiByName(conn *sagemaker.SageMaker, name string) (*sagemaker.Descr } if output == nil { + return nil, tfresource.NewEmptyResultError(input) + } + + return output, nil +} + +func EndpointConfigByName(conn *sagemaker.SageMaker, name string) (*sagemaker.DescribeEndpointConfigOutput, error) { + input := &sagemaker.DescribeEndpointConfigInput{ + EndpointConfigName: aws.String(name), + } + + output, err := conn.DescribeEndpointConfig(input) + + if tfawserr.ErrMessageContains(err, "ValidationException", "Could not find endpoint configuration") { return nil, &resource.NotFoundError{ - Message: "Empty result", + LastError: err, LastRequest: input, } } + if err != nil { + return nil, err + } + + if output == nil { + return nil, tfresource.NewEmptyResultError(input) + } + return output, nil } diff --git a/aws/resource_aws_sagemaker_endpoint_configuration.go b/aws/resource_aws_sagemaker_endpoint_configuration.go index 347b5edc593..33ecbff2875 100644 --- a/aws/resource_aws_sagemaker_endpoint_configuration.go +++ b/aws/resource_aws_sagemaker_endpoint_configuration.go @@ -7,10 +7,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sagemaker" + "github.com/hashicorp/aws-sdk-go-base/tfawserr" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/sagemaker/finder" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) func resourceAwsSagemakerEndpointConfiguration() *schema.Resource { @@ -28,6 +31,79 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource { Type: schema.TypeString, Computed: true, }, + "async_inference_config": { + Type: schema.TypeList, + MaxItems: 1, + Optional: true, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "client_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "max_concurrent_invocations_per_instance": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntBetween(1, 1000), + }, + }, + }, + }, + "output_config": { + Type: schema.TypeList, + Required: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "kms_key_id": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validateArn, + }, + "notification_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "error_topic": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validateArn, + }, + "success_topic": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validateArn, + }, + }, + }, + }, + "s3_output_path": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: validation.All( + validation.StringMatch(regexp.MustCompile(`^(https|s3)://([^/])/?(.*)$`), ""), + validation.StringLenBetween(1, 512), + ), + }, + }, + }, + }, + }, + }, + }, "name": { Type: schema.TypeString, @@ -227,10 +303,14 @@ func resourceAwsSagemakerEndpointConfigurationCreate(d *schema.ResourceData, met createOpts.DataCaptureConfig = expandSagemakerDataCaptureConfig(v.([]interface{})) } + if v, ok := d.GetOk("async_inference_config"); ok { + createOpts.AsyncInferenceConfig = expandSagemakerEndpointConfigAsyncInferenceConfig(v.([]interface{})) + } + log.Printf("[DEBUG] SageMaker Endpoint Configuration create config: %#v", *createOpts) _, err := conn.CreateEndpointConfig(createOpts) if err != nil { - return fmt.Errorf("error creating SageMaker Endpoint Configuration: %s", err) + return fmt.Errorf("error creating SageMaker Endpoint Configuration: %w", err) } d.SetId(name) @@ -242,39 +322,37 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta defaultTagsConfig := meta.(*AWSClient).DefaultTagsConfig ignoreTagsConfig := meta.(*AWSClient).IgnoreTagsConfig - request := &sagemaker.DescribeEndpointConfigInput{ - EndpointConfigName: aws.String(d.Id()), + endpointConfig, err := finder.EndpointConfigByName(conn, d.Id()) + + if !d.IsNewResource() && tfresource.NotFound(err) { + log.Printf("[WARN] SageMaker Endpoint Configuration (%s) not found, removing from state", d.Id()) + d.SetId("") + return nil } - endpointConfig, err := conn.DescribeEndpointConfig(request) if err != nil { - if isAWSErr(err, "ValidationException", "Could not find endpoint configuration") { - log.Printf("[INFO] unable to find the SageMaker Endpoint Configuration resource and therefore it is removed from the state: %s", d.Id()) - d.SetId("") - return nil - } - return fmt.Errorf("error reading SageMaker Endpoint Configuration %s: %s", d.Id(), err) + return fmt.Errorf("error reading SageMaker Endpoint Configuration (%s): %w", d.Id(), err) } - if err := d.Set("arn", endpointConfig.EndpointConfigArn); err != nil { - return err - } - if err := d.Set("name", endpointConfig.EndpointConfigName); err != nil { - return err - } + d.Set("arn", endpointConfig.EndpointConfigArn) + d.Set("name", endpointConfig.EndpointConfigName) + d.Set("kms_key_arn", endpointConfig.KmsKeyId) + if err := d.Set("production_variants", flattenProductionVariants(endpointConfig.ProductionVariants)); err != nil { - return err - } - if err := d.Set("kms_key_arn", endpointConfig.KmsKeyId); err != nil { - return err + return fmt.Errorf("error setting production_variants for SageMaker Endpoint Configuration (%s): %w", d.Id(), err) } + if err := d.Set("data_capture_config", flattenSagemakerDataCaptureConfig(endpointConfig.DataCaptureConfig)); err != nil { - return err + return fmt.Errorf("error setting data_capture_config for SageMaker Endpoint Configuration (%s): %w", d.Id(), err) + } + + if err := d.Set("async_inference_config", flattenSagemakerEndpointConfigAsyncInferenceConfig(endpointConfig.AsyncInferenceConfig)); err != nil { + return fmt.Errorf("error setting async_inference_config for SageMaker Endpoint Configuration (%s): %w", d.Id(), err) } tags, err := keyvaluetags.SagemakerListTags(conn, aws.StringValue(endpointConfig.EndpointConfigArn)) if err != nil { - return fmt.Errorf("error listing tags for Sagemaker Endpoint Configuration (%s): %s", d.Id(), err) + return fmt.Errorf("error listing tags for Sagemaker Endpoint Configuration (%s): %w", d.Id(), err) } tags = tags.IgnoreAws().IgnoreConfig(ignoreTagsConfig) @@ -298,7 +376,7 @@ func resourceAwsSagemakerEndpointConfigurationUpdate(d *schema.ResourceData, met o, n := d.GetChange("tags_all") if err := keyvaluetags.SagemakerUpdateTags(conn, d.Get("arn").(string), o, n); err != nil { - return fmt.Errorf("error updating Sagemaker Endpoint Configuration (%s) tags: %s", d.Id(), err) + return fmt.Errorf("error updating Sagemaker Endpoint Configuration (%s) tags: %w", d.Id(), err) } } return resourceAwsSagemakerEndpointConfigurationRead(d, meta) @@ -314,12 +392,12 @@ func resourceAwsSagemakerEndpointConfigurationDelete(d *schema.ResourceData, met _, err := conn.DeleteEndpointConfig(deleteOpts) - if isAWSErr(err, sagemaker.ErrCodeResourceNotFound, "") { + if tfawserr.ErrMessageContains(err, "ValidationException", "Could not find endpoint configuration") { return nil } if err != nil { - return fmt.Errorf("error deleting SageMaker Endpoint Configuration (%s): %s", d.Id(), err) + return fmt.Errorf("error deleting SageMaker Endpoint Configuration (%s): %w", d.Id(), err) } return nil @@ -487,3 +565,151 @@ func flattenSagemakerCaptureContentTypeHeader(contentTypeHeader *sagemaker.Captu return []map[string]interface{}{l} } + +func expandSagemakerEndpointConfigAsyncInferenceConfig(configured []interface{}) *sagemaker.AsyncInferenceConfig { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &sagemaker.AsyncInferenceConfig{} + + if v, ok := m["client_config"]; ok && (len(v.([]interface{})) > 0) { + c.ClientConfig = expandSagemakerEndpointConfigClientConfig(v.([]interface{})) + } + + if v, ok := m["output_config"]; ok && (len(v.([]interface{})) > 0) { + c.OutputConfig = expandSagemakerEndpointConfigOutputConfig(v.([]interface{})) + } + + return c +} + +func expandSagemakerEndpointConfigClientConfig(configured []interface{}) *sagemaker.AsyncInferenceClientConfig { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &sagemaker.AsyncInferenceClientConfig{} + + if v, ok := m["max_concurrent_invocations_per_instance"]; ok { + c.MaxConcurrentInvocationsPerInstance = aws.Int64(int64(v.(int))) + } + + return c +} + +func expandSagemakerEndpointConfigOutputConfig(configured []interface{}) *sagemaker.AsyncInferenceOutputConfig { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &sagemaker.AsyncInferenceOutputConfig{ + S3OutputPath: aws.String(m["s3_output_path"].(string)), + } + + if v, ok := m["kms_key_id"]; ok { + c.KmsKeyId = aws.String(v.(string)) + } + + if v, ok := m["notification_config"]; ok && (len(v.([]interface{})) > 0) { + c.NotificationConfig = expandSagemakerEndpointConfigNotificationConfig(v.([]interface{})) + } + + return c +} + +func expandSagemakerEndpointConfigNotificationConfig(configured []interface{}) *sagemaker.AsyncInferenceNotificationConfig { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &sagemaker.AsyncInferenceNotificationConfig{} + + if v, ok := m["error_topic"]; ok { + c.ErrorTopic = aws.String(v.(string)) + } + + if v, ok := m["success_topic"]; ok { + c.SuccessTopic = aws.String(v.(string)) + } + + return c +} + +func flattenSagemakerEndpointConfigAsyncInferenceConfig(config *sagemaker.AsyncInferenceConfig) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{} + + if config.ClientConfig != nil { + cfg["client_config"] = flattenSagemakerEndpointConfigClientConfig(config.ClientConfig) + } + + if config.OutputConfig != nil { + cfg["output_config"] = flattenSagemakerEndpointConfigOutputConfig(config.OutputConfig) + } + + return []map[string]interface{}{cfg} +} + +func flattenSagemakerEndpointConfigClientConfig(config *sagemaker.AsyncInferenceClientConfig) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{} + + if config.MaxConcurrentInvocationsPerInstance != nil { + cfg["max_concurrent_invocations_per_instance"] = aws.Int64Value(config.MaxConcurrentInvocationsPerInstance) + } + + return []map[string]interface{}{cfg} +} + +func flattenSagemakerEndpointConfigOutputConfig(config *sagemaker.AsyncInferenceOutputConfig) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{ + "s3_output_path": aws.StringValue(config.S3OutputPath), + } + + if config.KmsKeyId != nil { + cfg["kms_key_id"] = aws.StringValue(config.KmsKeyId) + } + + if config.NotificationConfig != nil { + cfg["notification_config"] = flattenSagemakerEndpointConfigNotificationConfig(config.NotificationConfig) + } + + return []map[string]interface{}{cfg} +} + +func flattenSagemakerEndpointConfigNotificationConfig(config *sagemaker.AsyncInferenceNotificationConfig) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{} + + if config.ErrorTopic != nil { + cfg["error_topic"] = aws.StringValue(config.ErrorTopic) + } + + if config.SuccessTopic != nil { + cfg["success_topic"] = aws.StringValue(config.SuccessTopic) + } + + return []map[string]interface{}{cfg} +} diff --git a/aws/resource_aws_sagemaker_endpoint_configuration_test.go b/aws/resource_aws_sagemaker_endpoint_configuration_test.go index 19f36510b65..bbec282f058 100644 --- a/aws/resource_aws_sagemaker_endpoint_configuration_test.go +++ b/aws/resource_aws_sagemaker_endpoint_configuration_test.go @@ -7,9 +7,12 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sagemaker" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/sagemaker/finder" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) func init() { @@ -25,35 +28,41 @@ func init() { func testSweepSagemakerEndpointConfigurations(region string) error { client, err := sharedClientForRegion(region) if err != nil { - return fmt.Errorf("error getting client: %s", err) + return fmt.Errorf("error getting client: %w", err) } conn := client.(*AWSClient).sagemakerconn + var sweeperErrs *multierror.Error req := &sagemaker.ListEndpointConfigsInput{ NameContains: aws.String("tf-acc-test"), } - resp, err := conn.ListEndpointConfigs(req) - if err != nil { - return fmt.Errorf("error listing endpoint configs: %s", err) - } + err = conn.ListEndpointConfigsPages(req, func(page *sagemaker.ListEndpointConfigsOutput, lastPage bool) bool { + for _, endpointConfig := range page.EndpointConfigs { + + r := resourceAwsSagemakerEndpointConfiguration() + d := r.Data(nil) + d.SetId(aws.StringValue(endpointConfig.EndpointConfigName)) + err := r.Delete(d, client) + if err != nil { + log.Printf("[ERROR] %s", err) + sweeperErrs = multierror.Append(sweeperErrs, err) + continue + } + } - if len(resp.EndpointConfigs) == 0 { - log.Print("[DEBUG] No SageMaker endpoint config to sweep") - return nil + return !lastPage + }) + + if testSweepSkipSweepError(err) { + log.Printf("[WARN] Skipping SageMaker Endpoint Config sweep for %s: %s", region, err) + return sweeperErrs.ErrorOrNil() } - for _, endpointConfig := range resp.EndpointConfigs { - _, err := conn.DeleteEndpointConfig(&sagemaker.DeleteEndpointConfigInput{ - EndpointConfigName: endpointConfig.EndpointConfigName, - }) - if err != nil { - return fmt.Errorf( - "failed to delete SageMaker endpoint config (%s): %s", - *endpointConfig.EndpointConfigName, err) - } + if err != nil { + sweeperErrs = multierror.Append(sweeperErrs, fmt.Errorf("error retrieving Sagemaker Endpoint Configs: %w", err)) } - return nil + return sweeperErrs.ErrorOrNil() } func TestAccAWSSagemakerEndpointConfiguration_basic(t *testing.T) { @@ -77,7 +86,9 @@ func TestAccAWSSagemakerEndpointConfiguration_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_instance_count", "2"), resource.TestCheckResourceAttr(resourceName, "production_variants.0.instance_type", "ml.t2.medium"), resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_variant_weight", "1"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.code_dump_config.#", "0"), resource.TestCheckResourceAttr(resourceName, "data_capture_config.#", "0"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "0"), ), }, { @@ -260,6 +271,7 @@ func TestAccAWSSagemakerEndpointConfiguration_disappears(t *testing.T) { Check: resource.ComposeTestCheckFunc( testAccCheckSagemakerEndpointConfigurationExists(resourceName), testAccCheckResourceDisappears(testAccProvider, resourceAwsSagemakerEndpointConfiguration(), resourceName), + testAccCheckResourceDisappears(testAccProvider, resourceAwsSagemakerEndpointConfiguration(), resourceName), ), ExpectNonEmptyPlan: true, }, @@ -267,6 +279,102 @@ func TestAccAWSSagemakerEndpointConfiguration_disappears(t *testing.T) { }) } +func TestAccAWSSagemakerEndpointConfiguration_async(t *testing.T) { + rName := acctest.RandomWithPrefix("tf-acc-test") + resourceName := "aws_sagemaker_endpoint_configuration.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID), + Providers: testAccProviders, + CheckDestroy: testAccCheckSagemakerEndpointConfigurationDestroy, + Steps: []resource.TestStep{ + { + Config: testAccSagemakerEndpointConfigurationConfigAsyncConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckSagemakerEndpointConfigurationExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.#", "0"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_output_path"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.#", "0"), + resource.TestCheckResourceAttrPair(resourceName, "async_inference_config.0.output_config.0.kms_key_id", "aws_kms_key.test", "arn"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccAWSSagemakerEndpointConfiguration_async_notifConfig(t *testing.T) { + rName := acctest.RandomWithPrefix("tf-acc-test") + resourceName := "aws_sagemaker_endpoint_configuration.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID), + Providers: testAccProviders, + CheckDestroy: testAccCheckSagemakerEndpointConfigurationDestroy, + Steps: []resource.TestStep{ + { + Config: testAccSagemakerEndpointConfigurationConfigAsyncNotifConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckSagemakerEndpointConfigurationExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.#", "0"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_output_path"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.#", "1"), + resource.TestCheckResourceAttrPair(resourceName, "async_inference_config.0.output_config.0.notification_config.0.error_topic", "aws_sns_topic.test", "arn"), + resource.TestCheckResourceAttrPair(resourceName, "async_inference_config.0.output_config.0.notification_config.0.success_topic", "aws_sns_topic.test", "arn"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccAWSSagemakerEndpointConfiguration_async_client(t *testing.T) { + rName := acctest.RandomWithPrefix("tf-acc-test") + resourceName := "aws_sagemaker_endpoint_configuration.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ErrorCheck: testAccErrorCheck(t, sagemaker.EndpointsID), + Providers: testAccProviders, + CheckDestroy: testAccCheckSagemakerEndpointConfigurationDestroy, + Steps: []resource.TestStep{ + { + Config: testAccSagemakerEndpointConfigurationConfigAsyncClientConfig(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckSagemakerEndpointConfigurationExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "name", rName), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.0.max_concurrent_invocations_per_instance", "1"), + resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"), + resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_output_path"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func testAccCheckSagemakerEndpointConfigurationDestroy(s *terraform.State) error { conn := testAccProvider.Meta().(*AWSClient).sagemakerconn @@ -275,13 +383,9 @@ func testAccCheckSagemakerEndpointConfigurationDestroy(s *terraform.State) error continue } - input := &sagemaker.DescribeEndpointConfigInput{ - EndpointConfigName: aws.String(rs.Primary.ID), - } - - _, err := conn.DescribeEndpointConfig(input) + _, err := finder.EndpointConfigByName(conn, rs.Primary.ID) - if isAWSErr(err, "ValidationException", "") { + if tfresource.NotFound(err) { continue } @@ -289,7 +393,7 @@ func testAccCheckSagemakerEndpointConfigurationDestroy(s *terraform.State) error return err } - return fmt.Errorf("SageMaker Endpoint Configuration (%s) still exists", rs.Primary.ID) + return fmt.Errorf("SageMaker Endpoint Configuration %s still exists", rs.Primary.ID) } return nil @@ -307,10 +411,8 @@ func testAccCheckSagemakerEndpointConfigurationExists(n string) resource.TestChe } conn := testAccProvider.Meta().(*AWSClient).sagemakerconn - opts := &sagemaker.DescribeEndpointConfigInput{ - EndpointConfigName: aws.String(rs.Primary.ID), - } - _, err := conn.DescribeEndpointConfig(opts) + _, err := finder.EndpointConfigByName(conn, rs.Primary.ID) + if err != nil { return err } @@ -412,7 +514,7 @@ resource "aws_sagemaker_endpoint_configuration" "test" { func testAccSagemakerEndpointConfiguration_Config_KmsKeyId(rName string) string { return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(` resource "aws_sagemaker_endpoint_configuration" "test" { - name = %q + name = %[1]q kms_key_arn = aws_kms_key.test.arn production_variants { @@ -425,10 +527,10 @@ resource "aws_sagemaker_endpoint_configuration" "test" { } resource "aws_kms_key" "test" { - description = %q + description = %[1]q deletion_window_in_days = 10 } -`, rName, rName) +`, rName) } func testAccSagemakerEndpointConfigurationConfigTags1(rName, tagKey1, tagValue1 string) string { @@ -511,3 +613,118 @@ resource "aws_sagemaker_endpoint_configuration" "test" { } `, rName) } + +func testAccSagemakerEndpointConfigurationConfigAsyncConfig(rName string) string { + return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(` +resource "aws_s3_bucket" "test" { + bucket = %[1]q + acl = "private" + force_destroy = true +} + +resource "aws_kms_key" "test" { + description = %[1]q + deletion_window_in_days = 7 +} + +resource "aws_sagemaker_endpoint_configuration" "test" { + name = %[1]q + + production_variants { + variant_name = "variant-1" + model_name = aws_sagemaker_model.test.name + initial_instance_count = 2 + instance_type = "ml.t2.medium" + initial_variant_weight = 1 + } + + async_inference_config { + output_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/" + kms_key_id = aws_kms_key.test.arn + } + } +} +`, rName) +} + +func testAccSagemakerEndpointConfigurationConfigAsyncNotifConfig(rName string) string { + return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(` +resource "aws_s3_bucket" "test" { + bucket = %[1]q + acl = "private" + force_destroy = true +} + +resource "aws_sns_topic" "test" { + name = %[1]q +} + +resource "aws_kms_key" "test" { + description = %[1]q + deletion_window_in_days = 7 +} + +resource "aws_sagemaker_endpoint_configuration" "test" { + name = %[1]q + + production_variants { + variant_name = "variant-1" + model_name = aws_sagemaker_model.test.name + initial_instance_count = 2 + instance_type = "ml.t2.medium" + initial_variant_weight = 1 + } + + async_inference_config { + output_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/" + kms_key_id = aws_kms_key.test.arn + + notification_config { + error_topic = aws_sns_topic.test.arn + success_topic = aws_sns_topic.test.arn + } + } + } +} +`, rName) +} + +func testAccSagemakerEndpointConfigurationConfigAsyncClientConfig(rName string) string { + return testAccSagemakerEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(` +resource "aws_s3_bucket" "test" { + bucket = %[1]q + acl = "private" + force_destroy = true +} + +resource "aws_kms_key" "test" { + description = %[1]q + deletion_window_in_days = 7 +} + +resource "aws_sagemaker_endpoint_configuration" "test" { + name = %[1]q + + production_variants { + variant_name = "variant-1" + model_name = aws_sagemaker_model.test.name + initial_instance_count = 2 + instance_type = "ml.t2.medium" + initial_variant_weight = 1 + } + + async_inference_config { + client_config { + max_concurrent_invocations_per_instance = 1 + } + + output_config { + s3_output_path = "s3://${aws_s3_bucket.test.bucket}/" + kms_key_id = aws_kms_key.test.arn + } + } +} +`, rName) +} diff --git a/aws/resource_aws_sagemaker_test.go b/aws/resource_aws_sagemaker_test.go index 5cb14a8c43c..5c15b16016b 100644 --- a/aws/resource_aws_sagemaker_test.go +++ b/aws/resource_aws_sagemaker_test.go @@ -2,8 +2,22 @@ package aws import ( "testing" + + "github.com/aws/aws-sdk-go/service/sagemaker" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" ) +func init() { + RegisterServiceErrorCheckFunc(sagemaker.EndpointsID, testAccErrorCheckSkipSagemaker) +} + +func testAccErrorCheckSkipSagemaker(t *testing.T) resource.ErrorCheckFunc { + return testAccErrorCheckSkipMessagesContaining(t, + "is not supported in region", + "is not supported for the chosen region", + ) +} + // Tests are serialized as SagmMaker Domain resources are limited to 1 per account by default. // SageMaker UserProfile and App depend on the Domain resources and as such are also part of the serialized test suite. // Sagemaker Workteam tests must also be serialized diff --git a/website/docs/r/sagemaker_endpoint_configuration.html.markdown b/website/docs/r/sagemaker_endpoint_configuration.html.markdown index 1134bb51a0f..0afde0b3795 100644 --- a/website/docs/r/sagemaker_endpoint_configuration.html.markdown +++ b/website/docs/r/sagemaker_endpoint_configuration.html.markdown @@ -41,6 +41,7 @@ The following arguments are supported: * `name` - (Optional) The name of the endpoint configuration. If omitted, Terraform will assign a random, unique name. * `tags` - (Optional) A mapping of tags to assign to the resource. If configured with a provider [`default_tags` configuration block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level. * `data_capture_config` - (Optional) Specifies the parameters to capture input/output of Sagemaker models endpoints. Fields are documented below. +* `async_inference_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference. The `production_variants` block supports: @@ -69,6 +70,26 @@ The `capture_content_type_header` block supports: * `csv_content_types` - (Optional) The CSV content type headers to capture. * `json_content_types` - (Optional) The JSON content type headers to capture. +The `async_inference_config` block supports: + +* `output_config` - (Required) Specifies the configuration for asynchronous inference invocation outputs. +* `client_config` - (Optional) Configures the behavior of the client used by Amazon SageMaker to interact with the model container during asynchronous inference. + +The `client_config` block supports: + +* `max_concurrent_invocations_per_instance` - (Optional) The maximum number of concurrent requests sent by the SageMaker client to the model container. If no value is provided, Amazon SageMaker will choose an optimal value for you. + +The `output_config` block supports: + +* `s3_output_path` - (Required) The Amazon S3 location to upload inference responses to. +* `kms_key_id` - (Optional) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the asynchronous inference output in Amazon S3. +* `notification_config` - (Optional) Specifies the configuration for notifications of inference results for asynchronous inference. + +The `notification_config` block supports: + +* `error_topic` - (Optional) Amazon SNS topic to post a notification to when inference fails. If no topic is provided, no notification is sent on failure. +* `success_topic` - (Optional) Amazon SNS topic to post a notification to when inference completes successfully. If no topic is provided, no notification is sent on success. + ## Attributes Reference In addition to all arguments above, the following attributes are exported: