Skip to content

Commit

Permalink
New Data Source: azuread_users (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
katbyte authored Jun 13, 2019
1 parent 03e89e1 commit 480d487
Show file tree
Hide file tree
Showing 13 changed files with 412 additions and 135 deletions.
38 changes: 5 additions & 33 deletions azuread/data_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/hashicorp/terraform/helper/schema"

"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/ar"
"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/graph"
"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/validate"
)

Expand Down Expand Up @@ -64,54 +64,26 @@ func dataSourceUserRead(d *schema.ResourceData, meta interface{}) error {
var user graphrbac.User

if upn, ok := d.Get("user_principal_name").(string); ok && upn != "" {

// use the object_id to find the Azure AD application
resp, err := client.Get(ctx, upn)
if err != nil {
if ar.ResponseWasNotFound(resp.Response) {
return fmt.Errorf("Error: AzureAD User with ID %q was not found", upn)
}

return fmt.Errorf("Error making Read request on AzureAD User with ID %q: %+v", upn, err)
}

user = resp
} else if oId, ok := d.Get("object_id").(string); ok && oId != "" {
filter := fmt.Sprintf("objectId eq '%s'", oId)

resp, err := client.ListComplete(ctx, filter)
u, err := graph.UserGetByObjectId(&client, ctx, oId)
if err != nil {
return fmt.Errorf("Error listing Azure AD Users for filter %q: %+v", filter, err)
}

values := resp.Response().Value
if values == nil {
return fmt.Errorf("nil values for AD Users matching %q", filter)
}
if len(*values) == 0 {
return fmt.Errorf("Found no AD Users matching %q", filter)
}
if len(*values) > 2 {
return fmt.Errorf("Found multiple AD Users matching %q", filter)
}

user = (*values)[0]
if user.DisplayName == nil {
return fmt.Errorf("nil DisplayName for AD Users matching %q", filter)
}
if *user.ObjectID != oId {
return fmt.Errorf("objectID for AD Users matching %q does is does not match(%q!=%q)", filter, *user.ObjectID, oId)
return fmt.Errorf("Error finding Azure AD User with object ID %q: %+v", oId, err)
}
user = *u
} else {
return fmt.Errorf("one of `object_id` or `user_principal_name` must be supplied")
}

if user.ObjectID == nil {
return fmt.Errorf("Group objectId is nil")
return fmt.Errorf("Azure AD User objectId is nil")
}
d.SetId(*user.ObjectID)

d.SetId(*user.ObjectID)
d.Set("object_id", user.ObjectID)
d.Set("user_principal_name", user.UserPrincipalName)
d.Set("account_enabled", user.AccountEnabled)
Expand Down
38 changes: 20 additions & 18 deletions azuread/data_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (

"github.com/hashicorp/terraform/helper/acctest"
"github.com/hashicorp/terraform/helper/resource"

"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/tf"
)

func TestAccDataSourceAzureADUser_byUserPrincipalName(t *testing.T) {
dataSourceName := "data.azuread_user.test"
id := acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)
password := id + "p@$$wR2"
func TestAccAzureADUserDataSource_byUserPrincipalName(t *testing.T) {
dsn := "data.azuread_user.test"
id := tf.AccRandTimeInt()
password := "p@$$wR2" + acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Expand All @@ -20,20 +22,20 @@ func TestAccDataSourceAzureADUser_byUserPrincipalName(t *testing.T) {
{
Config: testAccAzureADUserDataSource_byUserPrincipalName(id, password),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttrSet(dataSourceName, "user_principal_name"),
resource.TestCheckResourceAttrSet(dataSourceName, "account_enabled"),
resource.TestCheckResourceAttrSet(dataSourceName, "display_name"),
resource.TestCheckResourceAttrSet(dataSourceName, "mail_nickname"),
resource.TestCheckResourceAttrSet(dsn, "user_principal_name"),
resource.TestCheckResourceAttrSet(dsn, "account_enabled"),
resource.TestCheckResourceAttrSet(dsn, "display_name"),
resource.TestCheckResourceAttrSet(dsn, "mail_nickname"),
),
},
},
})
}

func TestAccDataSourceAzureADUser_byObjectId(t *testing.T) {
dataSourceName := "data.azuread_user.test"
id := acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)
password := id + "p@$$wR2"
func TestAccAzureADUserDataSource_byObjectId(t *testing.T) {
dsn := "data.azuread_user.test"
id := tf.AccRandTimeInt()
password := "p@$$wR2" + acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Expand All @@ -42,17 +44,17 @@ func TestAccDataSourceAzureADUser_byObjectId(t *testing.T) {
{
Config: testAccAzureADUserDataSource_byObjectId(id, password),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttrSet(dataSourceName, "user_principal_name"),
resource.TestCheckResourceAttrSet(dataSourceName, "account_enabled"),
resource.TestCheckResourceAttrSet(dataSourceName, "display_name"),
resource.TestCheckResourceAttrSet(dataSourceName, "mail_nickname"),
resource.TestCheckResourceAttrSet(dsn, "user_principal_name"),
resource.TestCheckResourceAttrSet(dsn, "account_enabled"),
resource.TestCheckResourceAttrSet(dsn, "display_name"),
resource.TestCheckResourceAttrSet(dsn, "mail_nickname"),
),
},
},
})
}

func testAccAzureADUserDataSource_byUserPrincipalName(id, password string) string {
func testAccAzureADUserDataSource_byUserPrincipalName(id int, password string) string {
return fmt.Sprintf(`
%s
Expand All @@ -62,7 +64,7 @@ data "azuread_user" "test" {
`, testAccADUser_basic(id, password))
}

func testAccAzureADUserDataSource_byObjectId(id, password string) string {
func testAccAzureADUserDataSource_byObjectId(id int, password string) string {
return fmt.Sprintf(`
%s
Expand Down
105 changes: 105 additions & 0 deletions azuread/data_users.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package azuread

import (
"crypto/sha1"
"encoding/base64"
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/hashicorp/terraform/helper/schema"

"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/graph"
"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/validate"
)

func dataSourceUsers() *schema.Resource {
return &schema.Resource{
Read: dataSourceUsersRead,

Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},

Schema: map[string]*schema.Schema{
"object_ids": {
Type: schema.TypeList,
Optional: true,
Computed: true,
MinItems: 1,
ConflictsWith: []string{"user_principal_names"},
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validate.UUID,
},
},

"user_principal_names": {
Type: schema.TypeList,
Optional: true,
Computed: true,
MinItems: 1,
ConflictsWith: []string{"object_ids"},
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validate.NoEmptyStrings,
},
},
},
}
}

func dataSourceUsersRead(d *schema.ResourceData, meta interface{}) error {
client := meta.(*ArmClient).usersClient
ctx := meta.(*ArmClient).StopContext

var users []graphrbac.User
expectedCount := 0

if upns, ok := d.Get("user_principal_names").([]interface{}); ok && len(upns) > 0 {
expectedCount = len(upns)
for _, v := range upns {
resp, err := client.Get(ctx, v.(string))
if err != nil {
return fmt.Errorf("Error making Read request on AzureAD User with ID %q: %+v", v.(string), err)
}

users = append(users, resp)
}
} else if oids, ok := d.Get("object_ids").([]interface{}); ok && len(oids) > 0 {
expectedCount = len(oids)
for _, v := range oids {
u, err := graph.UserGetByObjectId(&client, ctx, v.(string))
if err != nil {
return fmt.Errorf("Error finding Azure AD User with object ID %q: %+v", v.(string), err)
}
users = append(users, *u)
}
} else {
return fmt.Errorf("one of `object_ids` or `user_principal_names` must be supplied")
}

if len(users) != expectedCount {
return fmt.Errorf("Unexpected number of users returns (%d != %d)", len(users), expectedCount)
}

var upns, oids []string
for _, u := range users {
if u.ObjectID == nil || u.UserPrincipalName == nil {
return fmt.Errorf("User with nil ObjectId or UPN was found: %v", u)
}

oids = append(oids, *u.ObjectID)
upns = append(upns, *u.UserPrincipalName)
}

h := sha1.New()
if _, err := h.Write([]byte(strings.Join(upns, "-"))); err != nil {
return fmt.Errorf("Unable to compute hash for upns: %v", err)
}

d.SetId("users#" + base64.URLEncoding.EncodeToString(h.Sum(nil)))
d.Set("object_ids", oids)
d.Set("user_principal_names", upns)
return nil
}
71 changes: 71 additions & 0 deletions azuread/data_users_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package azuread

import (
"fmt"
"testing"

"github.com/hashicorp/terraform/helper/acctest"
"github.com/hashicorp/terraform/helper/resource"

"github.com/terraform-providers/terraform-provider-azuread/azuread/helpers/tf"
)

func TestAccAzureADUsersDataSource_byUserPrincipalNames(t *testing.T) {
dsn := "data.azuread_users.test"
id := tf.AccRandTimeInt()
password := "p@$$wR2" + acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: testAccAzureADUsersDataSource_byUserPrincipalNames(id, password),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr(dsn, "user_principal_names.#", "2"),
resource.TestCheckResourceAttr(dsn, "object_ids.#", "2"),
),
},
},
})
}

func TestAccAzureADUsersDataSource_byObjectIds(t *testing.T) {
dsn := "data.azuread_users.test"
id := tf.AccRandTimeInt()
password := "p@$$wR2" + acctest.RandStringFromCharSet(7, acctest.CharSetAlphaNum)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: testAccAzureADUsersDataSource_byObjectIds(id, password),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr(dsn, "user_principal_names.#", "2"),
resource.TestCheckResourceAttr(dsn, "object_ids.#", "2"),
),
},
},
})
}

func testAccAzureADUsersDataSource_byUserPrincipalNames(id int, password string) string {
return fmt.Sprintf(`
%s
data "azuread_users" "test" {
user_principal_names = ["${azuread_user.testA.user_principal_name}", "${azuread_user.testB.user_principal_name}"]
}
`, testAccADUser_multiple(id, password))
}

func testAccAzureADUsersDataSource_byObjectIds(id int, password string) string {
return fmt.Sprintf(`
%s
data "azuread_users" "test" {
object_ids = ["${azuread_user.testA.object_id}", "${azuread_user.testB.object_id}"]
}
`, testAccADUser_multiple(id, password))
}
38 changes: 38 additions & 0 deletions azuread/helpers/graph/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package graph

import (
"context"
"fmt"

"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
)

func UserGetByObjectId(client *graphrbac.UsersClient, ctx context.Context, objectId string) (*graphrbac.User, error) {

filter := fmt.Sprintf("objectId eq '%s'", objectId)
resp, err := client.ListComplete(ctx, filter)
if err != nil {
return nil, fmt.Errorf("Error listing Azure AD Users for filter %q: %+v", filter, err)
}

values := resp.Response().Value
if values == nil {
return nil, fmt.Errorf("nil values for AD Users matching %q", filter)
}
if len(*values) == 0 {
return nil, fmt.Errorf("Found no AD Users matching %q", filter)
}
if len(*values) > 2 {
return nil, fmt.Errorf("Found multiple AD Users matching %q", filter)
}

user := (*values)[0]
if user.DisplayName == nil {
return nil, fmt.Errorf("nil DisplayName for AD Users matching %q", filter)
}
if *user.ObjectID != objectId {
return nil, fmt.Errorf("objectID for AD Users matching %q does is does not match(%q!=%q)", filter, *user.ObjectID, objectId)
}

return &user, nil
}
27 changes: 27 additions & 0 deletions azuread/helpers/tf/acctest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package tf

import (
"strconv"
"strings"
"time"

"github.com/hashicorp/terraform/helper/acctest"
)

func AccRandTimeInt() int {
// acctest.RantInt() returns a value of size:
// 000000000000000000
// YYMMddHHmmsshhRRRR

//go format: 2006-01-02 15:04:05.00

timeStr := strings.Replace(time.Now().Local().Format("060102150405.00"), ".", "", 1) //no way to not have a .?
postfix := acctest.RandStringFromCharSet(4, "0123456789")

i, err := strconv.Atoi(timeStr + postfix)
if err != nil {
panic(err)
}

return i
}
Loading

0 comments on commit 480d487

Please sign in to comment.