Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New auth method: Federated Workload Identity (a.k.a OIDC) #438

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions auth_oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package main

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
)

var _ azcore.TokenCredential = &OidcCredential{}

type OidcCredential struct {
requestToken string
requestUrl string

token string
tokenFilePath string

cred *azidentity.ClientAssertionCredential
}

type OidcCredentialOptions struct {
azcore.ClientOptions
TenantID string
ClientID string
RequestToken string
RequestUrl string
Token string
TokenFilePath string
}

func NewOidcCredential(options *OidcCredentialOptions) (*OidcCredential, error) {
w := &OidcCredential{
requestToken: options.RequestToken,
requestUrl: options.RequestUrl,
token: options.Token,
tokenFilePath: options.TokenFilePath,
}

cred, err := azidentity.NewClientAssertionCredential(options.TenantID, options.ClientID, w.getAssertion, &azidentity.ClientAssertionCredentialOptions{ClientOptions: options.ClientOptions})
if err != nil {
return nil, err
}

w.cred = cred
return w, nil
}

func (w *OidcCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
return w.cred.GetToken(ctx, opts)
}

func (w *OidcCredential) getAssertion(ctx context.Context) (string, error) {
if w.token != "" {
return w.token, nil
}

if w.tokenFilePath != "" {
idTokenData, err := os.ReadFile(w.tokenFilePath)
if err != nil {
return "", fmt.Errorf("reading token file: %v", err)
}

return string(idTokenData), nil
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, w.requestUrl, http.NoBody)
if err != nil {
return "", fmt.Errorf("getAssertion: failed to build request")
}

query, err := url.ParseQuery(req.URL.RawQuery)
if err != nil {
return "", fmt.Errorf("getAssertion: cannot parse URL query")
}

if query.Get("audience") == "" {
query.Set("audience", "api://AzureADTokenExchange")
req.URL.RawQuery = query.Encode()
}

req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", w.requestToken))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("getAssertion: cannot request token: %v", err)
}

defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("getAssertion: cannot parse response: %v", err)
}

if c := resp.StatusCode; c < 200 || c > 299 {
return "", fmt.Errorf("getAssertion: received HTTP status %d with response: %s", resp.StatusCode, body)
}

var tokenRes struct {
Count *int `json:"count"`
Value *string `json:"value"`
}
if err := json.Unmarshal(body, &tokenRes); err != nil {
return "", fmt.Errorf("getAssertion: cannot unmarshal response: %v", err)
}

if tokenRes.Value == nil {
return "", fmt.Errorf("getAssertion: nil JWT assertion received from OIDC provider")
}

return *tokenRes.Value, nil
}
3 changes: 2 additions & 1 deletion command_before_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ func commandBeforeFunc(fset *FlagSet) func(ctx *cli.Context) error {
fset.flagUseEnvironmentCred,
fset.flagUseManagedIdentityCred,
fset.flagUseAzureCLICred,
fset.flagUseOIDCCred,
} {
if ok {
occur += 1
}
}
if occur > 1 {
return fmt.Errorf("only one of `--use-environment-cred`, `--use-managed-identity-cred` and `--use-azure-cli-cred` can be specified")
return fmt.Errorf("only one of `--use-environment-cred`, `--use-managed-identity-cred`, `--use-azure-cli-cred` and `--use-oidc-cred` can be specified")
}

// Initialize output directory
Expand Down
2 changes: 1 addition & 1 deletion command_before_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestCommondBeforeFunc(t *testing.T) {
flagUseEnvironmentCred: true,
flagUseAzureCLICred: true,
},
err: "only one of `--use-environment-cred`, `--use-managed-identity-cred` and `--use-azure-cli-cred` can be specified",
err: "only one of `--use-environment-cred`, `--use-managed-identity-cred`, `--use-azure-cli-cred` and `--use-oidc-cred` can be specified",
},
}

Expand Down
82 changes: 53 additions & 29 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,33 @@ var flagset FlagSet

type FlagSet struct {
// common flags
flagEnv string
flagSubscriptionId string
flagEnv string
flagSubscriptionId string
flagOutputDir string
flagOverwrite bool
flagAppend bool
flagDevProvider bool
flagProviderVersion string
flagBackendType string
flagBackendConfig cli.StringSlice
flagFullConfig bool
flagParallelism int
flagContinue bool
flagNonInteractive bool
flagPlainUI bool
flagGenerateMappingFile bool
flagHCLOnly bool
flagModulePath string

// common flags (auth)
flagUseEnvironmentCred bool
flagUseManagedIdentityCred bool
flagUseAzureCLICred bool
flagOutputDir string
flagOverwrite bool
flagAppend bool
flagDevProvider bool
flagProviderVersion string
flagBackendType string
flagBackendConfig cli.StringSlice
flagFullConfig bool
flagParallelism int
flagContinue bool
flagNonInteractive bool
flagPlainUI bool
flagGenerateMappingFile bool
flagHCLOnly bool
flagModulePath string
flagUseOIDCCred bool
flagOIDCRequestToken string
flagOIDCRequestURL string
flagOIDCTokenFilePath string
flagOIDCToken string

// common flags (hidden)
hflagMockClient bool
Expand Down Expand Up @@ -81,15 +88,6 @@ func (flag FlagSet) DescribeCLI(mode string) string {
if flag.flagOverwrite {
args = append(args, "--overwrite=true")
}
if flag.flagUseEnvironmentCred {
args = append(args, "--use-environment-cred=true")
}
if flag.flagUseManagedIdentityCred {
args = append(args, "--use-managed-identity-cred=true")
}
if flag.flagUseAzureCLICred {
args = append(args, "--use-azure-cli-cred=true")
}
if flag.flagAppend {
args = append(args, "--append=true")
}
Expand Down Expand Up @@ -120,12 +118,38 @@ func (flag FlagSet) DescribeCLI(mode string) string {
if flag.flagHCLOnly {
args = append(args, "--hcl-only=true")
}
if flag.hflagTFClientPluginPath != "" {
args = append(args, "--tfclient-plugin-path="+flag.hflagTFClientPluginPath)
}
if flag.flagModulePath != "" {
args = append(args, "--module-path="+flag.flagModulePath)
}

if flag.flagUseEnvironmentCred {
args = append(args, "--use-environment-cred=true")
}
if flag.flagUseManagedIdentityCred {
args = append(args, "--use-managed-identity-cred=true")
}
if flag.flagUseAzureCLICred {
args = append(args, "--use-azure-cli-cred=true")
}
if flag.flagUseOIDCCred {
args = append(args, "--use-oidc-cred=true")
}
if flag.flagOIDCRequestToken != "" {
args = append(args, "--oidc-request-token=*")
}
if flag.flagOIDCRequestURL != "" {
args = append(args, "--oidc-request-url="+flag.flagOIDCRequestURL)
}
if flag.flagOIDCTokenFilePath != "" {
args = append(args, "--oidc-token-file-path="+flag.flagOIDCTokenFilePath)
}
if flag.flagOIDCToken != "" {
args = append(args, "--oidc-token=*")
}

if flag.hflagTFClientPluginPath != "" {
args = append(args, "--tfclient-plugin-path="+flag.hflagTFClientPluginPath)
}
switch mode {
case ModeResource:
if flag.flagResName != "" {
Expand Down
Loading