Skip to content

Commit

Permalink
Adds callback mode that is direct to vault
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Dykstra <[email protected]>
  • Loading branch information
DrDaveD committed May 3, 2024
1 parent b8833ce commit 0b8ca06
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 85 deletions.
1 change: 1 addition & 0 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func backend() *jwtAuthBackend {
"login",
"oidc/auth_url",
"oidc/callback",
"oidc/poll",

// Uncomment to mount simple UI handler for local development
// "ui",
Expand Down
130 changes: 108 additions & 22 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path"
Expand All @@ -27,9 +28,11 @@ const (
defaultPort = "8250"
defaultCallbackHost = "localhost"
defaultCallbackMethod = "http"
defaultCallbackMode = "client"

FieldCallbackHost = "callbackhost"
FieldCallbackMethod = "callbackmethod"
FieldCallbackMode = "callbackmode"
FieldListenAddress = "listenaddress"
FieldPort = "port"
FieldCallbackPort = "callbackport"
Expand Down Expand Up @@ -69,19 +72,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
port = defaultPort
}

var vaultURL *url.URL
callbackMode, ok := m[FieldCallbackMode]
if !ok {
callbackMode = defaultCallbackMode
} else if callbackMode == "direct" {
vaultAddr := os.Getenv("VAULT_ADDR")
if vaultAddr != "" {
vaultURL, _ = url.Parse(vaultAddr)
}
}

callbackHost, ok := m[FieldCallbackHost]
if !ok {
callbackHost = defaultCallbackHost
if vaultURL != nil {
callbackHost = vaultURL.Hostname()
} else {
callbackHost = defaultCallbackHost
}
}

callbackMethod, ok := m[FieldCallbackMethod]
if !ok {
callbackMethod = defaultCallbackMethod
if vaultURL != nil {
callbackMethod = vaultURL.Scheme
} else {
callbackMethod = defaultCallbackMethod
}
}

callbackPort, ok := m[FieldCallbackPort]
if !ok {
callbackPort = port
if vaultURL != nil {
callbackPort = vaultURL.Port() + "/v1/auth/" + mount
} else {
callbackPort = port
}
}

parseBool := func(f string, d bool) (bool, error) {
Expand Down Expand Up @@ -115,20 +141,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro

role := m["role"]

authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
if err != nil {
return nil, err
}

// Set up callback handler
doneCh := make(chan loginResp)
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
var pollInterval string
var interval int
var state string
var listener net.Listener

if secret != nil {
pollInterval, _ = secret.Data["poll_interval"].(string)
state, _ = secret.Data["state"].(string)
}
if callbackMode == "direct" {
if state == "" {
return nil, errors.New("no state returned in direct callback mode")
}
if pollInterval == "" {
return nil, errors.New("no poll_interval returned in direct callback mode")
}
interval, err = strconv.Atoi(pollInterval)
if err != nil {
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
}
} else {
if state != "" {
return nil, errors.New("state returned in client callback mode, try direct")
}
if pollInterval != "" {
return nil, errors.New("poll_interval returned in client callback mode")
}
// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
}
defer listener.Close()
}
defer listener.Close()

// Open the default browser to the callback URL.
if !skipBrowserLaunch {
Expand All @@ -144,6 +199,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
}
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")

if callbackMode == "direct" {
data := map[string]interface{}{
"state": state,
"client_nonce": clientNonce,
}
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
for {
time.Sleep(time.Duration(interval) * time.Second)

secret, err := c.Logical().Write(pollUrl, data)
if err == nil {
return secret, nil
}
if !strings.HasSuffix(err.Error(), "authorization_pending") {
return nil, err
}
// authorization is pending, try again
}
}

// Start local server
go func() {
err := http.Serve(listener, nil)
Expand Down Expand Up @@ -210,12 +285,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
}
}

func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
var authURL string

clientNonce, err := base62.Random(20)
if err != nil {
return "", "", err
return "", "", nil, err
}

redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
Expand All @@ -227,18 +302,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho

secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
if err != nil {
return "", "", err
return "", "", nil, err
}

if secret != nil {
authURL = secret.Data["auth_url"].(string)
}

if authURL == "" {
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
}

return authURL, clientNonce, nil
return authURL, clientNonce, secret, nil
}

// parseError converts error from the API into summary and detailed portions.
Expand Down Expand Up @@ -292,35 +367,46 @@ Usage: vault login -method=oidc [CONFIG K=V...]
https://accounts.google.com/o/oauth2/v2/...
The default browser will be opened for the user to complete the login. Alternatively,
the user may visit the provided URL directly.
The default browser will be opened for the user to complete the login.
Alternatively, the user may visit the provided URL directly.
Configuration:
role=<string>
Vault role of type "OIDC" to use for authentication.
%s=<string>
Optional address to bind the OIDC callback listener to (default: localhost).
Mode of callback: "direct" for direct connection to Vault or "client"
for connection to command line client (default: client).
%s=<string>
Optional address to bind the OIDC callback listener to in client callback
mode (default: localhost).
%s=<string>
Optional localhost port to use for OIDC callback (default: 8250).
Optional localhost port to use for OIDC callback in client callback mode
(default: 8250).
%s=<string>
Optional method to to use in OIDC redirect_uri (default: http).
Optional method to use in OIDC redirect_uri (default: the method from
$VAULT_ADDR in direct callback mode, else http)
%s=<string>
Optional callback host address to use in OIDC redirect_uri (default: localhost).
Optional callback host address to use in OIDC redirect_uri (default:
the host from $VAULT_ADDR in direct callback mode, else localhost).
%s=<string>
Optional port to to use in OIDC redirect_uri (default: the value set for port).
Optional port to use in OIDC redirect_uri (default: the value set for
port in client callback mode, else the port from $VAULT_ADDR with an
added /v1/auth/<path> where <path> is from the login -path option).
%s=<bool>
Toggle the automatic launching of the default browser to the login URL. (default: false).
%s=<bool>
Abort on any error. (default: false).
`,
FieldCallbackMode,
FieldListenAddress, FieldPort, FieldCallbackMethod,
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
FieldAbortOnError,
Expand Down
File renamed without changes.
Loading

0 comments on commit 0b8ca06

Please sign in to comment.