From 1db39e62c65296eeffb1eeb35f9dcd12613f27b2 Mon Sep 17 00:00:00 2001 From: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:16:08 -0500 Subject: [PATCH] Add device flow Signed-off-by: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> --- cli.go | 23 +++-- path_config.go | 91 +++++++++++++++++ path_oidc.go | 252 ++++++++++++++++++++++++++++++++++++++++------ path_oidc_test.go | 102 ++++++++++++------- path_role.go | 19 +++- 5 files changed, 409 insertions(+), 78 deletions(-) diff --git a/cli.go b/cli.go index d84a6c0d..15bc04d0 100644 --- a/cli.go +++ b/cli.go @@ -151,18 +151,20 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro var pollInterval string var interval int var state string + var userCode string var listener net.Listener if secret != nil { pollInterval, _ = secret.Data["poll_interval"].(string) state, _ = secret.Data["state"].(string) + userCode, _ = secret.Data["user_code"].(string) } - if callbackMode == "direct" { + if callbackMode != "client" { if state == "" { - return nil, errors.New("no state returned in direct callback mode") + return nil, errors.New("no state returned in " + callbackMode + " callback mode") } if pollInterval == "" { - return nil, errors.New("no poll_interval returned in direct callback mode") + return nil, errors.New("no poll_interval returned in " + callbackMode + " callback mode") } interval, err = strconv.Atoi(pollInterval) if err != nil { @@ -199,7 +201,11 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro } fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n") - if callbackMode == "direct" { + if userCode != "" { + fmt.Fprintf(os.Stderr, "When prompted, enter code %s\n\n", userCode) + } + + if callbackMode != "client" { data := map[string]interface{}{ "state": state, "client_nonce": clientNonce, @@ -212,7 +218,9 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro if err == nil { return secret, nil } - if !strings.HasSuffix(err.Error(), "authorization_pending") { + if strings.HasSuffix(err.Error(), "slow_down") { + interval *= 2 + } else if !strings.HasSuffix(err.Error(), "authorization_pending") { return nil, err } // authorization is pending, try again @@ -376,8 +384,9 @@ Configuration: Vault role of type "OIDC" to use for authentication. %s= - Mode of callback: "direct" for direct connection to Vault or "client" - for connection to command line client (default: client). + Mode of callback: "direct" for direct connection to Vault, "client" + for connection to command line client, or "device" for device flow + which has no callback (default: client). %s= Optional address to bind the OIDC callback listener to in client callback diff --git a/path_config.go b/path_config.go index 6d137bac..38ae1bfb 100644 --- a/path_config.go +++ b/path_config.go @@ -9,9 +9,12 @@ import ( "crypto/tls" "crypto/x509" "encoding/asn1" + "encoding/json" "errors" "fmt" + "io/ioutil" "net/http" + "net/url" "strings" "github.com/hashicorp/cap/jwt" @@ -174,6 +177,91 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon return config, nil } +func contactIssuer(ctx context.Context, uri string, data *url.Values, ignoreBad bool) ([]byte, error) { + var req *http.Request + var err error + if data == nil { + req, err = http.NewRequest("GET", uri, nil) + } else { + req, err = http.NewRequest("POST", uri, strings.NewReader(data.Encode())) + } + if err != nil { + return nil, err + } + if data != nil { + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + } + + client, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) + if !ok { + client = http.DefaultClient + } + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return nil, nil + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, nil + } + + if resp.StatusCode != http.StatusOK && (!ignoreBad || resp.StatusCode != http.StatusBadRequest) { + return nil, fmt.Errorf("%s: %s", resp.Status, body) + } + + return body, nil +} + +// Discover the device_authorization_endpoint URL and store it in the config +// This should be in coreos/go-oidc but they don't yet support device flow +// At the same time, look up token_endpoint and store it as well +// Returns nil on success, otherwise returns an error +func (b *jwtAuthBackend) configDeviceAuthURL(ctx context.Context, s logical.Storage) error { + config, err := b.config(ctx, s) + if err != nil { + return err + } + + b.l.Lock() + defer b.l.Unlock() + + if config.OIDCDeviceAuthURL != "" { + if config.OIDCDeviceAuthURL == "N/A" { + return fmt.Errorf("no device auth endpoint url discovered") + } + return nil + } + + caCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM) + if err != nil { + return errwrap.Wrapf("error creating context for device auth: {{err}}", err) + } + + issuer := config.OIDCDiscoveryURL + + wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" + body, err := contactIssuer(caCtx, wellKnown, nil, false) + if err != nil { + return errwrap.Wrapf("error reading issuer config: {{err}}", err) + } + + var daj struct { + DeviceAuthURL string `json:"device_authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + } + err = json.Unmarshal(body, &daj) + if err != nil || daj.DeviceAuthURL == "" { + b.cachedConfig.OIDCDeviceAuthURL = "N/A" + return fmt.Errorf("no device auth endpoint url discovered") + } + + b.cachedConfig.OIDCDeviceAuthURL = daj.DeviceAuthURL + b.cachedConfig.OIDCTokenURL = daj.TokenURL + return nil +} + func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { config, err := b.config(ctx, req.Storage) if err != nil { @@ -502,6 +590,9 @@ type jwtConfig struct { UnsupportedCriticalCertExtensions []string `json:"unsupported_critical_cert_extensions"` ParsedJWTPubKeys []crypto.PublicKey `json:"-"` + // These are looked up from OIDCDiscoveryURL when needed + OIDCDeviceAuthURL string `json:"-"` + OIDCTokenURL string `json:"-"` } const ( diff --git a/path_oidc.go b/path_oidc.go index 6f7304d9..5f81ece8 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -54,6 +54,9 @@ type oidcRequest struct { // this is for storing the response in direct callback mode auth *logical.Auth + + // the device flow code + deviceCode string } func pathOIDC(b *jwtAuthBackend) []*framework.Path { @@ -146,7 +149,7 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, "redirect_uri": { Type: framework.TypeString, - Description: "The OAuth redirect_uri to use in the authorization URL.", + Description: "The OAuth redirect_uri to use in the authorization URL. Not needed with device flow.", }, "client_nonce": { Type: framework.TypeString, @@ -241,6 +244,13 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } + deleteRequest := true + defer func() { + if deleteRequest { + b.deleteOIDCRequest(stateID) + } + }() + roleName := oidcReq.rolename role, err := b.role(ctx, req.Storage, roleName) if err != nil { @@ -248,17 +258,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return nil, err } if role == nil { - b.deleteOIDCRequest(stateID) return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil } useHttp := false if role.CallbackMode == callbackModeDirect { useHttp = true - } - if !useHttp { - // state is only accessed once when not using direct callback - b.deleteOIDCRequest(stateID) + // save request for poll + deleteRequest = false } errorDescription := d.Get("error_description").(string) @@ -290,8 +297,13 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err) } - var rawToken oidc.IDToken + oidcCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err) + } + var token *oidc.Tk + var tokenSource oauth2.TokenSource code := d.Get("code").(string) if code == noCode { @@ -304,10 +316,15 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } // Verify the ID token received from the authentication response. - rawToken = oidc.IDToken(oidcReq.idToken) + rawToken := oidc.IDToken(oidcReq.idToken) if _, err := provider.VerifyIDToken(ctx, rawToken, oidcReq); err != nil { return logical.ErrorResponse("%s %s", errTokenVerification, err.Error()), nil } + + token, err = oidc.NewToken(rawToken, nil) + if err != nil { + return nil, errwrap.Wrapf("error creating oidc token: {{err}}", err) + } } else { // Exchange the authorization code for an ID token and access token. // ID token verification takes place in provider.Exchange. @@ -316,13 +333,19 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return loginFailedResponse(useHttp, fmt.Sprintf("Error exchanging oidc code: %q.", err.Error())), nil } - rawToken = token.IDToken() + tokenSource = token.StaticTokenSource() } + return b.processToken(ctx, config, oidcCtx, provider, roleName, role, token, tokenSource, stateID, oidcReq, useHttp) +} + +// Continue processing a token after it has been received from the +// OIDC provider from either code or device authorization flows +func (b *jwtAuthBackend) processToken(ctx context.Context, config *jwtConfig, oidcCtx context.Context, provider *oidc.Provider, roleName string, role *jwtRole, token *oidc.Tk, tokenSource oauth2.TokenSource, stateID string, oidcReq *oidcRequest, useHttp bool) (*logical.Response, error) { if role.VerboseOIDCLogging { loggedToken := "invalid token format" - parts := strings.Split(string(rawToken), ".") + parts := strings.Split(string(token.IDToken()), ".") if len(parts) == 3 { // strip signature from logged token loggedToken = fmt.Sprintf("%s.%s.xxxxxxxxxxx", parts[0], parts[1]) @@ -333,10 +356,16 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, // Parse claims from the ID token payload. var allClaims map[string]interface{} - if err := rawToken.Claims(&allClaims); err != nil { + if err := token.IDToken().Claims(&allClaims); err != nil { return nil, err } - delete(allClaims, "nonce") + + if claimNonce, ok := allClaims["nonce"]; ok { + if oidcReq != nil && claimNonce != oidcReq.Nonce() { + return loginFailedResponse(useHttp, "invalid ID token nonce."), nil + } + delete(allClaims, "nonce") + } // Get the subject claim for bound subject and user info validation var subject string @@ -348,15 +377,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return loginFailedResponse(useHttp, "sub claim does not match bound subject"), nil } - // Set the token source for the access token if it's available. It will only - // be available for the authorization code flow (oidc_response_types=code). - // The access token will be used for fetching additional user and group info. - var tokenSource oauth2.TokenSource - if token != nil { - tokenSource = token.StaticTokenSource() - } - - // If we have a token, attempt to fetch information from the /userinfo endpoint + // If we have a tokenSource, attempt to fetch information from the /userinfo endpoint // and merge it with the existing claims data. A failure to fetch additional information // from this endpoint will not invalidate the authorization flow. if tokenSource != nil { @@ -428,27 +449,116 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return resp, nil } +// second half of the client API for direct and device callback modes func (b *jwtAuthBackend) pathPoll(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { stateID := d.Get("state").(string) - oidcReq := b.getOIDCRequest(stateID) if oidcReq == nil { return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } + deleteRequest := true + defer func() { + if deleteRequest { + b.deleteOIDCRequest(stateID) + } + }() + clientNonce := d.Get("client_nonce").(string) if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { - b.deleteOIDCRequest(stateID) return logical.ErrorResponse("invalid client_nonce"), nil } + roleName := oidcReq.rolename + role, err := b.role(ctx, req.Storage, roleName) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil + } + + if role.CallbackMode == callbackModeDevice { + config, err := b.config(ctx, req.Storage) + if err != nil { + return nil, err + } + if config == nil { + return logical.ErrorResponse(errLoginFailed + " Could not load configuration"), nil + } + + caCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, err + } + provider, err := b.getProvider(config) + if err != nil { + return nil, errwrap.Wrapf("error getting provider for poll operation: {{err}}", err) + } + + values := url.Values{ + "client_id": {config.OIDCClientID}, + "client_secret": {config.OIDCClientSecret}, + "device_code": {oidcReq.deviceCode}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + } + body, err := contactIssuer(caCtx, config.OIDCTokenURL, &values, true) + if err != nil { + return nil, errwrap.Wrapf("error polling for device authorization: {{err}}", err) + } + + var tokenOrError struct { + *oauth2.Token + Error string `json:"error,omitempty"` + } + err = json.Unmarshal(body, &tokenOrError) + if err != nil { + return nil, fmt.Errorf("error decoding issuer response while polling for token: %v; response: %v", err, string(body)) + } + + if tokenOrError.Error != "" { + if tokenOrError.Error == "authorization_pending" || tokenOrError.Error == "slow_down" { + // save request for another poll + deleteRequest = false + return logical.ErrorResponse(tokenOrError.Error), nil + } + return logical.ErrorResponse("authorization failed: %v", tokenOrError.Error), nil + } + + extra := make(map[string]interface{}) + err = json.Unmarshal(body, &extra) + if err != nil { + // already been unmarshalled once, unlikely + return nil, err + } + oauth2Token := tokenOrError.Token.WithExtra(extra) + + // idToken, ok := oauth2Token.Extra("id_token").(oidc.IDToken) + rawToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return logical.ErrorResponse(errTokenVerification + " No id_token found in response."), nil + } + idToken := oidc.IDToken(rawToken) + token, err := oidc.NewToken(idToken, tokenOrError.Token) + if err != nil { + return nil, errwrap.Wrapf("error creating oidc token: {{err}}", err) + } + + return b.processToken(ctx, config, caCtx, provider, roleName, role, token, oauth2.StaticTokenSource(oauth2Token), "", nil, false) + } + + // else it's the direct callback mode + if oidcReq.auth == nil { + // save request for another poll + deleteRequest = false + } + if oidcReq.auth == nil { // Return the same response as oauth 2.0 device flow in RFC8628 return logical.ErrorResponse("authorization_pending"), nil } - b.deleteOIDCRequest(stateID) resp := &logical.Response{ Auth: oidcReq.auth, } @@ -490,9 +600,6 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f } redirectURI := d.Get("redirect_uri").(string) - if redirectURI == "" { - return logical.ErrorResponse("missing redirect_uri"), nil - } role, err := b.role(ctx, req.Storage, roleName) if err != nil { @@ -503,10 +610,88 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f } clientNonce := d.Get("client_nonce").(string) - if clientNonce == "" && role.CallbackMode == callbackModeDirect { + if clientNonce == "" && + (role.CallbackMode == callbackModeDirect || + role.CallbackMode == callbackModeDevice) { return logical.ErrorResponse("missing client_nonce"), nil } + if role.CallbackMode == callbackModeDevice { + // start a device flow + caCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, err + } + + // Discover the device url endpoint if not already known + // This adds it to the cached config + err = b.configDeviceAuthURL(ctx, req.Storage) + if err != nil { + return nil, err + } + + // "openid" is a required scope for OpenID Connect flows + scopes := append([]string{"openid"}, role.OIDCScopes...) + + values := url.Values{ + "client_id": {config.OIDCClientID}, + "client_secret": {config.OIDCClientSecret}, + "scope": {strings.Join(scopes, " ")}, + } + body, err := contactIssuer(caCtx, config.OIDCDeviceAuthURL, &values, false) + if err != nil { + return nil, errwrap.Wrapf("error authorizing device: {{err}}", err) + } + + var deviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + // Google and other old implementations use url instead of uri + VerificationURL string `json:"verification_url"` + VerificationURLComplete string `json:"verification_url_complete"` + Interval int `json:"interval"` + } + err = json.Unmarshal(body, &deviceCode) + if err != nil { + return nil, fmt.Errorf("error decoding issuer response to device auth: %v; response: %v", err, string(body)) + } + // currently hashicorp/cap/oidc.NewRequest requires + // redirectURL to be non-empty so throw in place holder + oidcReq, err := b.createOIDCRequest(config, role, roleName, "-", deviceCode.DeviceCode, clientNonce) + if err != nil { + logger.Warn("error generating OAuth state", "error", err) + return resp, nil + } + + if deviceCode.VerificationURIComplete != "" { + resp.Data["auth_url"] = deviceCode.VerificationURIComplete + } else if deviceCode.VerificationURLComplete != "" { + resp.Data["auth_url"] = deviceCode.VerificationURLComplete + } else { + if deviceCode.VerificationURI != "" { + resp.Data["auth_url"] = deviceCode.VerificationURI + } else { + resp.Data["auth_url"] = deviceCode.VerificationURL + } + resp.Data["user_code"] = deviceCode.UserCode + } + resp.Data["state"] = oidcReq.State() + interval := 5 + if role.PollInterval != 0 { + interval = role.PollInterval + } else if deviceCode.Interval != 0 { + interval = deviceCode.Interval + } + resp.Data["poll_interval"] = fmt.Sprintf("%d", interval) + return resp, nil + } + + if redirectURI == "" { + return logical.ErrorResponse("missing redirect_uri"), nil + } + // If namespace will be passed around in oidcReq, and it has been provided as // a redirectURI query parameter, remove it from redirectURI, and append it // to the oidcReq (later in this function) @@ -545,7 +730,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return resp, nil } - oidcReq, err := b.createOIDCRequest(config, role, roleName, redirectURI, clientNonce) + oidcReq, err := b.createOIDCRequest(config, role, roleName, redirectURI, "", clientNonce) if err != nil { logger.Warn("error generating OAuth state", "error", err) return resp, nil @@ -566,7 +751,11 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f resp.Data["auth_url"] = urlStr if role.CallbackMode == callbackModeDirect { resp.Data["state"] = oidcReq.State() - resp.Data["poll_interval"] = "5" + interval := 5 + if role.PollInterval != 0 { + interval = role.PollInterval + } + resp.Data["poll_interval"] = fmt.Sprintf("%d", interval) } return resp, nil @@ -574,7 +763,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f // createOIDCRequest makes an expiring request object, associated with a random state ID // that is passed throughout the OAuth process. A nonce is also included in the auth process. -func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rolename, redirectURI, clientNonce string) (*oidcRequest, error) { +func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rolename, redirectURI, deviceCode string, clientNonce string) (*oidcRequest, error) { options := []oidc.Option{ oidc.WithAudiences(role.BoundAudiences...), oidc.WithScopes(role.OIDCScopes...), @@ -604,6 +793,7 @@ func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rol Request: request, rolename: rolename, clientNonce: clientNonce, + deviceCode: deviceCode, } b.oidcRequests.SetDefault(request.State(), oidcReq) diff --git a/path_oidc_test.go b/path_oidc_test.go index c78c7f49..2ab7f12c 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -774,14 +774,16 @@ func TestOIDC_Callback(t *testing.T) { t.Run("successful login", func(t *testing.T) { // run test with and without bound_cidrs configured // and with and without direct callback mode - for i := 1; i <= 3; i++ { + for i := 1; i <= 4; i++ { var useBoundCIDRs bool - var callbackMode string + callbackMode := "client" if i == 2 { useBoundCIDRs = true } else if i == 3 { callbackMode = "direct" + } else if i == 4 { + callbackMode = "device" } b, storage, s := getBackendAndServer(t, useBoundCIDRs, callbackMode) @@ -789,6 +791,9 @@ func TestOIDC_Callback(t *testing.T) { clientNonce := "456" + // set mock provider's expected code + s.code = "abc" + // get auth_url data := map[string]interface{}{ "role": "test", @@ -807,42 +812,45 @@ func TestOIDC_Callback(t *testing.T) { t.Fatalf("err:%v resp:%#v\n", err, resp) } - authURL := resp.Data["auth_url"].(string) - - state := getQueryParam(t, authURL, "state") - nonce := getQueryParam(t, authURL, "nonce") + var state string - // set provider claims that will be returned by the mock server - s.customClaims = sampleClaims(nonce) + if callbackMode == "device" { + state = resp.Data["state"].(string) + s.customClaims = sampleClaims("") + } else { + authURL := resp.Data["auth_url"].(string) + state = getQueryParam(t, authURL, "state") + nonce := getQueryParam(t, authURL, "nonce") - // set mock provider's expected code - s.code = "abc" + // set provider claims that will be returned by the mock server + s.customClaims = sampleClaims(nonce) - // save PKCE challenge - s.codeChallenge = getQueryParam(t, authURL, "code_challenge") + // save PKCE challenge + s.codeChallenge = getQueryParam(t, authURL, "code_challenge") - // invoke the callback, which will try to exchange the code - // with the mock provider. - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "oidc/callback", - Storage: storage, - Data: map[string]interface{}{ - "state": state, - "code": "abc", - "client_nonce": clientNonce, - }, - Connection: &logical.Connection{ - RemoteAddr: "127.0.0.42", - }, - } + // invoke the callback, which will try to exchange the code + // with the mock provider. + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "oidc/callback", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "code": "abc", + "client_nonce": clientNonce, + }, + Connection: &logical.Connection{ + RemoteAddr: "127.0.0.42", + }, + } - resp, err = b.HandleRequest(context.Background(), req) - if err != nil { - t.Fatal(err) + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } } - if callbackMode == "direct" { + if callbackMode != "client" { req = &logical.Request{ Operation: logical.UpdateOperation, Path: "oidc/poll", @@ -1466,6 +1474,7 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { { "issuer": "%s", "authorization_endpoint": "%s/auth", + "device_authorization_endpoint": "%s/device", "token_endpoint": "%s/token", "jwks_uri": "%s/certs", "userinfo_endpoint": "%s/userinfo" @@ -1477,21 +1486,38 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) case "/certs_invalid": w.Write([]byte("It's not a keyset!")) + case "/device": + values := map[string]interface{}{ + "device_code": o.code, + } + data, err := json.Marshal(values) + if err != nil { + o.t.Fatal(err) + } + w.Write(data) case "/token": - code := r.FormValue("code") - codeVerifier := r.FormValue("code_verifier") + var code string + grant_type := r.FormValue("grant_type") + if grant_type == "urn:ietf:params:oauth:grant-type:device_code" { + code = r.FormValue("device_code") + } else { + code = r.FormValue("code") + } if code != o.code { w.WriteHeader(401) break } - sum := sha256.Sum256([]byte(codeVerifier)) - computedChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + if o.codeChallenge != "" { + codeVerifier := r.FormValue("code_verifier") + sum := sha256.Sum256([]byte(codeVerifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) - if computedChallenge != o.codeChallenge { - w.WriteHeader(401) - break + if computedChallenge != o.codeChallenge { + w.WriteHeader(401) + break + } } stdClaims := jwt.Claims{ diff --git a/path_role.go b/path_role.go index 6e1b33e6..8c5f56c5 100644 --- a/path_role.go +++ b/path_role.go @@ -26,6 +26,7 @@ const ( boundClaimsTypeGlob = "glob" callbackModeDirect = "direct" callbackModeClient = "client" + callbackModeDevice = "device" ) func pathRoleList(b *jwtAuthBackend) *framework.Path { @@ -158,9 +159,14 @@ for referencing claims.`, }, "callback_mode": { Type: framework.TypeString, - Description: `OIDC callback mode from Authorization Server: allowed values are 'direct' to Vault or 'client', default 'client'`, + Description: `OIDC callback mode from Authorization Server: allowed values are 'device' for device flow, 'direct' to Vault, or 'client', default 'client'`, Default: callbackModeClient, }, + "poll_interval": { + Type: framework.TypeInt, + Description: `poll interval in seconds for device and direct flows, default value from Authorization Server for device flow, or '5'`, + // don't set Default here because server may set a default + }, "verbose_oidc_logging": { Type: framework.TypeBool, Description: `Log received OIDC tokens and claims when debug-level logging is active. @@ -230,6 +236,7 @@ type jwtRole struct { OIDCScopes []string `json:"oidc_scopes"` AllowedRedirectURIs []string `json:"allowed_redirect_uris"` CallbackMode string `json:"callback_mode"` + PollInterval int `json:"poll_interval"` VerboseOIDCLogging bool `json:"verbose_oidc_logging"` MaxAge time.Duration `json:"max_age"` UserClaimJSONPointer bool `json:"user_claim_json_pointer"` @@ -346,6 +353,10 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, role.PopulateTokenData(d) + if role.PollInterval > 0 { + d["poll_interval"] = role.PollInterval + } + if len(role.Policies) > 0 { d["policies"] = d["token_policies"] } @@ -564,9 +575,13 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role.AllowedRedirectURIs = allowedRedirectURIs.([]string) } + if pollInterval, ok := data.GetOk("poll_interval"); ok { + role.PollInterval = pollInterval.(int) + } + callbackMode := data.Get("callback_mode").(string) switch callbackMode { - case callbackModeDirect, callbackModeClient: + case callbackModeDevice, callbackModeDirect, callbackModeClient: role.CallbackMode = callbackMode default: return logical.ErrorResponse("invalid 'callback_mode': %s", callbackMode), nil