Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWKS support #43

Merged
merged 2 commits into from
May 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package jwtauth

import (
"context"
"errors"
"sync"
"time"

oidc "github.com/coreos/go-oidc"
"github.com/coreos/go-oidc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
cache "github.com/patrickmn/go-cache"
"github.com/patrickmn/go-cache"
)

const (
Expand All @@ -30,6 +32,7 @@ type jwtAuthBackend struct {

l sync.RWMutex
provider *oidc.Provider
keySet oidc.KeySet
cachedConfig *jwtConfig
oidcStates *cache.Cache

Expand Down Expand Up @@ -126,6 +129,29 @@ func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error)
return provider, nil
}

// getKeySet returns a new JWKS KeySet based on the provided config.
func (b *jwtAuthBackend) getKeySet(config *jwtConfig) (oidc.KeySet, error) {
b.l.Lock()
defer b.l.Unlock()

if b.keySet != nil {
return b.keySet, nil
}

if config.JWKSURL == "" {
return nil, errors.New("keyset error: jwks_url not configured")
}

ctx, err := b.createCAContext(b.providerCtx, config.JWKSCAPEM)
if err != nil {
return nil, errwrap.Wrapf("error parsing jwks_ca_pem: {{err}}", err)
}

b.keySet = oidc.NewRemoteKeySet(ctx, config.JWKSURL)

return b.keySet, nil
}

const (
backendHelp = `
The JWT backend plugin allows authentication using JWTs (including OIDC).
Expand Down
119 changes: 93 additions & 26 deletions path_config.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package jwtauth

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"

"context"

oidc "github.com/coreos/go-oidc"
"github.com/coreos/go-oidc"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
Expand All @@ -24,11 +24,11 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Fields: map[string]*framework.FieldSchema{
"oidc_discovery_url": {
Type: framework.TypeString,
Description: `OIDC Discovery URL, without any .well-known component (base path). Cannot be used with "jwt_validation_pubkeys".`,
Description: `OIDC Discovery URL, without any .well-known component (base path). Cannot be used with "jwks_url" or "jwt_validation_pubkeys".`,
},
"oidc_discovery_ca_pem": {
Type: framework.TypeString,
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate conections to the OIDC Discovery URL. If not set, system certificates are used.",
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the OIDC Discovery URL. If not set, system certificates are used.",
},
"oidc_client_id": {
Type: framework.TypeString,
Expand All @@ -39,13 +39,21 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Description: "The OAuth Client Secret configured with your OIDC provider.",
DisplaySensitive: true,
},
"jwks_url": {
Type: framework.TypeString,
Description: `JWKS URL to use to authenticate signatures. Cannot be used with "oidc_discovery_url" or "jwt_validation_pubkeys".`,
},
"jwks_ca_pem": {
Type: framework.TypeString,
Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the JWKS URL. If not set, system certificates are used.",
},
"default_role": {
Type: framework.TypeString,
Description: "The default role to use if none is provided during login. If not set, a role is required during login.",
},
"jwt_validation_pubkeys": {
Type: framework.TypeCommaStringSlice,
Description: `A list of PEM-encoded public keys to use to authenticate signatures locally. Cannot be used with "oidc_discovery_url".`,
Description: `A list of PEM-encoded public keys to use to authenticate signatures locally. Cannot be used with "jwks_url" or "oidc_discovery_url".`,
},
"jwt_supported_algs": {
Type: framework.TypeCommaStringSlice,
Expand Down Expand Up @@ -76,8 +84,8 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
}

func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtConfig, error) {
b.l.RLock()
defer b.l.RUnlock()
b.l.Lock()
defer b.l.Unlock()

if b.cachedConfig != nil {
return b.cachedConfig, nil
Expand All @@ -92,10 +100,8 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon
}

result := &jwtConfig{}
if entry != nil {
if err := entry.DecodeJSON(result); err != nil {
return nil, err
}
if err := entry.DecodeJSON(result); err != nil {
return nil, err
}

for _, v := range result.JWTValidationPubKeys {
Expand Down Expand Up @@ -128,6 +134,8 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques
"default_role": config.DefaultRole,
"jwt_validation_pubkeys": config.JWTValidationPubKeys,
"jwt_supported_algs": config.JWTSupportedAlgs,
"jwks_url": config.JWKSURL,
"jwks_ca_pem": config.JWKSCAPEM,
"bound_issuer": config.BoundIssuer,
},
}
Expand All @@ -141,17 +149,29 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
OIDCDiscoveryCAPEM: d.Get("oidc_discovery_ca_pem").(string),
OIDCClientID: d.Get("oidc_client_id").(string),
OIDCClientSecret: d.Get("oidc_client_secret").(string),
JWKSURL: d.Get("jwks_url").(string),
JWKSCAPEM: d.Get("jwks_ca_pem").(string),
DefaultRole: d.Get("default_role").(string),
JWTValidationPubKeys: d.Get("jwt_validation_pubkeys").([]string),
JWTSupportedAlgs: d.Get("jwt_supported_algs").([]string),
BoundIssuer: d.Get("bound_issuer").(string),
}

// Run checks on values
methodCount := 0
if config.OIDCDiscoveryURL != "" {
methodCount++
}
if len(config.JWTValidationPubKeys) != 0 {
methodCount++
}
if config.JWKSURL != "" {
methodCount++
}

switch {
case config.OIDCDiscoveryURL == "" && len(config.JWTValidationPubKeys) == 0,
config.OIDCDiscoveryURL != "" && len(config.JWTValidationPubKeys) != 0:
return logical.ErrorResponse("exactly one of 'oidc_discovery_url' and 'jwt_validation_pubkeys' must be set"), nil
case methodCount != 1:
return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url' or 'oidc_discovery_url' must be set"), nil

case config.OIDCClientID != "" && config.OIDCClientSecret == "",
config.OIDCClientID == "" && config.OIDCClientSecret != "":
Expand All @@ -160,12 +180,32 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
case config.OIDCDiscoveryURL != "":
_, err := b.createProvider(config)
if err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error checking discovery URL: {{err}}", err).Error()), nil
return logical.ErrorResponse(errwrap.Wrapf("error checking oidc discovery URL: {{err}}", err).Error()), nil
}

case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil

case config.JWKSURL != "":
ctx, err := b.createCAContext(context.Background(), config.JWKSCAPEM)
if err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error checking jwks_ca_pem: {{err}}", err).Error()), nil
}

keyset := oidc.NewRemoteKeySet(ctx, config.JWKSURL)

// Try to verify a correctly formatted JWT. The signature will fail to match, but other
// errors with fetching the remote keyset should be reported.
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
}

if !strings.Contains(err.Error(), "failed to verify id token signature") {
kalafut marked this conversation as resolved.
Show resolved Hide resolved
return logical.ErrorResponse(errwrap.Wrapf("error checking jwks URL: {{err}}", err).Error()), nil
}

case len(config.JWTValidationPubKeys) != 0:
for _, v := range config.JWTValidationPubKeys {
if _, err := certutil.ParsePublicKeyPEM([]byte(v)); err != nil {
Expand Down Expand Up @@ -199,7 +239,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
oidcCtx, err := b.createOIDCContext(b.providerCtx, config)
oidcCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM)
if err != nil {
return nil, errwrap.Wrapf("error creating provider: {{err}}", err)
}
Expand All @@ -212,16 +252,16 @@ func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, erro
return provider, nil
}

// createOIDCContext returns a context with custom TLS client, configured with the root certificates
// from oidc_discovery_ca_pem. If no certificates are configured, the original context is returned.
func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfig) (context.Context, error) {
if config.OIDCDiscoveryCAPEM == "" {
// createCAContext returns a context with custom TLS client, configured with the root certificates
// from caPEM. If no certificates are configured, the original context is returned.
func (b *jwtAuthBackend) createCAContext(ctx context.Context, caPEM string) (context.Context, error) {
if caPEM == "" {
return ctx, nil
}

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok {
return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully")
if ok := certPool.AppendCertsFromPEM([]byte(caPEM)); !ok {
return nil, errors.New("could not parse CA PEM value successfully")
}

tr := cleanhttp.DefaultPooledTransport()
Expand All @@ -234,16 +274,18 @@ func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfi
Transport: tr,
}

oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)
caCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)

return oidcCtx, nil
return caCtx, nil
}

type jwtConfig struct {
OIDCDiscoveryURL string `json:"oidc_discovery_url"`
OIDCDiscoveryCAPEM string `json:"oidc_discovery_ca_pem"`
OIDCClientID string `json:"oidc_client_id"`
OIDCClientSecret string `json:"oidc_client_secret"`
JWKSURL string `json:"jwks_url"`
JWKSCAPEM string `json:"jwks_ca_pem"`
JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"`
JWTSupportedAlgs []string `json:"jwt_supported_algs"`
BoundIssuer string `json:"bound_issuer"`
Expand All @@ -252,6 +294,31 @@ type jwtConfig struct {
ParsedJWTPubKeys []interface{} `json:"-"`
}

const (
StaticKeys = iota
JWKS
OIDCDiscovery
OIDCFlow
unconfigured
)

// authType classifies the authorization type/flow based on config parameters.
func (c jwtConfig) authType() int {
switch {
case len(c.ParsedJWTPubKeys) > 0:
return StaticKeys
case c.JWKSURL != "":
return JWKS
case c.OIDCDiscoveryURL != "":
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
return OIDCFlow
}
return OIDCDiscovery
}

return unconfigured
}

const (
confHelpSyn = `
Configures the JWT authentication backend.
Expand Down
Loading