diff --git a/docs/data-sources/groups.md b/docs/data-sources/groups.md index 7f21b29d9c..b9f0391551 100644 --- a/docs/data-sources/groups.md +++ b/docs/data-sources/groups.md @@ -18,25 +18,44 @@ When authenticated with a user principal, this data source does not require any *Look up by group name* ```terraform -data "azuread_groups" "groups" { +data "azuread_groups" "example" { display_names = ["group-a", "group-b"] } ``` *Look up all groups* ```terraform -data "azuread_groups" "allGroups" { +data "azuread_groups" "example" { return_all = true } ``` +*Look up all mail-enabled groups* +```terraform +data "azuread_groups" "example" { + mail_enabled = true + return_all = true +} +``` + +*Look up all security-enabled groups that are not mail-enabled* +```terraform +data "azuread_groups" "example" { + mail_enabled = false + return_all = true + security_enabled = true +} +``` + ## Argument Reference The following arguments are supported: * `display_names` - (Optional) The display names of the groups. +* `mail_enabled` - (Optional) Whether the returned groups should be mail-enabled. By itself this does not exclude security-enabled groups. Setting this to `true` ensures all groups are mail-enabled, and setting to `false` ensures that all groups are _not_ mail-enabled. To ignore this filter, omit the property or set it to null. Cannot be specified together with `object_ids`. * `object_ids` - (Optional) The object IDs of the groups. * `return_all` - (Optional) A flag to denote if all groups should be fetched and returned. +* `security_enabled` - (Optional) Whether the returned groups should be security-enabled. By itself this does not exclude mail-enabled groups. Setting this to `true` ensures all groups are security-enabled, and setting to `false` ensures that all groups are _not_ security-enabled. To ignore this filter, omit the property or set it to null. Cannot be specified together with `object_ids`. ~> One of `display_names`, `object_ids` or `return_all` should be specified. Either of the first two _may_ be specified as an empty list, in which case no results will be returned. diff --git a/internal/acceptance/check/that.go b/internal/acceptance/check/that.go index 3241f12828..4d7c77d07a 100644 --- a/internal/acceptance/check/that.go +++ b/internal/acceptance/check/that.go @@ -1,7 +1,10 @@ package check import ( + "context" + "fmt" "regexp" + "strings" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" @@ -87,3 +90,29 @@ func (t thatWithKeyType) MatchesOtherKey(other thatWithKeyType) resource.TestChe func (t thatWithKeyType) MatchesRegex(r *regexp.Regexp) resource.TestCheckFunc { return resource.TestMatchResourceAttr(t.resourceName, t.key, r) } + +func (t thatWithKeyType) ValidatesWith(validationFunc KeyValidationFunc) resource.TestCheckFunc { + return func(state *terraform.State) error { + ms := state.RootModule() + rs, ok := ms.Resources[t.resourceName] + if !ok { + return fmt.Errorf("Not found: %s in %s", t.resourceName, ms.Path) + } + is := rs.Primary + if is == nil { + return fmt.Errorf("No primary instance: %s in %s", t.resourceName, ms.Path) + } + + var values []interface{} + for attr, val := range is.Attributes { + if attrParts := strings.Split(attr, "."); len(attrParts) == 2 && attrParts[0] == t.key && attrParts[1] != "#" && attrParts[1] != "%" { + values = append(values, val) + } + } + + clients := acceptance.AzureADProvider.Meta().(*clients.Client) + return validationFunc(clients.StopContext, clients, values) + } +} + +type KeyValidationFunc func(context.Context, *clients.Client, []interface{}) error diff --git a/internal/services/groups/groups_data_source.go b/internal/services/groups/groups_data_source.go index 1c461fe00c..e7d2864683 100644 --- a/internal/services/groups/groups_data_source.go +++ b/internal/services/groups/groups_data_source.go @@ -53,12 +53,28 @@ func groupsDataSource() *schema.Resource { }, }, + "mail_enabled": { + Description: "Whether the groups are mail-enabled", + Type: schema.TypeBool, + Optional: true, + Computed: true, + ConflictsWith: []string{"object_ids"}, + }, + "return_all": { Description: "Retrieve all groups with no filter", Type: schema.TypeBool, Optional: true, ExactlyOneOf: []string{"display_names", "object_ids", "return_all"}, }, + + "security_enabled": { + Description: "Whether the groups are security-enabled", + Type: schema.TypeBool, + Optional: true, + Computed: true, + ConflictsWith: []string{"object_ids"}, + }, }, } } @@ -76,8 +92,17 @@ func groupsDataSourceRead(ctx context.Context, d *schema.ResourceData, meta inte displayNames = v.([]interface{}) } + var filter []string + + if v, ok := d.GetOkExists("mail_enabled"); ok { //nolint:staticcheck // needed to detect unset booleans + filter = append(filter, fmt.Sprintf("mailEnabled eq %t", v.(bool))) + } + if v, ok := d.GetOkExists("security_enabled"); ok { //nolint:staticcheck // needed to detect unset booleans + filter = append(filter, fmt.Sprintf("securityEnabled eq %t", v.(bool))) + } + if returnAll { - result, _, err := client.List(ctx, odata.Query{}) + result, _, err := client.List(ctx, odata.Query{Filter: strings.Join(filter, " and ")}) if err != nil { return tf.ErrorDiagF(err, "Could not retrieve groups") } @@ -93,9 +118,7 @@ func groupsDataSourceRead(ctx context.Context, d *schema.ResourceData, meta inte expectedCount = len(displayNames) for _, v := range displayNames { displayName := v.(string) - query := odata.Query{ - Filter: fmt.Sprintf("displayName eq '%s'", displayName), - } + query := odata.Query{Filter: strings.Join(append(filter, fmt.Sprintf("displayName eq '%s'", displayName)), " and ")} result, _, err := client.List(ctx, query) if err != nil { return tf.ErrorDiagPathF(err, "display_names", "No group found with display name: %q", displayName) diff --git a/internal/services/groups/groups_data_source_test.go b/internal/services/groups/groups_data_source_test.go index 722d5e8cc3..813124300b 100644 --- a/internal/services/groups/groups_data_source_test.go +++ b/internal/services/groups/groups_data_source_test.go @@ -1,13 +1,16 @@ package groups_test import ( + "context" "fmt" "testing" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/manicminer/hamilton/odata" "github.com/hashicorp/terraform-provider-azuread/internal/acceptance" "github.com/hashicorp/terraform-provider-azuread/internal/acceptance/check" + "github.com/hashicorp/terraform-provider-azuread/internal/clients" ) type GroupsDataSource struct{} @@ -19,7 +22,7 @@ func TestAccGroupsDataSource_byDisplayNames(t *testing.T) { data.DataSourceTest(t, []resource.TestStep{ { Config: r.byDisplayNames(data), - Check: resource.ComposeTestCheckFunc( + Check: resource.ComposeAggregateTestCheckFunc( check.That(data.ResourceName).Key("display_names.#").HasValue("2"), check.That(data.ResourceName).Key("object_ids.#").HasValue("2"), ), @@ -34,7 +37,7 @@ func TestAccGroupsDataSource_byObjectIds(t *testing.T) { data.DataSourceTest(t, []resource.TestStep{ { Config: r.byObjectIds(data), - Check: resource.ComposeTestCheckFunc( + Check: resource.ComposeAggregateTestCheckFunc( check.That(data.ResourceName).Key("display_names.#").HasValue("2"), check.That(data.ResourceName).Key("object_ids.#").HasValue("2"), ), @@ -48,7 +51,7 @@ func TestAccGroupsDataSource_noNames(t *testing.T) { data.DataSourceTest(t, []resource.TestStep{ { Config: GroupsDataSource{}.noNames(), - Check: resource.ComposeTestCheckFunc( + Check: resource.ComposeAggregateTestCheckFunc( check.That(data.ResourceName).Key("display_names.#").HasValue("0"), check.That(data.ResourceName).Key("object_ids.#").HasValue("0"), ), @@ -62,7 +65,7 @@ func TestAccGroupsDataSource_returnAll(t *testing.T) { data.DataSourceTest(t, []resource.TestStep{ { Config: GroupsDataSource{}.returnAll(), - Check: resource.ComposeTestCheckFunc( + Check: resource.ComposeAggregateTestCheckFunc( check.That(data.ResourceName).Key("display_names.#").Exists(), check.That(data.ResourceName).Key("object_ids.#").Exists(), ), @@ -70,6 +73,119 @@ func TestAccGroupsDataSource_returnAll(t *testing.T) { }) } +func TestAccGroupsDataSource_returnAllMailEnabled(t *testing.T) { + data := acceptance.BuildTestData(t, "data.azuread_groups", "test") + + data.DataSourceTest(t, []resource.TestStep{ + { + Config: GroupsDataSource{}.returnAllMailEnabled(data), + Check: resource.ComposeAggregateTestCheckFunc( + check.That(data.ResourceName).Key("display_names.#").Exists(), + check.That(data.ResourceName).Key("object_ids.#").Exists(), + check.That(data.ResourceName).Key("object_ids").ValidatesWith(testCheckHasOnlyMailEnabledGroups()), + ), + }, + }) +} + +func TestAccGroupsDataSource_returnAllSecurityEnabled(t *testing.T) { + data := acceptance.BuildTestData(t, "data.azuread_groups", "test") + + data.DataSourceTest(t, []resource.TestStep{ + { + Config: GroupsDataSource{}.returnAllSecurityEnabled(data), + Check: resource.ComposeAggregateTestCheckFunc( + check.That(data.ResourceName).Key("display_names.#").Exists(), + check.That(data.ResourceName).Key("object_ids.#").Exists(), + check.That(data.ResourceName).Key("object_ids").ValidatesWith(testCheckHasOnlySecurityEnabledGroups()), + ), + }, + }) +} + +func TestAccGroupsDataSource_returnAllMailNotSecurityEnabled(t *testing.T) { + data := acceptance.BuildTestData(t, "data.azuread_groups", "test") + + data.DataSourceTest(t, []resource.TestStep{ + { + Config: GroupsDataSource{}.returnAllMailNotSecurityEnabled(data), + Check: resource.ComposeAggregateTestCheckFunc( + check.That(data.ResourceName).Key("display_names.#").Exists(), + check.That(data.ResourceName).Key("object_ids.#").Exists(), + check.That(data.ResourceName).Key("object_ids").ValidatesWith(testCheckHasOnlyMailEnabledGroupsNotSecurityEnabledGroups()), + ), + }, + }) +} + +func TestAccGroupsDataSource_returnAllSecurityNotMailEnabled(t *testing.T) { + data := acceptance.BuildTestData(t, "data.azuread_groups", "test") + + data.DataSourceTest(t, []resource.TestStep{ + { + Config: GroupsDataSource{}.returnAllSecurityNotMailEnabled(data), + Check: resource.ComposeAggregateTestCheckFunc( + check.That(data.ResourceName).Key("display_names.#").Exists(), + check.That(data.ResourceName).Key("object_ids.#").Exists(), + check.That(data.ResourceName).Key("object_ids").ValidatesWith(testCheckHasOnlySecurityEnabledGroupsNotMailEnabledGroups()), + ), + }, + }) +} + +func testCheckHasOnlyMailEnabledGroups() check.KeyValidationFunc { + return testCheckGroupsDataSource(true, false, false, false) +} + +func testCheckHasOnlySecurityEnabledGroups() check.KeyValidationFunc { + return testCheckGroupsDataSource(false, true, false, false) +} + +func testCheckHasOnlyMailEnabledGroupsNotSecurityEnabledGroups() check.KeyValidationFunc { + return testCheckGroupsDataSource(true, false, false, true) +} + +func testCheckHasOnlySecurityEnabledGroupsNotMailEnabledGroups() check.KeyValidationFunc { + return testCheckGroupsDataSource(false, true, true, false) +} + +func testCheckGroupsDataSource(hasMailGroupsOnly, hasSecurityGroupsOnly, hasNoMailGroups, hasNoSecurityGroups bool) check.KeyValidationFunc { + return func(ctx context.Context, clients *clients.Client, values []interface{}) error { + client := clients.Groups.GroupsClient + + for _, v := range values { + oid := v.(string) + group, _, err := client.Get(ctx, oid, odata.Query{}) + if err != nil { + return fmt.Errorf("retrieving group with object ID %q: %+oid", oid, err) + } + if group == nil { + return fmt.Errorf("retrieving group with object ID %q: group was nil", oid) + } + if group.ID == nil { + return fmt.Errorf("retrieving group with object ID %q: ID was nil", oid) + } + if group.DisplayName == nil { + return fmt.Errorf("retrieving group with object ID %q: DisplayName was nil", oid) + } + if hasMailGroupsOnly && group.MailEnabled != nil && !*group.MailEnabled { + return fmt.Errorf("expected only mail-enabled groups, encountered group %q (object ID: %q) which is not mail-enabled", *group.DisplayName, *group.ID) + } + if hasSecurityGroupsOnly && group.SecurityEnabled != nil && !*group.SecurityEnabled { + return fmt.Errorf("expected only security-enabled groups, encountered group %q (object ID: %q) which is not security-enabled", *group.DisplayName, *group.ID) + } + if hasNoMailGroups && group.MailEnabled != nil && *group.MailEnabled { + return fmt.Errorf("expected no mail-enabled groups, encountered group %q (object ID: %q) which is mail-enabled", *group.DisplayName, *group.ID) + } + if hasNoSecurityGroups && group.SecurityEnabled != nil && *group.SecurityEnabled { + return fmt.Errorf("expected no security-enabled groups, encountered group %q (object ID: %q) which is security-enabled", *group.DisplayName, *group.ID) + } + } + + return nil + } +} + func (GroupsDataSource) template(data acceptance.TestData) string { return fmt.Sprintf(` resource "azuread_group" "testA" { @@ -83,6 +199,14 @@ resource "azuread_group" "testB" { mail_nickname = "acctestGroupB-%[1]d" types = ["Unified"] } + +resource "azuread_group" "testC" { + display_name = "acctestGroupC-%[1]d" + mail_enabled = true + mail_nickname = "acctestGroupC%[1]d" + types = ["Unified"] + security_enabled = true +} `, data.RandomInteger) } @@ -121,3 +245,49 @@ data "azuread_groups" "test" { } ` } + +func (r GroupsDataSource) returnAllMailEnabled(data acceptance.TestData) string { + return fmt.Sprintf(` +%[1]s + +data "azuread_groups" "test" { + mail_enabled = true + return_all = true +} +`, r.template(data)) +} + +func (r GroupsDataSource) returnAllSecurityEnabled(data acceptance.TestData) string { + return fmt.Sprintf(` +%[1]s + +data "azuread_groups" "test" { + return_all = true + security_enabled = true +} +`, r.template(data)) +} + +func (r GroupsDataSource) returnAllMailNotSecurityEnabled(data acceptance.TestData) string { + return fmt.Sprintf(` +%[1]s + +data "azuread_groups" "test" { + mail_enabled = true + return_all = true + security_enabled = false +} +`, r.template(data)) +} + +func (r GroupsDataSource) returnAllSecurityNotMailEnabled(data acceptance.TestData) string { + return fmt.Sprintf(` +%[1]s + +data "azuread_groups" "test" { + mail_enabled = false + return_all = true + security_enabled = true +} +`, r.template(data)) +}