Skip to content

Commit

Permalink
Add device flow
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Dykstra <[email protected]>
  • Loading branch information
DrDaveD committed May 3, 2024
1 parent 0b8ca06 commit f4096a2
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 76 deletions.
38 changes: 33 additions & 5 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,20 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
var pollInterval string
var interval int
var state string
var userCode string
var listener net.Listener

if secret != nil {
pollInterval, _ = secret.Data["poll_interval"].(string)
state, _ = secret.Data["state"].(string)
userCode, _ = secret.Data["user_code"].(string)
}
if callbackMode == "direct" {
if callbackMode != "client" {
if state == "" {
return nil, errors.New("no state returned in direct callback mode")
return nil, errors.New("no state returned in " + callbackMode + " callback mode")
}
if pollInterval == "" {
return nil, errors.New("no poll_interval returned in direct callback mode")
return nil, errors.New("no poll_interval returned in " + callbackMode + " callback mode")
}
interval, err = strconv.Atoi(pollInterval)
if err != nil {
Expand Down Expand Up @@ -218,6 +220,31 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
// authorization is pending, try again
}
}
if userCode != "" {
fmt.Fprintf(os.Stderr, "When prompted, enter code %s\n\n", userCode)
}

if callbackMode != "client" {
data := map[string]interface{}{
"state": state,
"client_nonce": clientNonce,
}
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
for {
time.Sleep(time.Duration(interval) * time.Second)

secret, err := c.Logical().Write(pollUrl, data)
if err == nil {
return secret, nil
}
if strings.HasSuffix(err.Error(), "slow_down") {
interval *= 2
} else if !strings.HasSuffix(err.Error(), "authorization_pending") {
return nil, err
}
// authorization is pending, try again
}
}

// Start local server
go func() {
Expand Down Expand Up @@ -376,8 +403,9 @@ Configuration:
Vault role of type "OIDC" to use for authentication.
%s=<string>
Mode of callback: "direct" for direct connection to Vault or "client"
for connection to command line client (default: client).
Mode of callback: "direct" for direct connection to Vault, "client"
for connection to command line client, or "device" for device flow
which has no callback (default: client).
%s=<string>
Optional address to bind the OIDC callback listener to in client callback
Expand Down
91 changes: 91 additions & 0 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"

"github.com/hashicorp/cap/jwt"
Expand Down Expand Up @@ -168,6 +171,91 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon
return config, nil
}

func contactIssuer(ctx context.Context, uri string, data *url.Values, ignoreBad bool) ([]byte, error) {
var req *http.Request
var err error
if data == nil {
req, err = http.NewRequest("GET", uri, nil)
} else {
req, err = http.NewRequest("POST", uri, strings.NewReader(data.Encode()))
}
if err != nil {
return nil, nil
}
if data != nil {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}

client, ok := ctx.Value(oauth2.HTTPClient).(*http.Client)
if !ok {
client = http.DefaultClient
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return nil, nil
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, nil
}

if resp.StatusCode != http.StatusOK && (!ignoreBad || resp.StatusCode != http.StatusBadRequest) {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}

return body, nil
}

// Discover the device_authorization_endpoint URL and store it in the config
// This should be in coreos/go-oidc but they don't yet support device flow
// At the same time, look up token_endpoint and store it as well
// Returns nil on success, otherwise returns an error
func (b *jwtAuthBackend) configDeviceAuthURL(ctx context.Context, s logical.Storage) error {
config, err := b.config(ctx, s)
if err != nil {
return err
}

b.l.Lock()
defer b.l.Unlock()

if config.OIDCDeviceAuthURL != "" {
if config.OIDCDeviceAuthURL == "N/A" {
return fmt.Errorf("no device auth endpoint url discovered")
}
return nil
}

caCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM)
if err != nil {
return errwrap.Wrapf("error creating context for device auth: {{err}}", err)
}

issuer := config.OIDCDiscoveryURL

wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
body, err := contactIssuer(caCtx, wellKnown, nil, false)
if err != nil {
return errwrap.Wrapf("error reading issuer config: {{err}}", err)
}

var daj struct {
DeviceAuthURL string `json:"device_authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
}
err = json.Unmarshal(body, &daj)
if err != nil || daj.DeviceAuthURL == "" {
b.cachedConfig.OIDCDeviceAuthURL = "N/A"
return fmt.Errorf("no device auth endpoint url discovered")
}

b.cachedConfig.OIDCDeviceAuthURL = daj.DeviceAuthURL
b.cachedConfig.OIDCTokenURL = daj.TokenURL
return nil
}

func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config, err := b.config(ctx, req.Storage)
if err != nil {
Expand Down Expand Up @@ -457,6 +545,9 @@ type jwtConfig struct {
NamespaceInState bool `json:"namespace_in_state"`

ParsedJWTPubKeys []crypto.PublicKey `json:"-"`
// These are looked up from OIDCDiscoveryURL when needed
OIDCDeviceAuthURL string `json:"-"`
OIDCTokenURL string `json:"-"`
}

const (
Expand Down
Loading

0 comments on commit f4096a2

Please sign in to comment.