From f7b66bab6feac450a625be1e7a77e8e3fb3f1fb4 Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Thu, 9 May 2019 13:30:35 -0700 Subject: [PATCH] Add JWKS support (#43) --- backend.go | 30 ++++++- path_config.go | 119 +++++++++++++++++++++------ path_config_test.go | 140 ++++++++++++++++++++++++++++++++ path_login.go | 61 +++++++++----- path_login_test.go | 92 ++++++++++++++++++--- path_oidc.go | 20 ++--- path_oidc_test.go | 190 +++++++++++++++++++++++--------------------- path_role.go | 4 +- path_role_test.go | 4 +- 9 files changed, 498 insertions(+), 162 deletions(-) diff --git a/backend.go b/backend.go index c2a6f640..b9bc0799 100644 --- a/backend.go +++ b/backend.go @@ -2,13 +2,15 @@ package jwtauth import ( "context" + "errors" "sync" "time" - oidc "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" - cache "github.com/patrickmn/go-cache" + "github.com/patrickmn/go-cache" ) const ( @@ -30,6 +32,7 @@ type jwtAuthBackend struct { l sync.RWMutex provider *oidc.Provider + keySet oidc.KeySet cachedConfig *jwtConfig oidcStates *cache.Cache @@ -126,6 +129,29 @@ func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) return provider, nil } +// getKeySet returns a new JWKS KeySet based on the provided config. +func (b *jwtAuthBackend) getKeySet(config *jwtConfig) (oidc.KeySet, error) { + b.l.Lock() + defer b.l.Unlock() + + if b.keySet != nil { + return b.keySet, nil + } + + if config.JWKSURL == "" { + return nil, errors.New("keyset error: jwks_url not configured") + } + + ctx, err := b.createCAContext(b.providerCtx, config.JWKSCAPEM) + if err != nil { + return nil, errwrap.Wrapf("error parsing jwks_ca_pem: {{err}}", err) + } + + b.keySet = oidc.NewRemoteKeySet(ctx, config.JWKSURL) + + return b.keySet, nil +} + const ( backendHelp = ` The JWT backend plugin allows authentication using JWTs (including OIDC). diff --git a/path_config.go b/path_config.go index d431c8a6..b7b68f56 100644 --- a/path_config.go +++ b/path_config.go @@ -1,17 +1,17 @@ package jwtauth import ( + "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net/http" + "strings" - "context" - - oidc "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc" "github.com/hashicorp/errwrap" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/logical" @@ -24,11 +24,11 @@ func pathConfig(b *jwtAuthBackend) *framework.Path { Fields: map[string]*framework.FieldSchema{ "oidc_discovery_url": { Type: framework.TypeString, - Description: `OIDC Discovery URL, without any .well-known component (base path). Cannot be used with "jwt_validation_pubkeys".`, + Description: `OIDC Discovery URL, without any .well-known component (base path). Cannot be used with "jwks_url" or "jwt_validation_pubkeys".`, }, "oidc_discovery_ca_pem": { Type: framework.TypeString, - Description: "The CA certificate or chain of certificates, in PEM format, to use to validate conections to the OIDC Discovery URL. If not set, system certificates are used.", + Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the OIDC Discovery URL. If not set, system certificates are used.", }, "oidc_client_id": { Type: framework.TypeString, @@ -39,13 +39,21 @@ func pathConfig(b *jwtAuthBackend) *framework.Path { Description: "The OAuth Client Secret configured with your OIDC provider.", DisplaySensitive: true, }, + "jwks_url": { + Type: framework.TypeString, + Description: `JWKS URL to use to authenticate signatures. Cannot be used with "oidc_discovery_url" or "jwt_validation_pubkeys".`, + }, + "jwks_ca_pem": { + Type: framework.TypeString, + Description: "The CA certificate or chain of certificates, in PEM format, to use to validate connections to the JWKS URL. If not set, system certificates are used.", + }, "default_role": { Type: framework.TypeString, Description: "The default role to use if none is provided during login. If not set, a role is required during login.", }, "jwt_validation_pubkeys": { Type: framework.TypeCommaStringSlice, - Description: `A list of PEM-encoded public keys to use to authenticate signatures locally. Cannot be used with "oidc_discovery_url".`, + Description: `A list of PEM-encoded public keys to use to authenticate signatures locally. Cannot be used with "jwks_url" or "oidc_discovery_url".`, }, "jwt_supported_algs": { Type: framework.TypeCommaStringSlice, @@ -76,8 +84,8 @@ func pathConfig(b *jwtAuthBackend) *framework.Path { } func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtConfig, error) { - b.l.RLock() - defer b.l.RUnlock() + b.l.Lock() + defer b.l.Unlock() if b.cachedConfig != nil { return b.cachedConfig, nil @@ -92,10 +100,8 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon } result := &jwtConfig{} - if entry != nil { - if err := entry.DecodeJSON(result); err != nil { - return nil, err - } + if err := entry.DecodeJSON(result); err != nil { + return nil, err } for _, v := range result.JWTValidationPubKeys { @@ -128,6 +134,8 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques "default_role": config.DefaultRole, "jwt_validation_pubkeys": config.JWTValidationPubKeys, "jwt_supported_algs": config.JWTSupportedAlgs, + "jwks_url": config.JWKSURL, + "jwks_ca_pem": config.JWKSCAPEM, "bound_issuer": config.BoundIssuer, }, } @@ -141,6 +149,8 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque OIDCDiscoveryCAPEM: d.Get("oidc_discovery_ca_pem").(string), OIDCClientID: d.Get("oidc_client_id").(string), OIDCClientSecret: d.Get("oidc_client_secret").(string), + JWKSURL: d.Get("jwks_url").(string), + JWKSCAPEM: d.Get("jwks_ca_pem").(string), DefaultRole: d.Get("default_role").(string), JWTValidationPubKeys: d.Get("jwt_validation_pubkeys").([]string), JWTSupportedAlgs: d.Get("jwt_supported_algs").([]string), @@ -148,10 +158,20 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque } // Run checks on values + methodCount := 0 + if config.OIDCDiscoveryURL != "" { + methodCount++ + } + if len(config.JWTValidationPubKeys) != 0 { + methodCount++ + } + if config.JWKSURL != "" { + methodCount++ + } + switch { - case config.OIDCDiscoveryURL == "" && len(config.JWTValidationPubKeys) == 0, - config.OIDCDiscoveryURL != "" && len(config.JWTValidationPubKeys) != 0: - return logical.ErrorResponse("exactly one of 'oidc_discovery_url' and 'jwt_validation_pubkeys' must be set"), nil + case methodCount != 1: + return logical.ErrorResponse("exactly one of 'jwt_validation_pubkeys', 'jwks_url' or 'oidc_discovery_url' must be set"), nil case config.OIDCClientID != "" && config.OIDCClientSecret == "", config.OIDCClientID == "" && config.OIDCClientSecret != "": @@ -160,12 +180,32 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque case config.OIDCDiscoveryURL != "": _, err := b.createProvider(config) if err != nil { - return logical.ErrorResponse(errwrap.Wrapf("error checking discovery URL: {{err}}", err).Error()), nil + return logical.ErrorResponse(errwrap.Wrapf("error checking oidc discovery URL: {{err}}", err).Error()), nil } case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "": return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil + case config.JWKSURL != "": + ctx, err := b.createCAContext(context.Background(), config.JWKSCAPEM) + if err != nil { + return logical.ErrorResponse(errwrap.Wrapf("error checking jwks_ca_pem: {{err}}", err).Error()), nil + } + + keyset := oidc.NewRemoteKeySet(ctx, config.JWKSURL) + + // Try to verify a correctly formatted JWT. The signature will fail to match, but other + // errors with fetching the remote keyset should be reported. + testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk" + _, err = keyset.VerifySignature(ctx, testJWT) + if err == nil { + err = errors.New("unexpected verification of JWT") + } + + if !strings.Contains(err.Error(), "failed to verify id token signature") { + return logical.ErrorResponse(errwrap.Wrapf("error checking jwks URL: {{err}}", err).Error()), nil + } + case len(config.JWTValidationPubKeys) != 0: for _, v := range config.JWTValidationPubKeys { if _, err := certutil.ParsePublicKeyPEM([]byte(v)); err != nil { @@ -199,7 +239,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque } func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) { - oidcCtx, err := b.createOIDCContext(b.providerCtx, config) + oidcCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM) if err != nil { return nil, errwrap.Wrapf("error creating provider: {{err}}", err) } @@ -212,16 +252,16 @@ func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, erro return provider, nil } -// createOIDCContext returns a context with custom TLS client, configured with the root certificates -// from oidc_discovery_ca_pem. If no certificates are configured, the original context is returned. -func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfig) (context.Context, error) { - if config.OIDCDiscoveryCAPEM == "" { +// createCAContext returns a context with custom TLS client, configured with the root certificates +// from caPEM. If no certificates are configured, the original context is returned. +func (b *jwtAuthBackend) createCAContext(ctx context.Context, caPEM string) (context.Context, error) { + if caPEM == "" { return ctx, nil } certPool := x509.NewCertPool() - if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok { - return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully") + if ok := certPool.AppendCertsFromPEM([]byte(caPEM)); !ok { + return nil, errors.New("could not parse CA PEM value successfully") } tr := cleanhttp.DefaultPooledTransport() @@ -234,9 +274,9 @@ func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfi Transport: tr, } - oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc) + caCtx := context.WithValue(ctx, oauth2.HTTPClient, tc) - return oidcCtx, nil + return caCtx, nil } type jwtConfig struct { @@ -244,6 +284,8 @@ type jwtConfig struct { OIDCDiscoveryCAPEM string `json:"oidc_discovery_ca_pem"` OIDCClientID string `json:"oidc_client_id"` OIDCClientSecret string `json:"oidc_client_secret"` + JWKSURL string `json:"jwks_url"` + JWKSCAPEM string `json:"jwks_ca_pem"` JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"` JWTSupportedAlgs []string `json:"jwt_supported_algs"` BoundIssuer string `json:"bound_issuer"` @@ -252,6 +294,31 @@ type jwtConfig struct { ParsedJWTPubKeys []interface{} `json:"-"` } +const ( + StaticKeys = iota + JWKS + OIDCDiscovery + OIDCFlow + unconfigured +) + +// authType classifies the authorization type/flow based on config parameters. +func (c jwtConfig) authType() int { + switch { + case len(c.ParsedJWTPubKeys) > 0: + return StaticKeys + case c.JWKSURL != "": + return JWKS + case c.OIDCDiscoveryURL != "": + if c.OIDCClientID != "" && c.OIDCClientSecret != "" { + return OIDCFlow + } + return OIDCDiscovery + } + + return unconfigured +} + const ( confHelpSyn = ` Configures the JWT authentication backend. diff --git a/path_config_test.go b/path_config_test.go index ff2d71fa..7cb93261 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -21,6 +21,8 @@ func TestConfig_JWT_Read(t *testing.T) { "default_role": "", "jwt_validation_pubkeys": []string{testJWTPubKey}, "jwt_supported_algs": []string{}, + "jwks_url": "", + "jwks_ca_pem": "", "bound_issuer": "http://vault.example.com/", } @@ -56,9 +58,11 @@ func TestConfig_JWT_Read(t *testing.T) { func TestConfig_JWT_Write(t *testing.T) { b, storage := getBackend(t) + // Create a config with too many token verification schemes data := map[string]interface{}{ "oidc_discovery_url": "http://fake.example.com", "jwt_validation_pubkeys": []string{testJWTPubKey}, + "jwks_url": "http://fake.anotherexample.com", "bound_issuer": "http://vault.example.com/", } @@ -80,8 +84,30 @@ func TestConfig_JWT_Write(t *testing.T) { t.Fatalf("got unexpected error: %v", resp.Error()) } + // remove oidc_discovery_url, but this still leaves too many delete(data, "oidc_discovery_url") + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + + if resp == nil || !resp.IsError() { + t.Fatal("expected error") + } + if !strings.HasPrefix(resp.Error().Error(), "exactly one of") { + t.Fatalf("got unexpected error: %v", resp.Error()) + } + + // remove jwks_url so the config is now valid + delete(data, "jwks_url") + req = &logical.Request{ Operation: logical.UpdateOperation, Path: configPath, @@ -116,6 +142,120 @@ func TestConfig_JWT_Write(t *testing.T) { } } +func TestConfig_JWKS_Update(t *testing.T) { + b, storage := getBackend(t) + + s := newOIDCProvider(t) + defer s.server.Close() + + cert, err := s.getTLSCert() + if err != nil { + t.Fatal(err) + } + + data := map[string]interface{}{ + "jwks_url": s.server.URL + "/certs", + "jwks_ca_pem": cert, + "oidc_discovery_url": "", + "oidc_discovery_ca_pem": "", + "oidc_client_id": "", + "default_role": "", + "jwt_validation_pubkeys": []string{}, + "jwt_supported_algs": []string{}, + "bound_issuer": "", + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: configPath, + Storage: storage, + Data: nil, + } + + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if diff := deep.Equal(resp.Data, data); diff != nil { + t.Fatalf("Expected did not equal actual: %v", diff) + } +} + +func TestConfig_JWKS_Update_Invalid(t *testing.T) { + b, storage := getBackend(t) + + s := newOIDCProvider(t) + defer s.server.Close() + + cert, err := s.getTLSCert() + if err != nil { + t.Fatal(err) + } + + data := map[string]interface{}{ + "jwks_url": s.server.URL + "/certs_missing", + "jwks_ca_pem": cert, + "oidc_discovery_url": "", + "oidc_discovery_ca_pem": "", + "oidc_client_id": "", + "default_role": "", + "jwt_validation_pubkeys": []string{}, + "jwt_supported_algs": []string{}, + "bound_issuer": "", + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatal("expected error") + } + if !strings.Contains(resp.Error().Error(), "get keys failed") { + t.Fatalf("got unexpected error: %v", resp.Error()) + } + + data["jwks_url"] = s.server.URL + "/certs_invalid" + + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatal("expected error") + } + if !strings.Contains(resp.Error().Error(), "failed to decode keys") { + t.Fatalf("got unexpected error: %v", resp.Error()) + } +} + func TestConfig_OIDC_Write(t *testing.T) { b, storage := getBackend(t) diff --git a/path_login.go b/path_login.go index 13e0542e..7223999b 100644 --- a/path_login.go +++ b/path_login.go @@ -2,6 +2,7 @@ package jwtauth import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -81,29 +82,51 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d return logical.ErrorResponse("request originated from invalid CIDR"), nil } - // Here is where things diverge. If it is using OIDC Discovery, validate - // that way; otherwise validate against the locally configured keys. Once - // things are validated, we re-unify the request path when evaluating the - // claims. + // Here is where things diverge. If it is using OIDC Discovery, validate that way; + // otherwise validate against the locally configured or JWKS keys. Once things are + // validated, we re-unify the request path when evaluating the claims. allClaims := map[string]interface{}{} - switch { - case len(config.ParsedJWTPubKeys) != 0: - parsedJWT, err := jwt.ParseSigned(token) - if err != nil { - return logical.ErrorResponse(errwrap.Wrapf("error parsing token: {{err}}", err).Error()), nil - } + configType := config.authType() + switch { + case configType == StaticKeys || configType == JWKS: claims := jwt.Claims{} + if configType == JWKS { + keySet, err := b.getKeySet(config) + if err != nil { + return logical.ErrorResponse(errwrap.Wrapf("error fetching jwks keyset: {{err}}", err).Error()), nil + } - var valid bool - for _, key := range config.ParsedJWTPubKeys { - if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil { - valid = true - break + // Verify signature (and only signature... other elements are checked later) + payload, err := keySet.VerifySignature(ctx, token) + if err != nil { + return logical.ErrorResponse(errwrap.Wrapf("error verifying token: {{err}}", err).Error()), nil + } + + // Unmarshal payload into two copies: public claims for library verification, and a set + // of all received claims. + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %v", err) + } + if err := json.Unmarshal(payload, &allClaims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %v", err) + } + } else { + parsedJWT, err := jwt.ParseSigned(token) + if err != nil { + return logical.ErrorResponse(errwrap.Wrapf("error parsing token: {{err}}", err).Error()), nil + } + + var valid bool + for _, key := range config.ParsedJWTPubKeys { + if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil { + valid = true + break + } + } + if !valid { + return logical.ErrorResponse("no known key successfully validated the token signature"), nil } - } - if !valid { - return logical.ErrorResponse("no known key successfully validated the token signature"), nil } // We require notbefore or expiry; if only one is provided, we allow 5 minutes of leeway. @@ -152,7 +175,7 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil } - case config.OIDCDiscoveryURL != "": + case configType == OIDCDiscovery: allClaims, err = b.verifyOIDCToken(ctx, config, role, token) if err != nil { return logical.ErrorResponse(err.Error()), nil diff --git a/path_login_test.go b/path_login_test.go index 2d4e2df8..0b3206ce 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -19,7 +19,7 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) -func setupBackend(t *testing.T, oidc, role_type_oidc, audience bool, boundClaims bool, boundCIDRs bool) (logical.Backend, logical.Storage) { +func setupBackend(t *testing.T, oidc, role_type_oidc, audience bool, boundClaims bool, boundCIDRs bool, jwks bool) (logical.Backend, logical.Storage) { b, storage := getBackend(t) var data map[string]interface{} @@ -29,9 +29,22 @@ func setupBackend(t *testing.T, oidc, role_type_oidc, audience bool, boundClaims "oidc_discovery_url": "https://team-vault.auth0.com/", } } else { - data = map[string]interface{}{ - "bound_issuer": "https://team-vault.auth0.com/", - "jwt_validation_pubkeys": ecdsaPubKey, + if !jwks { + data = map[string]interface{}{ + "bound_issuer": "https://team-vault.auth0.com/", + "jwt_validation_pubkeys": ecdsaPubKey, + } + } else { + p := newOIDCProvider(t) + cert, err := p.getTLSCert() + if err != nil { + t.Fatal(err) + } + + data = map[string]interface{}{ + "jwks_url": p.server.URL + "/certs", + "jwks_ca_pem": cert, + } } } @@ -151,9 +164,15 @@ func getTestOIDC(t *testing.T) string { } func TestLogin_JWT(t *testing.T) { + testLogin_JWT(t, false) + testLogin_JWT(t, true) +} + +func testLogin_JWT(t *testing.T, jwks bool) { // Test role_type oidc { - b, storage := setupBackend(t, false, true, true, false, false) + b, storage := setupBackend(t, false, true, true, false, false, jwks) + cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", Issuer: "https://team-vault.auth0.com/", @@ -200,7 +219,8 @@ func TestLogin_JWT(t *testing.T) { // Test missing audience { - b, storage := setupBackend(t, false, false, false, false, false) + b, storage := setupBackend(t, false, false, false, false, false, jwks) + cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", Issuer: "https://team-vault.auth0.com/", @@ -249,7 +269,7 @@ func TestLogin_JWT(t *testing.T) { { // run test with and without bound_cidrs configured for _, useBoundCIDRs := range []bool{false, true} { - b, storage := setupBackend(t, false, false, true, true, useBoundCIDRs) + b, storage := setupBackend(t, false, false, true, true, useBoundCIDRs, jwks) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -338,7 +358,7 @@ func TestLogin_JWT(t *testing.T) { } } - b, storage := setupBackend(t, false, false, true, true, false) + b, storage := setupBackend(t, false, false, true, true, false, jwks) // test invalid bound claim { @@ -729,7 +749,7 @@ func TestLogin_JWT(t *testing.T) { // test invalid address { - b, storage := setupBackend(t, false, false, false, false, true) + b, storage := setupBackend(t, false, false, false, false, true, jwks) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -805,7 +825,7 @@ func TestLogin_JWT(t *testing.T) { } func TestLogin_OIDC(t *testing.T) { - b, storage := setupBackend(t, true, false, true, false, false) + b, storage := setupBackend(t, true, false, true, false, false, false) jwtData := getTestOIDC(t) @@ -961,6 +981,58 @@ func TestLogin_NestedGroups(t *testing.T) { } } +func TestLogin_JWKS_Concurrent(t *testing.T) { + b, storage := setupBackend(t, false, false, true, false, false, true) + + cl := jwt.Claims{ + Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", + Issuer: "https://team-vault.auth0.com/", + NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), + Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"}, + } + + privateCl := struct { + User string `json:"https://vault/user"` + Groups []string `json:"https://vault/groups"` + }{ + "jeff", + []string{"foo", "bar"}, + } + + jwtData, _ := getTestJWT(t, ecdsaPrivKey, cl, privateCl) + + data := map[string]interface{}{ + "role": "plugin-test", + "jwt": jwtData, + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "login", + Storage: storage, + Data: data, + } + + for i := 0; i < 100; i++ { + t.Run("", func(t *testing.T) { + t.Parallel() + + for i := 0; i < 100; i++ { + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("got nil response") + } + if resp.IsError() { + t.Fatalf("got error: %v", resp.Error()) + } + } + }) + } +} + const ( ecdsaPrivKey string = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49 diff --git a/path_oidc.go b/path_oidc.go index 871b8eca..c633daaa 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -108,7 +108,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err) } - oidcCtx, err := b.createOIDCContext(ctx, config) + oidcCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) if err != nil { return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err) } @@ -215,13 +215,14 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f config, err := b.config(ctx, req.Storage) if err != nil { - logger.Warn("error loading configuration", "error", err) - return resp, nil + return nil, err } - if config == nil { - logger.Warn("nil configuration") - return resp, nil + return logical.ErrorResponse("could not load configuration"), nil + } + + if config.authType() != OIDCFlow { + return logical.ErrorResponse("OIDC login is not configured for this mount"), nil } roleName := d.Get("role").(string) @@ -239,11 +240,10 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f role, err := b.role(ctx, req.Storage, roleName) if err != nil { - return resp, nil + return nil, err } - - if role == nil || role.RoleType != "oidc" { - return resp, nil + if role == nil { + return logical.ErrorResponse("role %q could not be found", roleName), nil } if !validRedirect(redirectURI, role.AllowedRedirectURIs) { diff --git a/path_oidc_test.go b/path_oidc_test.go index 01530c23..859082e2 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -51,7 +51,6 @@ func TestOIDC_AuthURL(t *testing.T) { // set up test role data = map[string]interface{}{ - "role_type": "oidc", "user_claim": "email", "bound_audiences": "vault", "allowed_redirect_uris": []string{"https://example.com"}, @@ -130,15 +129,13 @@ func TestOIDC_AuthURL(t *testing.T) { } resp, err := b.HandleRequest(context.Background(), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) + if err != nil { + t.Fatal(err) } - authURL := resp.Data["auth_url"].(string) - if authURL != "" { - t.Fatalf(`expected: "", actual: %s\n`, authURL) + if !resp.IsError() { + t.Fatalf("expected error response, got: %v", resp) } - }) // create limited role with restricted redirect_uris @@ -173,7 +170,7 @@ func TestOIDC_AuthURL(t *testing.T) { Data: data, } - resp, err = b.HandleRequest(context.Background(), req) + resp, err := b.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v\n", err, resp) } @@ -192,7 +189,7 @@ func TestOIDC_AuthURL(t *testing.T) { "role": "limited_uris", "redirect_uri": "http://bitc0in-4-less.cx", } - req = &logical.Request{ + req := &logical.Request{ Operation: logical.UpdateOperation, Path: "oidc/auth_url", Storage: storage, @@ -212,87 +209,6 @@ func TestOIDC_AuthURL(t *testing.T) { } func TestOIDC_Callback(t *testing.T) { - getBackendAndServer := func(t *testing.T, boundCIDRs bool) (logical.Backend, logical.Storage, *oidcProvider) { - b, storage := getBackend(t) - s := newOIDCProvider(t) - s.clientID = "abc" - s.clientSecret = "def" - - // save test server root cert to config in PEM format - cert := s.server.Certificate() - block := &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - } - - pemBuf := new(bytes.Buffer) - if err := pem.Encode(pemBuf, block); err != nil { - t.Fatal(err) - } - - // Configure backend - data := map[string]interface{}{ - "oidc_discovery_url": s.server.URL, - "oidc_client_id": "abc", - "oidc_client_secret": "def", - "oidc_discovery_ca_pem": pemBuf.String(), - "default_role": "test", - "bound_issuer": "http://vault.example.com/", - "jwt_supported_algs": []string{"ES256"}, - } - - // basic configuration - req := &logical.Request{ - Operation: logical.UpdateOperation, - Path: configPath, - Storage: storage, - Data: data, - } - - resp, err := b.HandleRequest(context.Background(), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) - } - - // set up test role - data = map[string]interface{}{ - "role_type": "oidc", - "user_claim": "email", - "allowed_redirect_uris": []string{"https://example.com"}, - "claim_mappings": map[string]string{ - "COLOR": "color", - "/nested/Size": "size", - }, - "groups_claim": "/nested/Groups", - "ttl": "3m", - "max_ttl": "5m", - "bound_claims": map[string]interface{}{ - "password": "foo", - "sk": "42", - "/nested/secret_code": "bar", - "temperature": "76", - }, - } - - if boundCIDRs { - data["bound_cidrs"] = "127.0.0.42" - } - - req = &logical.Request{ - Operation: logical.CreateOperation, - Path: "role/test", - Storage: storage, - Data: data, - } - - resp, err = b.HandleRequest(context.Background(), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) - } - - return b, storage, s - } - t.Run("successful login", func(t *testing.T) { // run test with and without bound_cidrs configured @@ -793,7 +709,10 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/certs": a := getTestJWKS(o.t, ecdsaPubKey) w.Write(a) - + case "/certs_missing": + w.WriteHeader(404) + case "/certs_invalid": + w.Write([]byte("It's not a keyset!")) case "/token": code := r.FormValue("code") @@ -830,6 +749,22 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +// getTLSCert returns the certificate for this provider in PEM format +func (o *oidcProvider) getTLSCert() (string, error) { + cert := o.server.Certificate() + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + } + + pemBuf := new(bytes.Buffer) + if err := pem.Encode(pemBuf, block); err != nil { + return "", err + } + + return pemBuf.String(), nil +} + func getQueryParam(t *testing.T, inputURL, param string) string { t.Helper() @@ -905,3 +840,76 @@ func TestOIDC_ValidRedirect(t *testing.T) { } } } + +func getBackendAndServer(t *testing.T, boundCIDRs bool) (logical.Backend, logical.Storage, *oidcProvider) { + b, storage := getBackend(t) + s := newOIDCProvider(t) + s.clientID = "abc" + s.clientSecret = "def" + + cert, err := s.getTLSCert() + if err != nil { + t.Fatal(err) + } + + // Configure backend + data := map[string]interface{}{ + "oidc_discovery_url": s.server.URL, + "oidc_client_id": "abc", + "oidc_client_secret": "def", + "oidc_discovery_ca_pem": cert, + "default_role": "test", + "bound_issuer": "http://vault.example.com/", + "jwt_supported_algs": []string{"ES256"}, + } + + // basic configuration + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + // set up test role + data = map[string]interface{}{ + "user_claim": "email", + "allowed_redirect_uris": []string{"https://example.com"}, + "claim_mappings": map[string]string{ + "COLOR": "color", + "/nested/Size": "size", + }, + "groups_claim": "/nested/Groups", + "ttl": "3m", + "max_ttl": "5m", + "bound_claims": map[string]interface{}{ + "password": "foo", + "sk": "42", + "/nested/secret_code": "bar", + "temperature": "76", + }, + } + + if boundCIDRs { + data["bound_cidrs"] = "127.0.0.42" + } + + req = &logical.Request{ + Operation: logical.CreateOperation, + Path: "role/test", + Storage: storage, + Data: data, + } + + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + return b, storage, s +} diff --git a/path_role.go b/path_role.go index ec304cda..59a0bdc8 100644 --- a/path_role.go +++ b/path_role.go @@ -360,11 +360,11 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. targets := make(map[string]bool) for _, metadataKey := range claimMappings { if strutil.StrListContains(reservedMetadata, metadataKey) { - return logical.ErrorResponse("metadata key '%s' is reserved and may not be a mapping destination", metadataKey), nil + return logical.ErrorResponse("metadata key %q is reserved and may not be a mapping destination", metadataKey), nil } if targets[metadataKey] { - return logical.ErrorResponse("multiple keys are mapped to metadata key '%s'", metadataKey), nil + return logical.ErrorResponse("multiple keys are mapped to metadata key %q", metadataKey), nil } targets[metadataKey] = true } diff --git a/path_role_test.go b/path_role_test.go index a5e5fd3a..15657214 100644 --- a/path_role_test.go +++ b/path_role_test.go @@ -232,7 +232,7 @@ func TestPath_OIDCCreate(t *testing.T) { if resp != nil && !resp.IsError() { t.Fatalf("expected error") } - if !strings.Contains(resp.Error().Error(), "metadata key 'role' is reserved") { + if !strings.Contains(resp.Error().Error(), `metadata key "role" is reserved`) { t.Fatalf("unexpected err: %v", resp) } @@ -256,7 +256,7 @@ func TestPath_OIDCCreate(t *testing.T) { if resp != nil && !resp.IsError() { t.Fatalf("expected error") } - if !strings.Contains(resp.Error().Error(), "multiple keys are mapped to metadata key 'a'") { + if !strings.Contains(resp.Error().Error(), `multiple keys are mapped to metadata key "a"`) { t.Fatalf("unexpected err: %v", resp) } }