Skip to content

Commit

Permalink
Adding audience field to device flow token request (#314)
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pmahindrakar-oss authored and eapolinario committed Sep 13, 2023
1 parent d040332 commit 3f50098
Show file tree
Hide file tree
Showing 15 changed files with 489 additions and 94 deletions.
2 changes: 2 additions & 0 deletions flyteidl/clients/go/admin/deviceflow/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ type DeviceAuthorizationRequest struct {
ClientID string `json:"client_id"`
// Scope is the scope parameter of the access request
Scope string `json:"scope"`
// Audience defines at which endpoints the token can be used.
Audience string `json:"audience"`
}

// DeviceAuthorizationResponse contains the information that the end user would use to authorize the app requesting the
Expand Down
14 changes: 9 additions & 5 deletions flyteidl/clients/go/admin/deviceflow/token_orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
)

const (
cliendID = "client_id"
audience = "audience"
clientID = "client_id"
deviceCode = "device_code"
grantType = "grant_type"
scope = "scope"
Expand All @@ -44,13 +45,15 @@ type TokenOrchestrator struct {

// StartDeviceAuthorization will initiate the OAuth2 device authorization flow.
func (t TokenOrchestrator) StartDeviceAuthorization(ctx context.Context, dareq DeviceAuthorizationRequest) (*DeviceAuthorizationResponse, error) {
v := url.Values{cliendID: {dareq.ClientID}, scope: {dareq.Scope}}
v := url.Values{clientID: {dareq.ClientID}, scope: {dareq.Scope}, audience: {dareq.Audience}}
httpReq, err := http.NewRequest("POST", t.ClientConfig.DeviceEndpoint, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

logger.Debugf(ctx, "Sending the following request to start device authorization %v with body %v", httpReq.URL, v.Encode())

httpResp, err := ctxhttp.Do(ctx, nil, httpReq)
if err != nil {
return nil, err
Expand Down Expand Up @@ -86,19 +89,20 @@ func (t TokenOrchestrator) StartDeviceAuthorization(ctx context.Context, dareq D
// PollTokenEndpoint polls the token endpoint until the user authorizes/ denies the app or an error occurs other than slow_down or authorization_pending
func (t TokenOrchestrator) PollTokenEndpoint(ctx context.Context, tokReq DeviceAccessTokenRequest, pollInterval time.Duration) (*oauth2.Token, error) {
v := url.Values{
cliendID: {tokReq.ClientID},
clientID: {tokReq.ClientID},
grantType: {grantTypeValue},
deviceCode: {tokReq.DeviceCode},
}

for {

httpReq, err := http.NewRequest("POST", t.ClientConfig.Endpoint.TokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")

logger.Debugf(ctx, "Sending the following request to fetch the token %v with body %v", httpReq.URL, v.Encode())

httpResp, err := ctxhttp.Do(ctx, nil, httpReq)
if err != nil {
return nil, err
Expand Down Expand Up @@ -150,7 +154,7 @@ func (t TokenOrchestrator) FetchTokenFromAuthFlow(ctx context.Context) (*oauth2.
if len(t.ClientConfig.Scopes) > 0 {
scopes = strings.Join(t.ClientConfig.Scopes, " ")
}
daReq := DeviceAuthorizationRequest{ClientID: t.ClientConfig.ClientID, Scope: scopes}
daReq := DeviceAuthorizationRequest{ClientID: t.ClientConfig.ClientID, Scope: scopes, Audience: t.ClientConfig.Audience}
daResp, err := t.StartDeviceAuthorization(ctx, daReq)
if err != nil {
return nil, err
Expand Down
26 changes: 24 additions & 2 deletions flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -46,9 +47,18 @@ func TestFetchFromAuthFlow(t *testing.T) {
body, err := io.ReadAll(r.Body)
assert.Nil(t, err)
isDeviceReq := strings.Contains(string(body), scope)
isTokReq := strings.Contains(string(body), deviceCode) && strings.Contains(string(body), grantType) && strings.Contains(string(body), cliendID)
isTokReq := strings.Contains(string(body), deviceCode) && strings.Contains(string(body), grantType) && strings.Contains(string(body), clientID)

if isDeviceReq {
for _, urlParm := range strings.Split(string(body), "&") {
paramKeyValue := strings.Split(urlParm, "=")
switch paramKeyValue[0] {
case audience:
assert.Equal(t, "abcd", paramKeyValue[1])
case clientID:
assert.Equal(t, clientID, paramKeyValue[1])
}
}
dar := DeviceAuthorizationResponse{
DeviceCode: "e1db31fe-3b23-4fce-b759-82bf8ea323d6",
UserCode: "RPBQZNRX",
Expand All @@ -62,6 +72,17 @@ func TestFetchFromAuthFlow(t *testing.T) {
assert.Nil(t, err)
return
} else if isTokReq {
for _, urlParm := range strings.Split(string(body), "&") {
paramKeyValue := strings.Split(urlParm, "=")
switch paramKeyValue[0] {
case grantType:
assert.Equal(t, url.QueryEscape(grantTypeValue), paramKeyValue[1])
case deviceCode:
assert.Equal(t, "e1db31fe-3b23-4fce-b759-82bf8ea323d6", paramKeyValue[1])
case clientID:
assert.Equal(t, clientID, paramKeyValue[1])
}
}
dar := DeviceAccessTokenResponse{
Token: oauth2.Token{
AccessToken: "access_token",
Expand All @@ -81,14 +102,15 @@ func TestFetchFromAuthFlow(t *testing.T) {
orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
ClientID: cliendID,
ClientID: clientID,
RedirectURL: "http://localhost:8089/redirect",
Scopes: []string{"code", "all"},
Endpoint: oauth2.Endpoint{
TokenURL: fakeServer.URL,
},
},
DeviceEndpoint: fakeServer.URL,
Audience: "abcd",
},
TokenCache: tokenCache,
}, Config{
Expand Down
3 changes: 3 additions & 0 deletions flyteidl/clients/go/admin/oauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
type Config struct {
*oauth2.Config
DeviceEndpoint string
// Audience value to be passed when requesting access token using device flow.This needs to be passed in the first request of the device flow currently and is configured in admin public client config.Required when auth server hasn't been configured with default audience"`
Audience string
}

// BuildConfigFromMetadataService builds OAuth2 config from information retrieved through the anonymous auth metadata service.
Expand All @@ -37,6 +39,7 @@ func BuildConfigFromMetadataService(ctx context.Context, authMetadataClient serv
},
},
DeviceEndpoint: oauthMetaResp.DeviceAuthorizationEndpoint,
Audience: clientResp.Audience,
}

return clientConf, nil
Expand Down
2 changes: 2 additions & 0 deletions flyteidl/clients/go/admin/oauth/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestGenerateClientConfig(t *testing.T) {
ClientId: "dummyClient",
RedirectUri: "dummyRedirectUri",
Scopes: []string{"dummyScopes"},
Audience: "dummyAudience",
}
oauthMetaDataResp := &service.OAuth2MetadataResponse{
Issuer: "dummyIssuer",
Expand All @@ -36,4 +37,5 @@ func TestGenerateClientConfig(t *testing.T) {
assert.Equal(t, "dummyTokenEndpoint", oauthConfig.Endpoint.TokenURL)
assert.Equal(t, "dummyAuthEndPoint", oauthConfig.Endpoint.AuthURL)
assert.Equal(t, "dummyDeviceEndpoint", oauthConfig.DeviceEndpoint)
assert.Equal(t, "dummyAudience", oauthConfig.Audience)
}
110 changes: 92 additions & 18 deletions flyteidl/gen/pb-cpp/flyteidl/service/auth.pb.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3f50098

Please sign in to comment.