From 058e93a786d370a0bff27612ecdb0da8a5aa9bb1 Mon Sep 17 00:00:00 2001 From: Tom Bamford Date: Tue, 15 Dec 2020 14:58:14 +0000 Subject: [PATCH] Parse claims in access tokens to surface useful authentication metadata, use in azuread_client_config to retrieve authenticated object ID --- internal/clients/client.go | 22 +++++++-- .../client_config_data_source.go | 37 ++------------ .../client_config_data_source_aadgraph.go | 48 +++++++++++++++++++ .../client_config_data_source_msgraph.go | 37 ++++++++++++++ .../client_config_data_source_test.go | 1 + 5 files changed, 108 insertions(+), 37 deletions(-) create mode 100644 internal/services/serviceprincipals/client_config_data_source_aadgraph.go create mode 100644 internal/services/serviceprincipals/client_config_data_source_msgraph.go diff --git a/internal/clients/client.go b/internal/clients/client.go index 2c336532c1..280cb43866 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -2,9 +2,11 @@ package clients import ( "context" + "fmt" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" + "github.com/manicminer/hamilton/auth" "github.com/terraform-providers/terraform-provider-azuread/internal/common" applications "github.com/terraform-providers/terraform-provider-azuread/internal/services/applications/client" @@ -16,9 +18,11 @@ import ( // Client contains the handles to all the specific Azure AD resource classes' respective clients type Client struct { - ClientID string - ObjectID string - TenantID string + ClientID string + ObjectID string + TenantID string + Claims auth.Claims + TerraformVersion string Environment azure.Environment @@ -44,5 +48,17 @@ func (client *Client) build(ctx context.Context, o *common.ClientOptions) error client.ServicePrincipals = serviceprincipals.NewClient(o) client.Users = users.NewClient(o) + if client.EnableMsGraphBeta { + // Acquire an access token upfront so we can decode and populate the JWT claims + token, err := o.MsGraphAuthorizer.Token() + if err != nil { + return fmt.Errorf("unable to obtain access token: %v", err) + } + client.Claims, err = auth.ParseClaims(token) + if err != nil { + return fmt.Errorf("unable to parse claims in access token: %v", err) + } + } + return nil } diff --git a/internal/services/serviceprincipals/client_config_data_source.go b/internal/services/serviceprincipals/client_config_data_source.go index 09d7c51ad4..5053979466 100644 --- a/internal/services/serviceprincipals/client_config_data_source.go +++ b/internal/services/serviceprincipals/client_config_data_source.go @@ -2,14 +2,12 @@ package serviceprincipals import ( "context" - "fmt" "time" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/terraform-providers/terraform-provider-azuread/internal/clients" - "github.com/terraform-providers/terraform-provider-azuread/internal/tf" ) func clientConfigDataSource() *schema.Resource { @@ -40,37 +38,8 @@ func clientConfigDataSource() *schema.Resource { } func clientConfigDataSourceRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client := meta.(*clients.Client) - - if client.AuthenticatedAsAServicePrincipal { - spClient := client.ServicePrincipals.AadClient - // Application & Service Principal is 1:1 per tenant. Since we know the appId (client_id) - // here, we can query for the Service Principal whose appId matches. - filter := fmt.Sprintf("appId eq '%s'", client.ClientID) - result, err := spClient.List(ctx, filter) - - if err != nil { - return tf.ErrorDiagF(err, "Listing Service Principals") - } - - if result.Values() == nil || len(result.Values()) != 1 { - return tf.ErrorDiagF(fmt.Errorf("%#v", result.Values()), "Unexpected Service Principal query result") - } - } - - d.SetId(fmt.Sprintf("%s-%s-%s", client.TenantID, client.ObjectID, client.ClientID)) - - if dg := tf.Set(d, "client_id", client.ClientID); dg != nil { - return dg + if useMsGraph := meta.(*clients.Client).EnableMsGraphBeta; useMsGraph { + return clientConfigDataSourceReadMsGraph(ctx, d, meta) } - - if dg := tf.Set(d, "object_id", client.ObjectID); dg != nil { - return dg - } - - if dg := tf.Set(d, "tenant_id", client.TenantID); dg != nil { - return dg - } - - return nil + return clientConfigDataSourceReadAadGraph(ctx, d, meta) } diff --git a/internal/services/serviceprincipals/client_config_data_source_aadgraph.go b/internal/services/serviceprincipals/client_config_data_source_aadgraph.go new file mode 100644 index 0000000000..3fb0a968ad --- /dev/null +++ b/internal/services/serviceprincipals/client_config_data_source_aadgraph.go @@ -0,0 +1,48 @@ +package serviceprincipals + +import ( + "context" + "fmt" + + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + + "github.com/terraform-providers/terraform-provider-azuread/internal/clients" + "github.com/terraform-providers/terraform-provider-azuread/internal/tf" +) + +func clientConfigDataSourceReadAadGraph(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*clients.Client) + + if client.AuthenticatedAsAServicePrincipal { + spClient := client.ServicePrincipals.AadClient + // Application & Service Principal is 1:1 per tenant. Since we know the appId (client_id) + // here, we can query for the Service Principal whose appId matches. + filter := fmt.Sprintf("appId eq '%s'", client.ClientID) + result, err := spClient.List(ctx, filter) + + if err != nil { + return tf.ErrorDiagF(err, "Listing Service Principals") + } + + if result.Values() == nil || len(result.Values()) != 1 { + return tf.ErrorDiagF(fmt.Errorf("%#v", result.Values()), "Unexpected Service Principal query result") + } + } + + d.SetId(fmt.Sprintf("%s-%s-%s", client.TenantID, client.ObjectID, client.ClientID)) + + if dg := tf.Set(d, "client_id", client.ClientID); dg != nil { + return dg + } + + if dg := tf.Set(d, "object_id", client.ObjectID); dg != nil { + return dg + } + + if dg := tf.Set(d, "tenant_id", client.TenantID); dg != nil { + return dg + } + + return nil +} diff --git a/internal/services/serviceprincipals/client_config_data_source_msgraph.go b/internal/services/serviceprincipals/client_config_data_source_msgraph.go new file mode 100644 index 0000000000..b752e1b6a7 --- /dev/null +++ b/internal/services/serviceprincipals/client_config_data_source_msgraph.go @@ -0,0 +1,37 @@ +package serviceprincipals + +import ( + "context" + "fmt" + + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + + "github.com/terraform-providers/terraform-provider-azuread/internal/clients" + "github.com/terraform-providers/terraform-provider-azuread/internal/tf" +) + +func clientConfigDataSourceReadMsGraph(_ context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*clients.Client) + + objectId := "" + if client.Claims.ObjectId != "" { + objectId = client.Claims.ObjectId + } + + d.SetId(fmt.Sprintf("%s-%s-%s", client.TenantID, client.ClientID, objectId)) + + if dg := tf.Set(d, "tenant_id", client.TenantID); dg != nil { + return dg + } + + if dg := tf.Set(d, "client_id", client.ClientID); dg != nil { + return dg + } + + if dg := tf.Set(d, "object_id", objectId); dg != nil { + return dg + } + + return nil +} diff --git a/internal/services/serviceprincipals/client_config_data_source_test.go b/internal/services/serviceprincipals/client_config_data_source_test.go index e2aa1ce34d..7b1588ffac 100644 --- a/internal/services/serviceprincipals/client_config_data_source_test.go +++ b/internal/services/serviceprincipals/client_config_data_source_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/terraform-providers/terraform-provider-azuread/internal/acceptance" "github.com/terraform-providers/terraform-provider-azuread/internal/acceptance/check" )