diff --git a/.ci/semgrep/stdlib/sort.yml b/.ci/semgrep/stdlib/sort.yml index 50ed6566b98..221abc72adf 100644 --- a/.ci/semgrep/stdlib/sort.yml +++ b/.ci/semgrep/stdlib/sort.yml @@ -1,8 +1,14 @@ rules: - id: prefer-slices-sortfunc languages: [go] - message: Prefer slices.SortFunc to sort.Slice - pattern: sort.Slice(...) + message: Prefer slices.SortFunc to sort.$FUNC + patterns: + - pattern: sort.$FUNC(...) + - metavariable-pattern: + metavariable: $FUNC + pattern-either: + - pattern: Sort + - pattern: Slice severity: WARNING - id: prefer-slices-sortstablefunc @@ -38,3 +44,9 @@ rules: - pattern: StringsAreSorted fix: slices.IsSorted($X) severity: WARNING + + - id: prefer-slices-issortedfunc + languages: [go] + message: Prefer slices.IsSortedFunc to sort.IsSorted + pattern: sort.IsSorted($X) + severity: WARNING diff --git a/internal/service/dax/cluster.go b/internal/service/dax/cluster.go index e81a9d67b57..92ae1b38c7d 100644 --- a/internal/service/dax/cluster.go +++ b/internal/service/dax/cluster.go @@ -4,11 +4,12 @@ package dax import ( + "cmp" "context" "fmt" "log" "reflect" - "sort" + "slices" "strings" "time" @@ -472,9 +473,9 @@ func resourceClusterUpdate(ctx context.Context, d *schema.ResourceData, meta int } func setClusterNodeData(d *schema.ResourceData, c awstypes.Cluster) error { - sortedNodes := make([]awstypes.Node, len(c.Nodes)) - copy(sortedNodes, c.Nodes) - sort.Sort(byNodeId(sortedNodes)) + sortedNodes := slices.SortedFunc(slices.Values(c.Nodes), func(a, b awstypes.Node) int { + return cmp.Compare(aws.ToString(a.NodeId), aws.ToString(b.NodeId)) + }) nodeData := make([]map[string]interface{}, 0, len(sortedNodes)) @@ -490,15 +491,6 @@ func setClusterNodeData(d *schema.ResourceData, c awstypes.Cluster) error { return d.Set("nodes", nodeData) } -type byNodeId []awstypes.Node - -func (b byNodeId) Len() int { return len(b) } -func (b byNodeId) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byNodeId) Less(i, j int) bool { - return b[i].NodeId != nil && b[j].NodeId != nil && - aws.ToString(b[i].NodeId) < aws.ToString(b[j].NodeId) -} - func resourceClusterDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { var diags diag.Diagnostics conn := meta.(*conns.AWSClient).DAXClient(ctx) diff --git a/internal/service/ec2/ebs_volume_data_source.go b/internal/service/ec2/ebs_volume_data_source.go index 31a120d8575..db58735ed49 100644 --- a/internal/service/ec2/ebs_volume_data_source.go +++ b/internal/service/ec2/ebs_volume_data_source.go @@ -6,7 +6,7 @@ package ec2 import ( "context" "fmt" - "sort" + "slices" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -159,18 +159,8 @@ func dataSourceEBSVolumeRead(ctx context.Context, d *schema.ResourceData, meta i return diags } -type volumeSort []awstypes.Volume - -func (a volumeSort) Len() int { return len(a) } -func (a volumeSort) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a volumeSort) Less(i, j int) bool { - itime := aws.ToTime(a[i].CreateTime) - jtime := aws.ToTime(a[j].CreateTime) - return itime.Unix() < jtime.Unix() -} - func mostRecentVolume(volumes []awstypes.Volume) awstypes.Volume { - sortedVolumes := volumes - sort.Sort(volumeSort(sortedVolumes)) - return sortedVolumes[len(sortedVolumes)-1] + return slices.MaxFunc(volumes, func(a, b awstypes.Volume) int { + return a.CreateTime.Compare(aws.ToTime(b.CreateTime)) + }) } diff --git a/internal/service/ec2/vpc_security_group_rule.go b/internal/service/ec2/vpc_security_group_rule.go index 05c305805c3..3ed24dc2af3 100644 --- a/internal/service/ec2/vpc_security_group_rule.go +++ b/internal/service/ec2/vpc_security_group_rule.go @@ -5,11 +5,12 @@ package ec2 import ( "bytes" + "cmp" "context" + "errors" "fmt" "log" "slices" - "sort" "strconv" "strings" "time" @@ -173,7 +174,10 @@ func resourceSecurityGroupRuleCreate(ctx context.Context, d *schema.ResourceData ipPermission := expandIPPermission(d, sg) ruleType := securityGroupRuleType(d.Get(names.AttrType).(string)) - id := securityGroupRuleCreateID(securityGroupID, string(ruleType), &ipPermission) + id, err := securityGroupRuleCreateID(securityGroupID, string(ruleType), &ipPermission) + if err != nil { + return sdkdiag.AppendErrorf(diags, "reading Security Group (%s): %s", securityGroupID, err) + } switch ruleType { case securityGroupRuleTypeIngress: @@ -307,7 +311,10 @@ func resourceSecurityGroupRuleRead(ctx context.Context, d *schema.ResourceData, if strings.Contains(d.Id(), securityGroupRuleIDSeparator) { // import so fix the id - id := securityGroupRuleCreateID(securityGroupID, string(ruleType), &ipPermission) + id, err := securityGroupRuleCreateID(securityGroupID, string(ruleType), &ipPermission) + if err != nil { + return sdkdiag.AppendErrorf(diags, "reading Security Group (%s) Rule (%s): %s", securityGroupID, d.Id(), err) + } d.SetId(id) } @@ -674,25 +681,7 @@ func findSecurityGroupRuleMatch(p awstypes.IpPermission, securityGroupRules []aw const securityGroupRuleIDSeparator = "_" -// byGroupPair implements sort.Interface for []*ec2.UserIDGroupPairs based on -// GroupID or GroupName field (only one should be set). -type byGroupPair []awstypes.UserIdGroupPair - -func (b byGroupPair) Len() int { return len(b) } -func (b byGroupPair) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byGroupPair) Less(i, j int) bool { - if b[i].GroupId != nil && b[j].GroupId != nil { - return aws.ToString(b[i].GroupId) < aws.ToString(b[j].GroupId) - } - if b[i].GroupName != nil && b[j].GroupName != nil { - return aws.ToString(b[i].GroupName) < aws.ToString(b[j].GroupName) - } - - //lintignore:R009 - panic("mismatched security group rules, may be a terraform bug") -} - -func securityGroupRuleCreateID(securityGroupID, ruleType string, ip *awstypes.IpPermission) string { +func securityGroupRuleCreateID(securityGroupID, ruleType string, ip *awstypes.IpPermission) (string, error) { var buf bytes.Buffer buf.WriteString(fmt.Sprintf("%s-", securityGroupID)) @@ -744,22 +733,35 @@ func securityGroupRuleCreateID(securityGroupID, ruleType string, ip *awstypes.Ip } if len(ip.UserIdGroupPairs) > 0 { - sort.Sort(byGroupPair(ip.UserIdGroupPairs)) + var err error + slices.SortFunc(ip.UserIdGroupPairs, func(a, b awstypes.UserIdGroupPair) int { + if a.GroupId != nil && b.GroupId != nil { + return cmp.Compare(aws.ToString(a.GroupId), aws.ToString(b.GroupId)) + } + if a.GroupName != nil && b.GroupName != nil { + return cmp.Compare(aws.ToString(a.GroupName), aws.ToString(b.GroupName)) + } + err = errors.New("mismatched security group rules: contains both GroupId and GroupName") + return 0 + }) + if err != nil { + return "", err + } for _, pair := range ip.UserIdGroupPairs { if pair.GroupId != nil { - buf.WriteString(fmt.Sprintf("%s-", *pair.GroupId)) + buf.WriteString(fmt.Sprintf("%s-", aws.ToString(pair.GroupId))) } else { buf.WriteString("-") } if pair.GroupName != nil { - buf.WriteString(fmt.Sprintf("%s-", *pair.GroupName)) + buf.WriteString(fmt.Sprintf("%s-", aws.ToString(pair.GroupName))) } else { buf.WriteString("-") } } } - return fmt.Sprintf("sgrule-%d", create.StringHashcode(buf.String())) + return fmt.Sprintf("sgrule-%d", create.StringHashcode(buf.String())), nil } func expandIPPermission(d *schema.ResourceData, sg *awstypes.SecurityGroup) awstypes.IpPermission { // nosemgrep:ci.caps5-in-func-name diff --git a/internal/service/ec2/vpc_security_group_rule_migrate.go b/internal/service/ec2/vpc_security_group_rule_migrate.go index eff9e9bccd5..30c48d3309c 100644 --- a/internal/service/ec2/vpc_security_group_rule_migrate.go +++ b/internal/service/ec2/vpc_security_group_rule_migrate.go @@ -39,13 +39,15 @@ func migrateSGRuleStateV0toV1(is *terraform.InstanceState) (*terraform.InstanceS } perm, err := migrateExpandIPPerm(is.Attributes) - if err != nil { return nil, fmt.Errorf("making new IP Permission in Security Group migration") } log.Printf("[DEBUG] Attributes before migration: %#v", is.Attributes) - newID := securityGroupRuleCreateID(is.Attributes["security_group_id"], is.Attributes[names.AttrType], perm) + newID, err := securityGroupRuleCreateID(is.Attributes["security_group_id"], is.Attributes[names.AttrType], perm) + if err != nil { + return nil, err + } is.Attributes[names.AttrID] = newID is.ID = newID log.Printf("[DEBUG] Attributes after migration: %#v, new id: %s", is.Attributes, newID) diff --git a/internal/service/ec2/vpc_security_group_rule_test.go b/internal/service/ec2/vpc_security_group_rule_test.go index 518d2428912..1014ad90c2f 100644 --- a/internal/service/ec2/vpc_security_group_rule_test.go +++ b/internal/service/ec2/vpc_security_group_rule_test.go @@ -108,7 +108,10 @@ func TestSecurityGroupRuleCreateID(t *testing.T) { } for _, tc := range cases { - actual := tfec2.SecurityGroupRuleCreateID("sg-12345", tc.Type, &tc.Input) + actual, err := tfec2.SecurityGroupRuleCreateID("sg-12345", tc.Type, &tc.Input) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } if actual != tc.Output { t.Errorf("input: %s - %#v\noutput: %s", tc.Type, tc.Input, actual) } diff --git a/internal/service/ec2/vpnsite_connection.go b/internal/service/ec2/vpnsite_connection.go index 5e5936d6e2a..eaa9d44ab95 100644 --- a/internal/service/ec2/vpnsite_connection.go +++ b/internal/service/ec2/vpnsite_connection.go @@ -4,12 +4,13 @@ package ec2 import ( + "cmp" "context" "encoding/xml" "fmt" "log" "net" - "sort" + "slices" "strconv" "time" @@ -1565,18 +1566,6 @@ type tunnelInfo struct { Tunnel2VgwInsideAddress string } -func (slice xmlVpnConnectionConfig) Len() int { - return len(slice.Tunnels) -} - -func (slice xmlVpnConnectionConfig) Less(i, j int) bool { - return slice.Tunnels[i].OutsideAddress < slice.Tunnels[j].OutsideAddress -} - -func (slice xmlVpnConnectionConfig) Swap(i, j int) { - slice.Tunnels[i], slice.Tunnels[j] = slice.Tunnels[j], slice.Tunnels[i] -} - // customerGatewayConfigurationToTunnelInfo converts the configuration information for the // VPN connection's customer gateway (in the native XML format) to a tunnelInfo structure. // The tunnel1 parameters are optionally used to correctly order tunnel configurations. @@ -1616,7 +1605,9 @@ func customerGatewayConfigurationToTunnelInfo(xmlConfig string, tunnel1PreShared } } } else { - sort.Sort(vpnConfig) + slices.SortFunc(vpnConfig.Tunnels, func(a, b xmlIpsecTunnel) int { + return cmp.Compare(a.OutsideAddress, b.OutsideAddress) + }) } tunnelInfo := &tunnelInfo{ diff --git a/internal/service/elasticache/cluster.go b/internal/service/elasticache/cluster.go index f697c94627e..895b5f3d4da 100644 --- a/internal/service/elasticache/cluster.go +++ b/internal/service/elasticache/cluster.go @@ -4,11 +4,12 @@ package elasticache import ( + "cmp" "context" "errors" "fmt" "log" - "sort" + "slices" "strconv" "strings" "time" @@ -944,9 +945,9 @@ func getCacheNodesToRemove(oldNumberOfNodes int, cacheNodesToRemove int) []strin } func setCacheNodeData(d *schema.ResourceData, c *awstypes.CacheCluster) error { - sortedCacheNodes := make([]awstypes.CacheNode, len(c.CacheNodes)) - copy(sortedCacheNodes, c.CacheNodes) - sort.Sort(byCacheNodeId(sortedCacheNodes)) + sortedCacheNodes := slices.SortedFunc(slices.Values(c.CacheNodes), func(a, b awstypes.CacheNode) int { + return cmp.Compare(aws.ToString(a.CacheNodeId), aws.ToString(b.CacheNodeId)) + }) cacheNodeData := make([]map[string]interface{}, 0, len(sortedCacheNodes)) @@ -966,15 +967,6 @@ func setCacheNodeData(d *schema.ResourceData, c *awstypes.CacheCluster) error { return d.Set("cache_nodes", cacheNodeData) } -type byCacheNodeId []awstypes.CacheNode - -func (b byCacheNodeId) Len() int { return len(b) } -func (b byCacheNodeId) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byCacheNodeId) Less(i, j int) bool { - return b[i].CacheNodeId != nil && b[j].CacheNodeId != nil && - aws.ToString(b[i].CacheNodeId) < aws.ToString(b[j].CacheNodeId) -} - func setFromCacheCluster(d *schema.ResourceData, c *awstypes.CacheCluster) error { d.Set("node_type", c.CacheNodeType) diff --git a/internal/service/iam/policy_model.go b/internal/service/iam/policy_model.go index b9292764c01..dfec584d3e4 100644 --- a/internal/service/iam/policy_model.go +++ b/internal/service/iam/policy_model.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "slices" - "sort" "strconv" "github.com/YakDriver/regexache" @@ -86,7 +85,7 @@ func (s *IAMPolicyDoc) Merge(newDoc *IAMPolicyDoc) { func (ps IAMPolicyStatementPrincipalSet) MarshalJSON() ([]byte, error) { raw := map[string]interface{}{} - // Although IAM documentation says, that "*" and {"AWS": "*"} are equivalent + // Although IAM documentation says that "*" and {"AWS": "*"} are equivalent // (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html), // in practice they are not for IAM roles. IAM will return an error if trust // policy have "*" or {"*": "*"} as principal, but will accept {"AWS": "*"}. @@ -115,7 +114,8 @@ func (ps IAMPolicyStatementPrincipalSet) MarshalJSON() ([]byte, error) { raw[p.Type] = make([]string, 0, len(i)+1) raw[p.Type] = append(raw[p.Type].([]string), v) } - sort.Sort(sort.Reverse(sort.StringSlice(i))) + slices.Sort(i) + slices.Reverse(i) raw[p.Type] = append(raw[p.Type].([]string), i...) case string: switch v := raw[p.Type].(type) { @@ -243,7 +243,8 @@ func policyDecodeConfigStringList(lI []interface{}) interface{} { for i, vI := range lI { ret[i] = vI.(string) } - sort.Sort(sort.Reverse(sort.StringSlice(ret))) + slices.Sort(ret) + slices.Reverse(ret) return ret } diff --git a/internal/service/iam/server_certificate_data_source.go b/internal/service/iam/server_certificate_data_source.go index 7f26db48156..a42397d5dab 100644 --- a/internal/service/iam/server_certificate_data_source.go +++ b/internal/service/iam/server_certificate_data_source.go @@ -6,7 +6,7 @@ package iam import ( "context" "log" - "sort" + "slices" "strings" "time" @@ -87,20 +87,6 @@ func dataSourceServerCertificate() *schema.Resource { } } -type CertificateByExpiration []awstypes.ServerCertificateMetadata - -func (m CertificateByExpiration) Len() int { - return len(m) -} - -func (m CertificateByExpiration) Swap(i, j int) { - m[i], m[j] = m[j], m[i] -} - -func (m CertificateByExpiration) Less(i, j int) bool { - return m[i].Expiration.After(*m[j].Expiration) -} - func dataSourceServerCertificateRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { var diags diag.Diagnostics conn := meta.(*conns.AWSClient).IAMClient(ctx) @@ -136,15 +122,13 @@ func dataSourceServerCertificateRead(ctx context.Context, d *schema.ResourceData if len(metadatas) == 0 { return sdkdiag.AppendErrorf(diags, "Search for AWS IAM server certificate returned no results") } - if len(metadatas) > 1 { - if !d.Get("latest").(bool) { - return sdkdiag.AppendErrorf(diags, "Search for AWS IAM server certificate returned too many results") - } - - sort.Sort(CertificateByExpiration(metadatas)) + if len(metadatas) > 1 && !d.Get("latest").(bool) { + return sdkdiag.AppendErrorf(diags, "Search for AWS IAM server certificate returned too many results") } - metadata := metadatas[0] + metadata := slices.MaxFunc(metadatas, func(a, b awstypes.ServerCertificateMetadata) int { + return a.Expiration.Compare(aws.ToTime(b.Expiration)) + }) d.SetId(aws.ToString(metadata.ServerCertificateId)) d.Set(names.AttrARN, metadata.Arn) d.Set(names.AttrPath, metadata.Path) diff --git a/internal/service/iam/server_certificate_data_source_test.go b/internal/service/iam/server_certificate_data_source_test.go index 1f0db464631..8ce7f7006f5 100644 --- a/internal/service/iam/server_certificate_data_source_test.go +++ b/internal/service/iam/server_certificate_data_source_test.go @@ -5,43 +5,15 @@ package iam_test import ( "fmt" - "sort" "testing" - "time" "github.com/YakDriver/regexache" - "github.com/aws/aws-sdk-go-v2/aws" - awstypes "github.com/aws/aws-sdk-go-v2/service/iam/types" sdkacctest "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-provider-aws/internal/acctest" - tfiam "github.com/hashicorp/terraform-provider-aws/internal/service/iam" "github.com/hashicorp/terraform-provider-aws/names" ) -func TestResourceSortByExpirationDate(t *testing.T) { - t.Parallel() - - certs := []awstypes.ServerCertificateMetadata{ - { - ServerCertificateName: aws.String("oldest"), - Expiration: aws.Time(time.Now()), - }, - { - ServerCertificateName: aws.String("latest"), - Expiration: aws.Time(time.Now().Add(3 * time.Hour)), - }, - { - ServerCertificateName: aws.String("in between"), - Expiration: aws.Time(time.Now().Add(2 * time.Hour)), - }, - } - sort.Sort(tfiam.CertificateByExpiration(certs)) - if aws.ToString(certs[0].ServerCertificateName) != "latest" { - t.Fatalf("Expected first item to be %q, but was %q", "latest", *certs[0].ServerCertificateName) - } -} - func TestAccIAMServerCertificateDataSource_basic(t *testing.T) { ctx := acctest.Context(t) rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) diff --git a/internal/service/lambda/policy_model.go b/internal/service/lambda/policy_model.go index 05ea91b59db..f21f7c0f12a 100644 --- a/internal/service/lambda/policy_model.go +++ b/internal/service/lambda/policy_model.go @@ -6,7 +6,7 @@ package lambda import ( "encoding/json" "fmt" - "sort" + "slices" ) const ( @@ -80,7 +80,7 @@ func (s *IAMPolicyDoc) Merge(newDoc *IAMPolicyDoc) { func (ps IAMPolicyStatementPrincipalSet) MarshalJSON() ([]byte, error) { raw := map[string]interface{}{} - // Although IAM documentation says, that "*" and {"AWS": "*"} are equivalent + // Although IAM documentation says that "*" and {"AWS": "*"} are equivalent // (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html), // in practice they are not for IAM roles. IAM will return an error if trust // policy have "*" or {"*": "*"} as principal, but will accept {"AWS": "*"}. @@ -109,7 +109,8 @@ func (ps IAMPolicyStatementPrincipalSet) MarshalJSON() ([]byte, error) { raw[p.Type] = make([]string, 0, len(i)+1) raw[p.Type] = append(raw[p.Type].([]string), v) } - sort.Sort(sort.Reverse(sort.StringSlice(i))) + slices.Sort(i) + slices.Reverse(i) raw[p.Type] = append(raw[p.Type].([]string), i...) case string: switch v := raw[p.Type].(type) { diff --git a/internal/service/meta/ip_ranges_data_source_test.go b/internal/service/meta/ip_ranges_data_source_test.go index 24ec1384230..461f8d113a5 100644 --- a/internal/service/meta/ip_ranges_data_source_test.go +++ b/internal/service/meta/ip_ranges_data_source_test.go @@ -6,7 +6,7 @@ package meta_test import ( "fmt" "net" - "sort" + "slices" "strconv" "strings" "testing" @@ -169,7 +169,6 @@ func testAccIPRangesCheckCIDRBlocksAttribute(name, attribute string) resource.Te var ( cidrBlockSize int - cidrBlocks sort.StringSlice err error ) @@ -181,8 +180,7 @@ func testAccIPRangesCheckCIDRBlocksAttribute(name, attribute string) resource.Te return fmt.Errorf("%s for eu-west-1 seem suspiciously low: %d", attribute, cidrBlockSize) // lintignore:AWSAT003 } - cidrBlocks = make([]string, cidrBlockSize) - + cidrBlocks := make([]string, cidrBlockSize) for i := range cidrBlocks { cidrBlock := a[fmt.Sprintf("%s.%d", attribute, i)] @@ -194,8 +192,8 @@ func testAccIPRangesCheckCIDRBlocksAttribute(name, attribute string) resource.Te cidrBlocks[i] = cidrBlock } - if !sort.IsSorted(cidrBlocks) { - return fmt.Errorf("unexpected order of %s: %s", attribute, cidrBlocks) + if !slices.IsSorted(cidrBlocks) { + return fmt.Errorf("expected %s to be sorted: %s", attribute, cidrBlocks) } return nil diff --git a/internal/service/networkmanager/core_network_policy_model.go b/internal/service/networkmanager/core_network_policy_model.go index 2d15fa877e1..835342e5050 100644 --- a/internal/service/networkmanager/core_network_policy_model.go +++ b/internal/service/networkmanager/core_network_policy_model.go @@ -5,7 +5,7 @@ package networkmanager import ( "encoding/json" - "sort" + "slices" "github.com/hashicorp/terraform-provider-aws/internal/flex" ) @@ -140,7 +140,8 @@ func (c coreNetworkPolicySegmentAction) MarshalJSON() ([]byte, error) { func coreNetworkPolicyExpandStringList(configured []interface{}) interface{} { vs := flex.ExpandStringValueList(configured) - sort.Sort(sort.Reverse(sort.StringSlice(vs))) + slices.Sort(vs) + slices.Reverse(vs) return vs } diff --git a/internal/service/rds/cluster_snapshot_data_source.go b/internal/service/rds/cluster_snapshot_data_source.go index 521847f2914..3b5058a3f95 100644 --- a/internal/service/rds/cluster_snapshot_data_source.go +++ b/internal/service/rds/cluster_snapshot_data_source.go @@ -5,7 +5,7 @@ package rds import ( "context" - "sort" + "slices" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -151,17 +151,17 @@ func dataSourceClusterSnapshotRead(ctx context.Context, d *schema.ResourceData, return sdkdiag.AppendErrorf(diags, "Your query returned no results. Please change your search criteria and try again.") } - var snapshot *types.DBClusterSnapshot - if len(snapshots) > 1 { - if d.Get(names.AttrMostRecent).(bool) { - snapshot = mostRecentClusterSnapshot(snapshots) - } else { - return sdkdiag.AppendErrorf(diags, "Your query returned more than one result. Please try a more specific search criteria.") - } - } else { - snapshot = &snapshots[0] + if len(snapshots) > 1 && !d.Get(names.AttrMostRecent).(bool) { + return sdkdiag.AppendErrorf(diags, "Your query returned more than one result. Please try a more specific search criteria.") } + snapshot := slices.MaxFunc(snapshots, func(a, b types.DBClusterSnapshot) int { + if a.SnapshotCreateTime == nil || b.SnapshotCreateTime == nil { + return 0 + } + return a.SnapshotCreateTime.Compare(aws.ToTime(b.SnapshotCreateTime)) + }) + d.SetId(aws.ToString(snapshot.DBClusterSnapshotIdentifier)) d.Set(names.AttrAllocatedStorage, snapshot.AllocatedStorage) d.Set(names.AttrAvailabilityZones, snapshot.AvailabilityZones) @@ -186,25 +186,3 @@ func dataSourceClusterSnapshotRead(ctx context.Context, d *schema.ResourceData, return diags } - -type rdsClusterSnapshotSort []types.DBClusterSnapshot - -func (a rdsClusterSnapshotSort) Len() int { return len(a) } -func (a rdsClusterSnapshotSort) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a rdsClusterSnapshotSort) Less(i, j int) bool { - // Snapshot creation can be in progress - if a[i].SnapshotCreateTime == nil { - return true - } - if a[j].SnapshotCreateTime == nil { - return false - } - - return (aws.ToTime(a[i].SnapshotCreateTime)).Before(aws.ToTime(a[j].SnapshotCreateTime)) -} - -func mostRecentClusterSnapshot(snapshots []types.DBClusterSnapshot) *types.DBClusterSnapshot { - sortedSnapshots := snapshots - sort.Sort(rdsClusterSnapshotSort(sortedSnapshots)) - return &sortedSnapshots[len(sortedSnapshots)-1] -} diff --git a/internal/service/rds/snapshot_data_source.go b/internal/service/rds/snapshot_data_source.go index 9fae90b18cc..85c5ce99d60 100644 --- a/internal/service/rds/snapshot_data_source.go +++ b/internal/service/rds/snapshot_data_source.go @@ -5,7 +5,7 @@ package rds import ( "context" - "sort" + "slices" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -170,17 +170,17 @@ func dataSourceSnapshotRead(ctx context.Context, d *schema.ResourceData, meta in return sdkdiag.AppendErrorf(diags, "Your query returned no results. Please change your search criteria and try again.") } - var snapshot *types.DBSnapshot - if len(snapshots) > 1 { - if d.Get(names.AttrMostRecent).(bool) { - snapshot = mostRecentDBSnapshot(snapshots) - } else { - return sdkdiag.AppendErrorf(diags, "Your query returned more than one result. Please try a more specific search criteria.") - } - } else { - snapshot = &snapshots[0] + if len(snapshots) > 1 && !d.Get(names.AttrMostRecent).(bool) { + return sdkdiag.AppendErrorf(diags, "Your query returned more than one result. Please try a more specific search criteria.") } + snapshot := slices.MaxFunc(snapshots, func(a, b types.DBSnapshot) int { + if a.SnapshotCreateTime == nil || b.SnapshotCreateTime == nil { + return 0 + } + return a.SnapshotCreateTime.Compare(aws.ToTime(b.SnapshotCreateTime)) + }) + d.SetId(aws.ToString(snapshot.DBSnapshotIdentifier)) d.Set(names.AttrAllocatedStorage, snapshot.AllocatedStorage) d.Set(names.AttrAvailabilityZone, snapshot.AvailabilityZone) @@ -212,25 +212,3 @@ func dataSourceSnapshotRead(ctx context.Context, d *schema.ResourceData, meta in return diags } - -type rdsSnapshotSort []types.DBSnapshot - -func (a rdsSnapshotSort) Len() int { return len(a) } -func (a rdsSnapshotSort) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a rdsSnapshotSort) Less(i, j int) bool { - // Snapshot creation can be in progress - if a[i].SnapshotCreateTime == nil { - return true - } - if a[j].SnapshotCreateTime == nil { - return false - } - - return (aws.ToTime(a[i].SnapshotCreateTime)).Before(aws.ToTime(a[j].SnapshotCreateTime)) -} - -func mostRecentDBSnapshot(snapshots []types.DBSnapshot) *types.DBSnapshot { - sortedSnapshots := snapshots - sort.Sort(rdsSnapshotSort(sortedSnapshots)) - return &sortedSnapshots[len(sortedSnapshots)-1] -}