diff --git a/auth_oidc.go b/auth_oidc.go new file mode 100644 index 0000000..9a87a16 --- /dev/null +++ b/auth_oidc.go @@ -0,0 +1,122 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" +) + +var _ azcore.TokenCredential = &OidcCredential{} + +type OidcCredential struct { + requestToken string + requestUrl string + + token string + tokenFilePath string + + cred *azidentity.ClientAssertionCredential +} + +type OidcCredentialOptions struct { + azcore.ClientOptions + TenantID string + ClientID string + RequestToken string + RequestUrl string + Token string + TokenFilePath string +} + +func NewOidcCredential(options *OidcCredentialOptions) (*OidcCredential, error) { + w := &OidcCredential{ + requestToken: options.RequestToken, + requestUrl: options.RequestUrl, + token: options.Token, + tokenFilePath: options.TokenFilePath, + } + + cred, err := azidentity.NewClientAssertionCredential(options.TenantID, options.ClientID, w.getAssertion, &azidentity.ClientAssertionCredentialOptions{ClientOptions: options.ClientOptions}) + if err != nil { + return nil, err + } + + w.cred = cred + return w, nil +} + +func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + return w.cred.GetToken(ctx, opts) +} + +func (w *OidcCredential) getAssertion(ctx context.Context) (string, error) { + if w.token != "" { + return w.token, nil + } + + if w.tokenFilePath != "" { + idTokenData, err := os.ReadFile(w.tokenFilePath) + if err != nil { + return "", fmt.Errorf("reading token file: %v", err) + } + + return string(idTokenData), nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.requestUrl, http.NoBody) + if err != nil { + return "", fmt.Errorf("getAssertion: failed to build request") + } + + query, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + return "", fmt.Errorf("getAssertion: cannot parse URL query") + } + + if query.Get("audience") == "" { + query.Set("audience", "api://AzureADTokenExchange") + req.URL.RawQuery = query.Encode() + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", w.requestToken)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("getAssertion: cannot request token: %v", err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("getAssertion: cannot parse response: %v", err) + } + + if c := resp.StatusCode; c < 200 || c > 299 { + return "", fmt.Errorf("getAssertion: received HTTP status %d with response: %s", resp.StatusCode, body) + } + + var tokenRes struct { + Count *int `json:"count"` + Value *string `json:"value"` + } + if err := json.Unmarshal(body, &tokenRes); err != nil { + return "", fmt.Errorf("getAssertion: cannot unmarshal response: %v", err) + } + + if tokenRes.Value == nil { + return "", fmt.Errorf("getAssertion: nil JWT assertion received from OIDC provider") + } + + return *tokenRes.Value, nil +} diff --git a/command_before_func.go b/command_before_func.go index 26c08f4..5f78429 100644 --- a/command_before_func.go +++ b/command_before_func.go @@ -60,13 +60,14 @@ func commandBeforeFunc(fset *FlagSet) func(ctx *cli.Context) error { fset.flagUseEnvironmentCred, fset.flagUseManagedIdentityCred, fset.flagUseAzureCLICred, + fset.flagUseOIDCCred, } { if ok { occur += 1 } } if occur > 1 { - return fmt.Errorf("only one of `--use-environment-cred`, `--use-managed-identity-cred` and `--use-azure-cli-cred` can be specified") + return fmt.Errorf("only one of `--use-environment-cred`, `--use-managed-identity-cred`, `--use-azure-cli-cred` and `--use-oidc-cred` can be specified") } // Initialize output directory diff --git a/command_before_func_test.go b/command_before_func_test.go index cf99c0f..7980f67 100644 --- a/command_before_func_test.go +++ b/command_before_func_test.go @@ -224,7 +224,7 @@ func TestCommondBeforeFunc(t *testing.T) { flagUseEnvironmentCred: true, flagUseAzureCLICred: true, }, - err: "only one of `--use-environment-cred`, `--use-managed-identity-cred` and `--use-azure-cli-cred` can be specified", + err: "only one of `--use-environment-cred`, `--use-managed-identity-cred`, `--use-azure-cli-cred` and `--use-oidc-cred` can be specified", }, } diff --git a/flag.go b/flag.go index a90a6e2..c8246e2 100644 --- a/flag.go +++ b/flag.go @@ -11,26 +11,33 @@ var flagset FlagSet type FlagSet struct { // common flags - flagEnv string - flagSubscriptionId string + flagEnv string + flagSubscriptionId string + flagOutputDir string + flagOverwrite bool + flagAppend bool + flagDevProvider bool + flagProviderVersion string + flagBackendType string + flagBackendConfig cli.StringSlice + flagFullConfig bool + flagParallelism int + flagContinue bool + flagNonInteractive bool + flagPlainUI bool + flagGenerateMappingFile bool + flagHCLOnly bool + flagModulePath string + + // common flags (auth) flagUseEnvironmentCred bool flagUseManagedIdentityCred bool flagUseAzureCLICred bool - flagOutputDir string - flagOverwrite bool - flagAppend bool - flagDevProvider bool - flagProviderVersion string - flagBackendType string - flagBackendConfig cli.StringSlice - flagFullConfig bool - flagParallelism int - flagContinue bool - flagNonInteractive bool - flagPlainUI bool - flagGenerateMappingFile bool - flagHCLOnly bool - flagModulePath string + flagUseOIDCCred bool + flagOIDCRequestToken string + flagOIDCRequestURL string + flagOIDCTokenFilePath string + flagOIDCToken string // common flags (hidden) hflagMockClient bool @@ -81,15 +88,6 @@ func (flag FlagSet) DescribeCLI(mode string) string { if flag.flagOverwrite { args = append(args, "--overwrite=true") } - if flag.flagUseEnvironmentCred { - args = append(args, "--use-environment-cred=true") - } - if flag.flagUseManagedIdentityCred { - args = append(args, "--use-managed-identity-cred=true") - } - if flag.flagUseAzureCLICred { - args = append(args, "--use-azure-cli-cred=true") - } if flag.flagAppend { args = append(args, "--append=true") } @@ -120,12 +118,38 @@ func (flag FlagSet) DescribeCLI(mode string) string { if flag.flagHCLOnly { args = append(args, "--hcl-only=true") } - if flag.hflagTFClientPluginPath != "" { - args = append(args, "--tfclient-plugin-path="+flag.hflagTFClientPluginPath) - } if flag.flagModulePath != "" { args = append(args, "--module-path="+flag.flagModulePath) } + + if flag.flagUseEnvironmentCred { + args = append(args, "--use-environment-cred=true") + } + if flag.flagUseManagedIdentityCred { + args = append(args, "--use-managed-identity-cred=true") + } + if flag.flagUseAzureCLICred { + args = append(args, "--use-azure-cli-cred=true") + } + if flag.flagUseOIDCCred { + args = append(args, "--use-oidc-cred=true") + } + if flag.flagOIDCRequestToken != "" { + args = append(args, "--oidc-request-token=*") + } + if flag.flagOIDCRequestURL != "" { + args = append(args, "--oidc-request-url="+flag.flagOIDCRequestURL) + } + if flag.flagOIDCTokenFilePath != "" { + args = append(args, "--oidc-token-file-path="+flag.flagOIDCTokenFilePath) + } + if flag.flagOIDCToken != "" { + args = append(args, "--oidc-token=*") + } + + if flag.hflagTFClientPluginPath != "" { + args = append(args, "--tfclient-plugin-path="+flag.hflagTFClientPluginPath) + } switch mode { case ModeResource: if flag.flagResName != "" { diff --git a/main.go b/main.go index 8c575f4..49f99a5 100644 --- a/main.go +++ b/main.go @@ -126,24 +126,6 @@ func main() { Usage: "The subscription id", Destination: &flagset.flagSubscriptionId, }, - &cli.BoolFlag{ - Name: "use-environment-cred", - EnvVars: []string{"AZTFEXPORT_USE_ENVIRONMENT_CRED"}, - Usage: "Explicitly use the environment variables to do authentication", - Destination: &flagset.flagUseEnvironmentCred, - }, - &cli.BoolFlag{ - Name: "use-managed-identity-cred", - EnvVars: []string{"AZTFEXPORT_USE_MANAGED_IDENTITY_CRED"}, - Usage: "Explicitly use the managed identity that is provided by the Azure host to do authentication", - Destination: &flagset.flagUseManagedIdentityCred, - }, - &cli.BoolFlag{ - Name: "use-azure-cli-cred", - EnvVars: []string{"AZTFEXPORT_USE_AZURE_CLI_CRED"}, - Usage: "Explicitly use the Azure CLI to do authentication", - Destination: &flagset.flagUseAzureCLICred, - }, &cli.StringFlag{ Name: "output-dir", EnvVars: []string{"AZTFEXPORT_OUTPUT_DIR"}, @@ -259,6 +241,56 @@ func main() { Value: "INFO", }, + // Common flags (auth) + &cli.BoolFlag{ + Name: "use-environment-cred", + EnvVars: []string{"AZTFEXPORT_USE_ENVIRONMENT_CRED"}, + Usage: "Explicitly use the environment variables to do authentication", + Destination: &flagset.flagUseEnvironmentCred, + }, + &cli.BoolFlag{ + Name: "use-managed-identity-cred", + EnvVars: []string{"AZTFEXPORT_USE_MANAGED_IDENTITY_CRED"}, + Usage: "Explicitly use the managed identity that is provided by the Azure host to do authentication", + Destination: &flagset.flagUseManagedIdentityCred, + }, + &cli.BoolFlag{ + Name: "use-azure-cli-cred", + EnvVars: []string{"AZTFEXPORT_USE_AZURE_CLI_CRED"}, + Usage: "Explicitly use the Azure CLI to do authentication", + Destination: &flagset.flagUseAzureCLICred, + }, + &cli.BoolFlag{ + Name: "use-oidc-cred", + EnvVars: []string{"AZTFEXPORT_USE_OIDC_CRED"}, + Usage: "Explicitly use the OIDC to do authentication", + Destination: &flagset.flagUseOIDCCred, + }, + &cli.StringFlag{ + Name: "oidc-request-token", + EnvVars: []string{"AZTFEXPORT_OIDC_REQUEST_TOKEN", "ARM_OIDC_REQUEST_TOKEN", "ACTIONS_ID_TOKEN_REQUEST_TOKEN"}, + Usage: "The bearer token for the request to the OIDC provider", + Destination: &flagset.flagOIDCRequestToken, + }, + &cli.StringFlag{ + Name: "oidc-request-url", + EnvVars: []string{"AZTFEXPORT_OIDC_REQUEST_URL", "ARM_OIDC_REQUEST_URL", "ACTIONS_ID_TOKEN_REQUEST_URL"}, + Usage: "The URL for the OIDC provider from which to request an ID token", + Destination: &flagset.flagOIDCRequestURL, + }, + &cli.StringFlag{ + Name: "oidc-token-file-path", + EnvVars: []string{"AZTFEXPORT_OIDC_TOKEN_FILE_PATH", "ARM_OIDC_TOKEN_FILE_PATH"}, + Usage: "The path to a file containing an ID token when authenticating using OIDC", + Destination: &flagset.flagOIDCTokenFilePath, + }, + &cli.StringFlag{ + Name: "oidc-token", + EnvVars: []string{"AZTFEXPORT_OIDC_TOKEN", "ARM_OIDC_TOKEN"}, + Usage: "The ID token when authenticating using OIDC", + Destination: &flagset.flagOIDCToken, + }, + // Hidden flags &cli.BoolFlag{ Name: "mock-client", @@ -414,7 +446,7 @@ func main() { return fmt.Errorf("invalid resource id: %v", err) } - cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset.flagEnv, NewAuthMethodFromFlagSet(flagset)) + cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset) if err != nil { return err } @@ -478,7 +510,7 @@ func main() { rg := c.Args().First() - cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset.flagEnv, NewAuthMethodFromFlagSet(flagset)) + cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset) if err != nil { return err } @@ -541,7 +573,7 @@ func main() { predicate := c.Args().First() - cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset.flagEnv, NewAuthMethodFromFlagSet(flagset)) + cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset) if err != nil { return err } @@ -605,7 +637,7 @@ func main() { mapFile := c.Args().First() - cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset.flagEnv, NewAuthMethodFromFlagSet(flagset)) + cred, clientOpt, err := buildAzureSDKCredAndClientOpt(flagset) if err != nil { return err } @@ -741,33 +773,10 @@ func initTelemetryClient(subscriptionId string) telemetry.Client { return telemetry.NewAppInsight(subscriptionId, installId, sessionId) } -// At most one of below is true -type authMethod int - -const ( - authMethodDefault authMethod = iota - authMethodEnvironment - authMethodManagedIdentity - authMethodAzureCLI -) - -func NewAuthMethodFromFlagSet(fset FlagSet) authMethod { - if fset.flagUseEnvironmentCred { - return authMethodEnvironment - } - if fset.flagUseManagedIdentityCred { - return authMethodManagedIdentity - } - if fset.flagUseAzureCLICred { - return authMethodAzureCLI - } - return authMethodDefault -} - // buildAzureSDKCredAndClientOpt builds the Azure SDK credential and client option from multiple sources (i.e. environment variables, MSI, Azure CLI). -func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.TokenCredential, *arm.ClientOptions, error) { +func buildAzureSDKCredAndClientOpt(fset FlagSet) (azcore.TokenCredential, *arm.ClientOptions, error) { var cloudCfg cloud.Configuration - switch strings.ToLower(env) { + switch env := fset.flagEnv; strings.ToLower(env) { case "public": cloudCfg = cloud.AzurePublic case "usgovernment": @@ -814,8 +823,8 @@ func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.To cred azcore.TokenCredential err error ) - switch authMethod { - case authMethodEnvironment: + switch { + case fset.flagUseEnvironmentCred: cred, err = azidentity.NewEnvironmentCredential(&azidentity.EnvironmentCredentialOptions{ ClientOptions: clientOpt.ClientOptions, }) @@ -823,7 +832,7 @@ func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.To return nil, nil, fmt.Errorf("failed to new Environment credential: %v", err) } return cred, clientOpt, nil - case authMethodManagedIdentity: + case fset.flagUseManagedIdentityCred: cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ ClientOptions: clientOpt.ClientOptions, }) @@ -831,7 +840,7 @@ func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.To return nil, nil, fmt.Errorf("failed to new Managed Identity credential: %v", err) } return cred, clientOpt, nil - case authMethodAzureCLI: + case fset.flagUseAzureCLICred: cred, err = azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ TenantID: tenantId, }) @@ -839,7 +848,21 @@ func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.To return nil, nil, fmt.Errorf("failed to new Azure CLI credential: %v", err) } return cred, clientOpt, nil - case authMethodDefault: + case fset.flagUseOIDCCred: + cred, err = NewOidcCredential(&OidcCredentialOptions{ + ClientOptions: clientOpt.ClientOptions, + TenantID: tenantId, + ClientID: os.Getenv("ARM_CLIENT_ID"), + RequestToken: fset.flagOIDCRequestToken, + RequestUrl: fset.flagOIDCRequestURL, + Token: fset.flagOIDCToken, + TokenFilePath: fset.flagOIDCTokenFilePath, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to new OIDC credential: %v", err) + } + return cred, clientOpt, nil + default: opt := &azidentity.DefaultAzureCredentialOptions{ ClientOptions: clientOpt.ClientOptions, TenantID: tenantId, @@ -849,8 +872,6 @@ func buildAzureSDKCredAndClientOpt(env string, authMethod authMethod) (azcore.To return nil, nil, fmt.Errorf("failed to new Default credential: %v", err) } return cred, clientOpt, nil - default: - return nil, nil, fmt.Errorf("unknown auth method: %v", authMethod) } }