diff --git a/claims.go b/claims.go index 473349bc..ad062464 100644 --- a/claims.go +++ b/claims.go @@ -1,9 +1,12 @@ package jwtauth import ( + "errors" "fmt" "strings" + "github.com/hashicorp/vault/helper/strutil" + log "github.com/hashicorp/go-hclog" "github.com/mitchellh/pointerstructure" ) @@ -63,3 +66,40 @@ func extractMetadata(logger log.Logger, allClaims map[string]interface{}, claimM } return metadata, nil } + +// validateAudience checks whether any of the audiences in audClaim match those +// in boundAudiences. If strict is true and there are no bound audiences, then the +// presence of any audience in the received claim is considered an error. +func validateAudience(boundAudiences, audClaim []string, strict bool) error { + if strict && len(boundAudiences) == 0 && len(audClaim) > 0 { + return errors.New("audience claim found in JWT but no audiences bound to the role") + } + + if len(boundAudiences) > 0 { + for _, v := range boundAudiences { + if strutil.StrListContains(audClaim, v) { + return nil + } + } + return errors.New("aud claim does not match any bound audience") + } + + return nil +} + +// validateBoundClaims checks that all of the claim:value requirements in boundClaims are +// met in allClaims. +func validateBoundClaims(logger log.Logger, boundClaims, allClaims map[string]interface{}) error { + for claim, expValue := range boundClaims { + actValue := getClaim(logger, allClaims, claim) + if actValue == nil { + return fmt.Errorf("claim %q is missing", claim) + } + + if expValue != actValue { + return fmt.Errorf("claim %q does not match associated bound claim", claim) + } + } + + return nil +} diff --git a/claims_test.go b/claims_test.go index 31cefd4c..dd522a24 100644 --- a/claims_test.go +++ b/claims_test.go @@ -2,10 +2,10 @@ package jwtauth import ( "encoding/json" - "github.com/hashicorp/go-hclog" "testing" "github.com/go-test/deep" + "github.com/hashicorp/go-hclog" ) func TestGetClaim(t *testing.T) { @@ -150,3 +150,135 @@ func TestExtractMetadata(t *testing.T) { } } } + +func TestValidateAudience(t *testing.T) { + tests := []struct { + boundAudiences []string + audience []string + strict bool + errExpected bool + }{ + {[]string{"a"}, []string{"a"}, false, false}, + {[]string{"a"}, []string{"b"}, false, true}, + {[]string{"a"}, []string{""}, false, true}, + {[]string{}, []string{"a"}, false, false}, + {[]string{}, []string{"a"}, true, true}, + {[]string{"a", "b"}, []string{"a"}, false, false}, + {[]string{"a", "b"}, []string{"b"}, false, false}, + {[]string{"a", "b"}, []string{"a", "b", "c"}, false, false}, + {[]string{"a", "b"}, []string{"c", "d"}, false, true}, + } + + for _, test := range tests { + err := validateAudience(test.boundAudiences, test.audience, test.strict) + if test.errExpected != (err != nil) { + t.Fatalf("unexpected error result: boundAudiences %v, audience %v, strict %t, err: %v", + test.boundAudiences, test.audience, test.strict, err) + } + } +} + +func TestValidateBoundClaims(t *testing.T) { + tests := []struct { + name string + boundClaims map[string]interface{} + allClaims map[string]interface{} + errExpected bool + }{ + { + name: "valid", + boundClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + errExpected: false, + }, + { + name: "valid - extra data", + boundClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + "color": "green", + }, + errExpected: false, + }, + { + name: "mismatched value", + boundClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": "wrong", + }, + errExpected: true, + }, + { + name: "missing claim", + boundClaims: map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + allClaims: map[string]interface{}{ + "foo": "a", + }, + errExpected: true, + }, + { + name: "valid - JSONPointer", + boundClaims: map[string]interface{}{ + "foo": "a", + "/bar/baz/1": "y", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": map[string]interface{}{ + "baz": []string{"x", "y", "z"}, + }, + }, + errExpected: false, + }, + { + name: "invalid - JSONPointer value mismatch", + boundClaims: map[string]interface{}{ + "foo": "a", + "/bar/baz/1": "q", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": map[string]interface{}{ + "baz": []string{"x", "y", "z"}, + }, + }, + errExpected: true, + }, + { + name: "invalid - JSONPointer not found", + boundClaims: map[string]interface{}{ + "foo": "a", + "/bar/XXX/1243": "q", + }, + allClaims: map[string]interface{}{ + "foo": "a", + "bar": map[string]interface{}{ + "baz": []string{"x", "y", "z"}, + }, + }, + errExpected: true, + }, + } + for _, tt := range tests { + if err := validateBoundClaims(hclog.NewNullLogger(), tt.boundClaims, tt.allClaims); (err != nil) != tt.errExpected { + t.Errorf("validateBoundClaims(%s) error = %v, wantErr %v", tt.name, err, tt.errExpected) + } + } +} diff --git a/path_login.go b/path_login.go index 06ad379e..a67b7b17 100644 --- a/path_login.go +++ b/path_login.go @@ -6,10 +6,9 @@ import ( "fmt" "time" - oidc "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/cidrutil" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "gopkg.in/square/go-jose.v2/jwt" @@ -127,16 +126,19 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d } expected := jwt.Expected{ - Issuer: config.BoundIssuer, - Subject: role.BoundSubject, - Audience: jwt.Audience(role.BoundAudiences), - Time: time.Now(), + Issuer: config.BoundIssuer, + Subject: role.BoundSubject, + Time: time.Now(), } if err := claims.Validate(expected); err != nil { return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil } + if err := validateAudience(role.BoundAudiences, claims.Audience, true); err != nil { + return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil + } + case config.OIDCDiscoveryURL != "": allClaims, err = b.verifyOIDCToken(ctx, config, role, token) if err != nil { @@ -147,6 +149,10 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d return nil, errors.New("unhandled case during login") } + if err := validateBoundClaims(b.Logger(), role.BoundClaims, allClaims); err != nil { + return logical.ErrorResponse("error validating claims: %s", err.Error()), nil + } + alias, groupAliases, err := b.createIdentity(allClaims, role) if err != nil { return logical.ErrorResponse(err.Error()), nil @@ -234,36 +240,15 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, if role.BoundSubject != "" && role.BoundSubject != idToken.Subject { return nil, errors.New("sub claim does not match bound subject") } - if len(role.BoundAudiences) > 0 { - var found bool - for _, v := range role.BoundAudiences { - if strutil.StrListContains(idToken.Audience, v) { - found = true - break - } - } - if !found { - return nil, errors.New("aud claim does not match any bound audience") - } - } - if len(role.BoundClaims) > 0 { - for claim, expValue := range role.BoundClaims { - actValue := getClaim(b.Logger(), allClaims, claim) - if actValue == nil { - return nil, fmt.Errorf("claim is missing: %s", claim) - } - - if expValue != actValue { - return nil, fmt.Errorf("claim '%s' does not match associated bound claim", claim) - } - } + if err := validateAudience(role.BoundAudiences, idToken.Audience, false); err != nil { + return nil, errwrap.Wrapf("error validating claims: {{err}}", err) } return allClaims, nil } -// createIdentity creates an alias and set of groups aliass based on the role +// createIdentity creates an alias and set of groups aliases based on the role // definition and received claims. func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) { userClaimRaw, ok := allClaims[role.UserClaim] diff --git a/path_login_test.go b/path_login_test.go index f6859cab..fc4ce3f6 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -19,7 +19,7 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) -func setupBackend(t *testing.T, oidc, audience bool) (logical.Backend, logical.Storage) { +func setupBackend(t *testing.T, oidc, audience bool, boundClaims bool) (logical.Backend, logical.Storage) { b, storage := getBackend(t) var data map[string]interface{} @@ -63,7 +63,12 @@ func setupBackend(t *testing.T, oidc, audience bool) (logical.Backend, logical.S }, } if audience { - data["bound_audiences"] = "https://vault.plugin.auth.jwt.test" + data["bound_audiences"] = []string{"https://vault.plugin.auth.jwt.test", "another_audience"} + } + if boundClaims { + data["bound_claims"] = map[string]interface{}{ + "color": "green", + } } req = &logical.Request{ @@ -141,7 +146,7 @@ func getTestOIDC(t *testing.T) string { func TestLogin_JWT(t *testing.T) { // Test missing audience { - b, storage := setupBackend(t, false, false) + b, storage := setupBackend(t, false, false, false) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", Issuer: "https://team-vault.auth0.com/", @@ -186,7 +191,7 @@ func TestLogin_JWT(t *testing.T) { } } - b, storage := setupBackend(t, false, true) + b, storage := setupBackend(t, false, true, true) // test valid inputs { @@ -206,11 +211,13 @@ func TestLogin_JWT(t *testing.T) { Groups []string `json:"https://vault/groups"` FirstName string `json:"first_name"` Org orgs `json:"org"` + Color string `json:"color"` }{ "jeff", []string{"foo", "bar"}, "jeff2", orgs{"engineering"}, + "green", } jwtData, _ := getTestJWT(t, ecdsaPrivKey, cl, privateCl) @@ -272,6 +279,54 @@ func TestLogin_JWT(t *testing.T) { } + // test invalid bound claim + { + cl := jwt.Claims{ + Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", + Issuer: "https://team-vault.auth0.com/", + NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), + Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"}, + } + + type orgs struct { + Primary string `json:"primary"` + } + + privateCl := struct { + User string `json:"https://vault/user"` + Groups []string `json:"https://vault/groups"` + FirstName string `json:"first_name"` + Org orgs `json:"org"` + }{ + "jeff", + []string{"foo", "bar"}, + "jeff2", + orgs{"engineering"}, + } + + jwtData, _ := getTestJWT(t, ecdsaPrivKey, cl, privateCl) + + data := map[string]interface{}{ + "role": "plugin-test", + "jwt": jwtData, + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "login", + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if !resp.IsError() { + t.Fatalf("expected error, got: %v", resp.Data) + } + } + // test bad signature { cl := jwt.Claims{ @@ -543,9 +598,11 @@ func TestLogin_JWT(t *testing.T) { privateCl := struct { User string `json:"https://vault/user"` Groups []string `json:"https://vault/groups"` + Color string `json:"color"` }{ "jeff", []string{"foo", "bar"}, + "green", } jwtData, _ := getTestJWT(t, ecdsaPrivKey, cl, privateCl) @@ -638,7 +695,7 @@ func TestLogin_JWT(t *testing.T) { } func TestLogin_OIDC(t *testing.T) { - b, storage := setupBackend(t, true, true) + b, storage := setupBackend(t, true, true, false) jwtData := getTestOIDC(t) diff --git a/path_oidc.go b/path_oidc.go index f1ac6242..51084a5e 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -133,6 +133,11 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return logical.ErrorResponse("%s %s", errTokenVerification, err.Error()), nil } + if allClaims["nonce"] != state.nonce { + return logical.ErrorResponse(errTokenVerification + " Invalid ID token nonce."), nil + } + delete(allClaims, "nonce") + // Attempt to fetch information from the /userinfo endpoint and merge it with // the existing claims data. A failure to fetch additional information from this // endpoint will not invalidate the authorization flow. @@ -146,10 +151,9 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, logFunc("error reading /userinfo endpoint", "error", err) } - if allClaims["nonce"] != state.nonce { - return logical.ErrorResponse(errTokenVerification + " Invalid ID token nonce."), nil + if err := validateBoundClaims(b.Logger(), role.BoundClaims, allClaims); err != nil { + return logical.ErrorResponse("error validating claims: %s", err.Error()), nil } - delete(allClaims, "nonce") alias, groupAliases, err := b.createIdentity(allClaims, role) if err != nil { diff --git a/path_oidc_test.go b/path_oidc_test.go index d53215d6..e428ac6d 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -16,7 +16,7 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/vault/logical" - jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -254,6 +254,7 @@ func TestOIDC_Callback(t *testing.T) { "password": "foo", "sk": "42", "/nested/secret_code": "bar", + "temperature": "76", }, } @@ -365,6 +366,134 @@ func TestOIDC_Callback(t *testing.T) { } }) + t.Run("failed login - bad nonce", func(t *testing.T) { + b, storage, s := getBackendAndServer(t) + defer s.server.Close() + + // get auth_url + data := map[string]interface{}{ + "role": "test", + "redirect_uri": "https://example.com", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "oidc/auth_url", + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + authURL := resp.Data["auth_url"].(string) + + state := getQueryParam(t, authURL, "state") + + // set provider claims that will be returned by the mock server + s.customClaims = map[string]interface{}{ + "nonce": "notgonnamatch", + "email": "bob@example.com", + "COLOR": "green", + "sk": "42", + "nested": map[string]interface{}{ + "Size": "medium", + "Groups": []string{"a", "b"}, + "secret_code": "bar", + }, + "password": "foo", + } + + // set mock provider's expected code + s.code = "abc" + + // invoke the callback, which will in to try to exchange the code + // with the mock provider. + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "oidc/callback", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "code": "abc", + }, + } + + resp, err = b.HandleRequest(context.Background(), req) + + if err != nil { + t.Fatal(err) + } + if !resp.IsError() { + t.Fatalf("expected error response, got: %v", resp.Data) + } + }) + + t.Run("failed login - bound claim mismatch", func(t *testing.T) { + b, storage, s := getBackendAndServer(t) + defer s.server.Close() + + // get auth_url + data := map[string]interface{}{ + "role": "test", + "redirect_uri": "https://example.com", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "oidc/auth_url", + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + authURL := resp.Data["auth_url"].(string) + + state := getQueryParam(t, authURL, "state") + nonce := getQueryParam(t, authURL, "nonce") + + // set provider claims that will be returned by the mock server + s.customClaims = map[string]interface{}{ + "nonce": nonce, + "email": "bob@example.com", + "COLOR": "green", + "sk": "43", // the pre-configured role has a bound claim of "sk"=="42" + "nested": map[string]interface{}{ + "Size": "medium", + "Groups": []string{"a", "b"}, + "secret_code": "bar", + }, + "password": "foo", + } + + // set mock provider's expected code + s.code = "abc" + + // invoke the callback, which will in to try to exchange the code + // with the mock provider. + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "oidc/callback", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "code": "abc", + }, + } + + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if !resp.IsError() { + t.Fatalf("expected error response, got: %v", resp.Data) + } + }) + t.Run("missing state", func(t *testing.T) { b, storage, s := getBackendAndServer(t) defer s.server.Close() @@ -565,18 +694,14 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/openid-configuration": - w.Write([]byte(fmt.Sprintf(` + w.Write([]byte(strings.Replace(` { "issuer": "%s", "authorization_endpoint": "%s/auth", "token_endpoint": "%s/token", - "jwks_uri": "%s/certs" - }`, - o.server.URL, - o.server.URL, - o.server.URL, - o.server.URL, - ))) + "jwks_uri": "%s/certs", + "userinfo_endpoint": "%s/userinfo" + }`, "%s", o.server.URL, -1))) case "/certs": a := getTestJWKS(o.t, ecdsaPubKey) w.Write(a) @@ -605,6 +730,12 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { jwtData, jwtData, ))) + case "/userinfo": + w.Write([]byte(` + { + "color":"red", + "temperature":"76" + }`)) default: o.t.Fatalf("unexpected path: %q", r.URL.Path) diff --git a/scripts/local_dev.sh b/scripts/local_dev.sh index ba67fe52..53c40425 100755 --- a/scripts/local_dev.sh +++ b/scripts/local_dev.sh @@ -49,7 +49,7 @@ function cleanup { trap cleanup EXIT echo " Authing" -vault auth root &>/dev/null +vault login root &>/dev/null echo "--> Building" go build -o "$SCRATCH/plugins/$PLUGIN_NAME" "./cmd/$PLUGIN_NAME"