Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use configured TLS certs for OIDC operations #40

Merged
merged 1 commit into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (b *jwtAuthBackend) reset() {
b.l.Unlock()
}

func (b *jwtAuthBackend) getProvider(ctx context.Context, config *jwtConfig) (*oidc.Provider, error) {
func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) {
b.l.RLock()
unlockFunc := b.l.RUnlock
defer func() { unlockFunc() }()
Expand Down
37 changes: 25 additions & 12 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,29 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
var certPool *x509.CertPool
if config.OIDCDiscoveryCAPEM != "" {
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok {
return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully")
}
oidcCtx, err := b.createOIDCContext(b.providerCtx, config)
if err != nil {
return nil, errwrap.Wrapf("error creating provider: {{err}}", err)
}

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
if err != nil {
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
}

return provider, nil
}

// createOIDCContext returns a context with custom TLS client, configured with the root certificates
// from oidc_discovery_ca_pem. If no certificates are configured, the original context is returned.
func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfig) (context.Context, error) {
if config.OIDCDiscoveryCAPEM == "" {
return ctx, nil
}

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok {
return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully")
}

tr := cleanhttp.DefaultPooledTransport()
Expand All @@ -216,14 +233,10 @@ func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, erro
tc := &http.Client{
Transport: tr,
}
oidcCtx := context.WithValue(b.providerCtx, oauth2.HTTPClient, tc)

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
if err != nil {
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
}
oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)

return provider, nil
return oidcCtx, nil
}

type jwtConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques
func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, role *jwtRole, rawToken string) (map[string]interface{}, error) {
allClaims := make(map[string]interface{})

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
}
Expand Down
15 changes: 10 additions & 5 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("request originated from invalid CIDR"), nil
}

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
return nil, errwrap.Wrapf(errLoginFailed+" Error getting provider for login operation: {{err}}", err)
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
}

oidcCtx, err := b.createOIDCContext(ctx, config)
if err != nil {
return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err)
}

var oauth2Config = oauth2.Config{
Expand All @@ -121,7 +126,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(errLoginFailed + " OAuth code parameter not provided"), nil
}

oauth2Token, err := oauth2Config.Exchange(ctx, code)
oauth2Token, err := oauth2Config.Exchange(oidcCtx, code)
if err != nil {
return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil
}
Expand All @@ -146,7 +151,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
// Attempt to fetch information from the /userinfo endpoint and merge it with
// the existing claims data. A failure to fetch additional information from this
// endpoint will not invalidate the authorization flow.
if userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
if userinfo, err := provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
_ = userinfo.Claims(&allClaims)
} else {
logFunc := b.Logger().Warn
Expand Down Expand Up @@ -246,7 +251,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
return resp, nil
}

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
logger.Warn("error getting provider for login operation", "error", err)
return resp, nil
Expand Down
28 changes: 21 additions & 7 deletions path_oidc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
Expand Down Expand Up @@ -217,14 +218,27 @@ func TestOIDC_Callback(t *testing.T) {
s.clientID = "abc"
s.clientSecret = "def"

// save test server root cert to config in PEM format
cert := s.server.Certificate()
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}

pemBuf := new(bytes.Buffer)
if err := pem.Encode(pemBuf, block); err != nil {
t.Fatal(err)
}

// Configure backend
data := map[string]interface{}{
"oidc_discovery_url": s.server.URL,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
"default_role": "test",
"bound_issuer": "http://vault.example.com/",
"jwt_supported_algs": []string{"ES256"},
"oidc_discovery_url": s.server.URL,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
"oidc_discovery_ca_pem": pemBuf.String(),
"default_role": "test",
"bound_issuer": "http://vault.example.com/",
"jwt_supported_algs": []string{"ES256"},
}

// basic configuration
Expand Down Expand Up @@ -758,7 +772,7 @@ type oidcProvider struct {
func newOIDCProvider(t *testing.T) *oidcProvider {
o := new(oidcProvider)
o.t = t
o.server = httptest.NewServer(o)
o.server = httptest.NewTLSServer(o)

return o
}
Expand Down
4 changes: 2 additions & 2 deletions scripts/local_dev.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env bash
set -e

MNT_PATH="jwt"
MNT_PATH="oidc"
PLUGIN_NAME="vault-plugin-auth-jwt"
PLUGIN_CATALOG_NAME="jwt"
PLUGIN_CATALOG_NAME="oidc"

#
# Helper script for local development. Automatically builds and registers the
Expand Down