Skip to content

Commit

Permalink
Merge pull request #32740 from jeremychauvet/b-2988-rds-filter-instan…
Browse files Browse the repository at this point in the history
…ces-by-tags

enhancement: add support of tags to filter RDS instances (#2988)
  • Loading branch information
ewbankkit authored Aug 3, 2023
2 parents 170aa7a + 35302ef commit 609022c
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 102 deletions.
7 changes: 7 additions & 0 deletions .changelog/32740.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:enhancement
data-source/aws_db_instances: Add ability to filter by `tags`
```

```release-note:enhancement
data-source/aws_db_instance: Add ability to filter by `tags`
```
53 changes: 43 additions & 10 deletions internal/service/rds/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
"github.com/hashicorp/terraform-provider-aws/internal/flex"
tfslices "github.com/hashicorp/terraform-provider-aws/internal/slices"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
"github.com/hashicorp/terraform-provider-aws/internal/verify"
Expand Down Expand Up @@ -2368,9 +2369,10 @@ func parseDBInstanceARN(s string) (dbInstanceARN, error) {
// "db-BE6UI2KLPQP3OVDYD74ZEV6NUM" rather than a DB identifier. However, in some cases only
// the identifier is available, and can be used.
func findDBInstanceByIDSDKv1(ctx context.Context, conn *rds.RDS, id string) (*rds.DBInstance, error) {
idLooksLikeDbiResourceId := regexp.MustCompile(`^db-[a-zA-Z0-9]{2,255}$`).MatchString(id)
input := &rds.DescribeDBInstancesInput{}

if regexp.MustCompile(`^db-[a-zA-Z0-9]{2,255}$`).MatchString(id) {
if idLooksLikeDbiResourceId {
input.Filters = []*rds.Filter{
{
Name: aws.String("dbi-resource-id"),
Expand All @@ -2381,16 +2383,51 @@ func findDBInstanceByIDSDKv1(ctx context.Context, conn *rds.RDS, id string) (*rd
input.DBInstanceIdentifier = aws.String(id)
}

output, err := conn.DescribeDBInstancesWithContext(ctx, input)
output, err := findDBInstanceSDKv1(ctx, conn, input)

// in case a DB has an *identifier* starting with "db-""
if regexp.MustCompile(`^db-[a-zA-Z0-9]{2,255}$`).MatchString(id) && (output == nil || len(output.DBInstances) == 0) {
input = &rds.DescribeDBInstancesInput{
if idLooksLikeDbiResourceId && tfresource.NotFound(err) {
input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(id),
}
output, err = conn.DescribeDBInstancesWithContext(ctx, input)

output, err = findDBInstanceSDKv1(ctx, conn, input)
}

if err != nil {
return nil, err
}

return output, nil
}

func findDBInstanceSDKv1(ctx context.Context, conn *rds.RDS, input *rds.DescribeDBInstancesInput) (*rds.DBInstance, error) {
output, err := findDBInstancesSDKv1(ctx, conn, input, tfslices.PredicateTrue[*rds.DBInstance]())

if err != nil {
return nil, err
}

return tfresource.AssertSinglePtrResult(output)
}

func findDBInstancesSDKv1(ctx context.Context, conn *rds.RDS, input *rds.DescribeDBInstancesInput, filter tfslices.Predicate[*rds.DBInstance]) ([]*rds.DBInstance, error) {
var output []*rds.DBInstance

err := conn.DescribeDBInstancesPagesWithContext(ctx, input, func(page *rds.DescribeDBInstancesOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}

for _, v := range page.DBInstances {
if v != nil && filter(v) {
output = append(output, v)
}
}

return !lastPage
})

if tfawserr.ErrCodeEquals(err, rds.ErrCodeDBInstanceNotFoundFault) {
return nil, &retry.NotFoundError{
LastError: err,
Expand All @@ -2402,11 +2439,7 @@ func findDBInstanceByIDSDKv1(ctx context.Context, conn *rds.RDS, id string) (*rd
return nil, err
}

if output == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return tfresource.AssertSinglePtrResult(output.DBInstances)
return output, nil
}

// findDBInstanceByIDSDKv2 in general should be called with a DbiResourceId of the form
Expand Down
155 changes: 92 additions & 63 deletions internal/service/rds/instance_data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
tfslices "github.com/hashicorp/terraform-provider-aws/internal/slices"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
"github.com/hashicorp/terraform-provider-aws/names"
)

// @SDKDataSource("aws_db_instance")
// @SDKDataSource("aws_db_instance", name="DB Instance")
// @Tags
func DataSourceInstance() *schema.Resource {
return &schema.Resource{
ReadWithoutTimeout: dataSourceInstanceRead,
Expand Down Expand Up @@ -60,9 +63,9 @@ func DataSourceInstance() *schema.Resource {
Computed: true,
},
"db_instance_identifier": {
Type: schema.TypeString,
Required: true,
ValidateFunc: validation.StringIsNotEmpty,
Type: schema.TypeString,
Optional: true,
Computed: true,
},
"db_instance_port": {
Type: schema.TypeInt,
Expand Down Expand Up @@ -199,7 +202,7 @@ func DataSourceInstance() *schema.Resource {
Type: schema.TypeString,
Computed: true,
},
"tags": tftags.TagsSchemaComputed(),
names.AttrTags: tftags.TagsSchemaComputed(),
"timezone": {
Type: schema.TypeString,
Computed: true,
Expand All @@ -216,74 +219,104 @@ func DataSourceInstance() *schema.Resource {
func dataSourceInstanceRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).RDSConn(ctx)
ignoreTagsConfig := meta.(*conns.AWSClient).IgnoreTagsConfig

v, err := findDBInstanceByIDSDKv1(ctx, conn, d.Get("db_instance_identifier").(string))
if err != nil {
return diag.FromErr(tfresource.SingularDataSourceFindError("RDS DB Instance", err))
var instance *rds.DBInstance

filter := tfslices.PredicateTrue[*rds.DBInstance]()
if tags := getTagsIn(ctx); len(tags) > 0 {
filter = func(v *rds.DBInstance) bool {
return KeyValueTags(ctx, v.TagList).ContainsAll(KeyValueTags(ctx, tags))
}
}

d.SetId(aws.StringValue(v.DBInstanceIdentifier))
d.Set("allocated_storage", v.AllocatedStorage)
d.Set("auto_minor_version_upgrade", v.AutoMinorVersionUpgrade)
d.Set("availability_zone", v.AvailabilityZone)
d.Set("backup_retention_period", v.BackupRetentionPeriod)
d.Set("ca_cert_identifier", v.CACertificateIdentifier)
d.Set("db_cluster_identifier", v.DBClusterIdentifier)
d.Set("db_instance_arn", v.DBInstanceArn)
d.Set("db_instance_class", v.DBInstanceClass)
d.Set("db_instance_port", v.DbInstancePort)
d.Set("db_name", v.DBName)
var parameterGroupNames []string
for _, v := range v.DBParameterGroups {
parameterGroupNames = append(parameterGroupNames, aws.StringValue(v.DBParameterGroupName))
if v, ok := d.GetOk("db_instance_identifier"); ok {
id := v.(string)
output, err := findDBInstanceByIDSDKv1(ctx, conn, id)

if err != nil {
return sdkdiag.AppendErrorf(diags, "reading RDS DB Instance (%s): %s", id, err)
}

if !filter(output) {
return sdkdiag.AppendErrorf(diags, "Your query returned no results. Please change your search criteria and try again.")
}

instance = output
} else {
input := &rds.DescribeDBInstancesInput{}
instances, err := findDBInstancesSDKv1(ctx, conn, input, filter)

if err != nil {
return sdkdiag.AppendErrorf(diags, "reading RDS DB Instances: %s", err)
}

output, err := tfresource.AssertSinglePtrResult(instances)

if err != nil {
return sdkdiag.AppendFromErr(diags, tfresource.SingularDataSourceFindError("RDS DB Instance", err))
}

instance = output
}

d.SetId(aws.StringValue(instance.DBInstanceIdentifier))
d.Set("allocated_storage", instance.AllocatedStorage)
d.Set("auto_minor_version_upgrade", instance.AutoMinorVersionUpgrade)
d.Set("availability_zone", instance.AvailabilityZone)
d.Set("backup_retention_period", instance.BackupRetentionPeriod)
d.Set("ca_cert_identifier", instance.CACertificateIdentifier)
d.Set("db_cluster_identifier", instance.DBClusterIdentifier)
d.Set("db_instance_arn", instance.DBInstanceArn)
d.Set("db_instance_class", instance.DBInstanceClass)
d.Set("db_instance_port", instance.DbInstancePort)
d.Set("db_name", instance.DBName)
parameterGroupNames := tfslices.ApplyToAll(instance.DBParameterGroups, func(v *rds.DBParameterGroupStatus) string {
return aws.StringValue(v.DBParameterGroupName)
})
d.Set("db_parameter_groups", parameterGroupNames)
if v.DBSubnetGroup != nil {
d.Set("db_subnet_group", v.DBSubnetGroup.DBSubnetGroupName)
if instance.DBSubnetGroup != nil {
d.Set("db_subnet_group", instance.DBSubnetGroup.DBSubnetGroupName)
} else {
d.Set("db_subnet_group", "")
}
d.Set("enabled_cloudwatch_logs_exports", aws.StringValueSlice(v.EnabledCloudwatchLogsExports))
d.Set("engine", v.Engine)
d.Set("engine_version", v.EngineVersion)
d.Set("iops", v.Iops)
d.Set("kms_key_id", v.KmsKeyId)
d.Set("license_model", v.LicenseModel)
d.Set("master_username", v.MasterUsername)
if v.MasterUserSecret != nil {
if err := d.Set("master_user_secret", []interface{}{flattenManagedMasterUserSecret(v.MasterUserSecret)}); err != nil {
d.Set("enabled_cloudwatch_logs_exports", aws.StringValueSlice(instance.EnabledCloudwatchLogsExports))
d.Set("engine", instance.Engine)
d.Set("engine_version", instance.EngineVersion)
d.Set("iops", instance.Iops)
d.Set("kms_key_id", instance.KmsKeyId)
d.Set("license_model", instance.LicenseModel)
d.Set("master_username", instance.MasterUsername)
if instance.MasterUserSecret != nil {
if err := d.Set("master_user_secret", []interface{}{flattenManagedMasterUserSecret(instance.MasterUserSecret)}); err != nil {
return sdkdiag.AppendErrorf(diags, "setting master_user_secret: %s", err)
}
}
d.Set("max_allocated_storage", v.MaxAllocatedStorage)
d.Set("monitoring_interval", v.MonitoringInterval)
d.Set("monitoring_role_arn", v.MonitoringRoleArn)
d.Set("multi_az", v.MultiAZ)
d.Set("network_type", v.NetworkType)
var optionGroupNames []string
for _, v := range v.OptionGroupMemberships {
optionGroupNames = append(optionGroupNames, aws.StringValue(v.OptionGroupName))
}
d.Set("max_allocated_storage", instance.MaxAllocatedStorage)
d.Set("monitoring_interval", instance.MonitoringInterval)
d.Set("monitoring_role_arn", instance.MonitoringRoleArn)
d.Set("multi_az", instance.MultiAZ)
d.Set("network_type", instance.NetworkType)
optionGroupNames := tfslices.ApplyToAll(instance.OptionGroupMemberships, func(v *rds.OptionGroupMembership) string {
return aws.StringValue(v.OptionGroupName)
})
d.Set("option_group_memberships", optionGroupNames)
d.Set("preferred_backup_window", v.PreferredBackupWindow)
d.Set("preferred_maintenance_window", v.PreferredMaintenanceWindow)
d.Set("publicly_accessible", v.PubliclyAccessible)
d.Set("replicate_source_db", v.ReadReplicaSourceDBInstanceIdentifier)
d.Set("resource_id", v.DbiResourceId)
d.Set("storage_encrypted", v.StorageEncrypted)
d.Set("storage_throughput", v.StorageThroughput)
d.Set("storage_type", v.StorageType)
d.Set("timezone", v.Timezone)
var vpcSecurityGroupIDs []string
for _, v := range v.VpcSecurityGroups {
vpcSecurityGroupIDs = append(vpcSecurityGroupIDs, aws.StringValue(v.VpcSecurityGroupId))
}
d.Set("preferred_backup_window", instance.PreferredBackupWindow)
d.Set("preferred_maintenance_window", instance.PreferredMaintenanceWindow)
d.Set("publicly_accessible", instance.PubliclyAccessible)
d.Set("replicate_source_db", instance.ReadReplicaSourceDBInstanceIdentifier)
d.Set("resource_id", instance.DbiResourceId)
d.Set("storage_encrypted", instance.StorageEncrypted)
d.Set("storage_throughput", instance.StorageThroughput)
d.Set("storage_type", instance.StorageType)
d.Set("timezone", instance.Timezone)
vpcSecurityGroupIDs := tfslices.ApplyToAll(instance.VpcSecurityGroups, func(v *rds.VpcSecurityGroupMembership) string {
return aws.StringValue(v.VpcSecurityGroupId)
})
d.Set("vpc_security_groups", vpcSecurityGroupIDs)

// Per AWS SDK Go docs:
// The endpoint might not be shown for instances whose status is creating.
if dbEndpoint := v.Endpoint; dbEndpoint != nil {
if dbEndpoint := instance.Endpoint; dbEndpoint != nil {
d.Set("address", dbEndpoint.Address)
d.Set("endpoint", fmt.Sprintf("%s:%d", aws.StringValue(dbEndpoint.Address), aws.Int64Value(dbEndpoint.Port)))
d.Set("hosted_zone_id", dbEndpoint.HostedZoneId)
Expand All @@ -295,11 +328,7 @@ func dataSourceInstanceRead(ctx context.Context, d *schema.ResourceData, meta in
d.Set("port", nil)
}

tags := KeyValueTags(ctx, v.TagList)

if err := d.Set("tags", tags.IgnoreAWS().IgnoreConfig(ignoreTagsConfig).Map()); err != nil {
return sdkdiag.AppendErrorf(diags, "setting tags: %s", err)
}
setTagsOut(ctx, instance.TagList)

return diags
}
Loading

0 comments on commit 609022c

Please sign in to comment.