diff --git a/backend.go b/backend.go index a9b4ac1a..d20f3ff4 100644 --- a/backend.go +++ b/backend.go @@ -100,7 +100,7 @@ func (b *jwtAuthBackend) reset() { b.l.Unlock() } -func (b *jwtAuthBackend) getProvider(ctx context.Context, config *jwtConfig) (*oidc.Provider, error) { +func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) { b.l.RLock() unlockFunc := b.l.RUnlock defer func() { unlockFunc() }() diff --git a/path_config.go b/path_config.go index eede23d3..c3d62640 100644 --- a/path_config.go +++ b/path_config.go @@ -199,12 +199,29 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque } func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) { - var certPool *x509.CertPool - if config.OIDCDiscoveryCAPEM != "" { - certPool = x509.NewCertPool() - if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok { - return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully") - } + oidcCtx, err := b.createOIDCContext(b.providerCtx, config) + if err != nil { + return nil, errwrap.Wrapf("error creating provider: {{err}}", err) + } + + provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL) + if err != nil { + return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err) + } + + return provider, nil +} + +// createOIDCContext returns a context with custom TLS client, configured with the root certificates +// from oidc_discovery_ca_pem. If no certificates are configured, the original context is returned. +func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfig) (context.Context, error) { + if config.OIDCDiscoveryCAPEM == "" { + return ctx, nil + } + + certPool := x509.NewCertPool() + if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok { + return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully") } tr := cleanhttp.DefaultPooledTransport() @@ -216,14 +233,10 @@ func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, erro tc := &http.Client{ Transport: tr, } - oidcCtx := context.WithValue(b.providerCtx, oauth2.HTTPClient, tc) - provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL) - if err != nil { - return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err) - } + oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc) - return provider, nil + return oidcCtx, nil } type jwtConfig struct { diff --git a/path_login.go b/path_login.go index 2cbed778..f7823b04 100644 --- a/path_login.go +++ b/path_login.go @@ -225,7 +225,7 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, role *jwtRole, rawToken string) (map[string]interface{}, error) { allClaims := make(map[string]interface{}) - provider, err := b.getProvider(ctx, config) + provider, err := b.getProvider(config) if err != nil { return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err) } diff --git a/path_oidc.go b/path_oidc.go index a5a72b7f..69c6cb5a 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -103,9 +103,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return logical.ErrorResponse("request originated from invalid CIDR"), nil } - provider, err := b.getProvider(ctx, config) + provider, err := b.getProvider(config) if err != nil { - return nil, errwrap.Wrapf(errLoginFailed+" Error getting provider for login operation: {{err}}", err) + return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err) + } + + oidcCtx, err := b.createOIDCContext(ctx, config) + if err != nil { + return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err) } var oauth2Config = oauth2.Config{ @@ -121,7 +126,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return logical.ErrorResponse(errLoginFailed + " OAuth code parameter not provided"), nil } - oauth2Token, err := oauth2Config.Exchange(ctx, code) + oauth2Token, err := oauth2Config.Exchange(oidcCtx, code) if err != nil { return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil } @@ -146,7 +151,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, // Attempt to fetch information from the /userinfo endpoint and merge it with // the existing claims data. A failure to fetch additional information from this // endpoint will not invalidate the authorization flow. - if userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)); err == nil { + if userinfo, err := provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil { _ = userinfo.Claims(&allClaims) } else { logFunc := b.Logger().Warn @@ -246,7 +251,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return resp, nil } - provider, err := b.getProvider(ctx, config) + provider, err := b.getProvider(config) if err != nil { logger.Warn("error getting provider for login operation", "error", err) return resp, nil diff --git a/path_oidc_test.go b/path_oidc_test.go index 73fc018a..98bca1a7 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -1,6 +1,7 @@ package jwtauth import ( + "bytes" "context" "crypto/x509" "encoding/json" @@ -217,14 +218,27 @@ func TestOIDC_Callback(t *testing.T) { s.clientID = "abc" s.clientSecret = "def" + // save test server root cert to config in PEM format + cert := s.server.Certificate() + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + } + + pemBuf := new(bytes.Buffer) + if err := pem.Encode(pemBuf, block); err != nil { + t.Fatal(err) + } + // Configure backend data := map[string]interface{}{ - "oidc_discovery_url": s.server.URL, - "oidc_client_id": "abc", - "oidc_client_secret": "def", - "default_role": "test", - "bound_issuer": "http://vault.example.com/", - "jwt_supported_algs": []string{"ES256"}, + "oidc_discovery_url": s.server.URL, + "oidc_client_id": "abc", + "oidc_client_secret": "def", + "oidc_discovery_ca_pem": pemBuf.String(), + "default_role": "test", + "bound_issuer": "http://vault.example.com/", + "jwt_supported_algs": []string{"ES256"}, } // basic configuration @@ -758,7 +772,7 @@ type oidcProvider struct { func newOIDCProvider(t *testing.T) *oidcProvider { o := new(oidcProvider) o.t = t - o.server = httptest.NewServer(o) + o.server = httptest.NewTLSServer(o) return o } diff --git a/scripts/local_dev.sh b/scripts/local_dev.sh index 53c40425..9e2711ed 100755 --- a/scripts/local_dev.sh +++ b/scripts/local_dev.sh @@ -1,9 +1,9 @@ #!/usr/bin/env bash set -e -MNT_PATH="jwt" +MNT_PATH="oidc" PLUGIN_NAME="vault-plugin-auth-jwt" -PLUGIN_CATALOG_NAME="jwt" +PLUGIN_CATALOG_NAME="oidc" # # Helper script for local development. Automatically builds and registers the