Skip to content

Commit

Permalink
Improve claims verification (#31)
Browse files Browse the repository at this point in the history
- verify bound_claims in all paths
- verify bound_claims against /userinfo data
- verify bound_audiences in all paths
- review bound_audiences check when using static keys to conform to our docs
- more tests
  • Loading branch information
kalafut authored Mar 14, 2019
1 parent b9b5cad commit 86b4467
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 49 deletions.
40 changes: 40 additions & 0 deletions claims.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package jwtauth

import (
"errors"
"fmt"
"strings"

"github.com/hashicorp/vault/helper/strutil"

log "github.com/hashicorp/go-hclog"
"github.com/mitchellh/pointerstructure"
)
Expand Down Expand Up @@ -63,3 +66,40 @@ func extractMetadata(logger log.Logger, allClaims map[string]interface{}, claimM
}
return metadata, nil
}

// validateAudience checks whether any of the audiences in audClaim match those
// in boundAudiences. If strict is true and there are no bound audiences, then the
// presence of any audience in the received claim is considered an error.
func validateAudience(boundAudiences, audClaim []string, strict bool) error {
if strict && len(boundAudiences) == 0 && len(audClaim) > 0 {
return errors.New("audience claim found in JWT but no audiences bound to the role")
}

if len(boundAudiences) > 0 {
for _, v := range boundAudiences {
if strutil.StrListContains(audClaim, v) {
return nil
}
}
return errors.New("aud claim does not match any bound audience")
}

return nil
}

// validateBoundClaims checks that all of the claim:value requirements in boundClaims are
// met in allClaims.
func validateBoundClaims(logger log.Logger, boundClaims, allClaims map[string]interface{}) error {
for claim, expValue := range boundClaims {
actValue := getClaim(logger, allClaims, claim)
if actValue == nil {
return fmt.Errorf("claim %q is missing", claim)
}

if expValue != actValue {
return fmt.Errorf("claim %q does not match associated bound claim", claim)
}
}

return nil
}
134 changes: 133 additions & 1 deletion claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package jwtauth

import (
"encoding/json"
"github.com/hashicorp/go-hclog"
"testing"

"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
)

func TestGetClaim(t *testing.T) {
Expand Down Expand Up @@ -150,3 +150,135 @@ func TestExtractMetadata(t *testing.T) {
}
}
}

func TestValidateAudience(t *testing.T) {
tests := []struct {
boundAudiences []string
audience []string
strict bool
errExpected bool
}{
{[]string{"a"}, []string{"a"}, false, false},
{[]string{"a"}, []string{"b"}, false, true},
{[]string{"a"}, []string{""}, false, true},
{[]string{}, []string{"a"}, false, false},
{[]string{}, []string{"a"}, true, true},
{[]string{"a", "b"}, []string{"a"}, false, false},
{[]string{"a", "b"}, []string{"b"}, false, false},
{[]string{"a", "b"}, []string{"a", "b", "c"}, false, false},
{[]string{"a", "b"}, []string{"c", "d"}, false, true},
}

for _, test := range tests {
err := validateAudience(test.boundAudiences, test.audience, test.strict)
if test.errExpected != (err != nil) {
t.Fatalf("unexpected error result: boundAudiences %v, audience %v, strict %t, err: %v",
test.boundAudiences, test.audience, test.strict, err)
}
}
}

func TestValidateBoundClaims(t *testing.T) {
tests := []struct {
name string
boundClaims map[string]interface{}
allClaims map[string]interface{}
errExpected bool
}{
{
name: "valid",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
errExpected: false,
},
{
name: "valid - extra data",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
"color": "green",
},
errExpected: false,
},
{
name: "mismatched value",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "wrong",
},
errExpected: true,
},
{
name: "missing claim",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
},
errExpected: true,
},
{
name: "valid - JSONPointer",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/baz/1": "y",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: false,
},
{
name: "invalid - JSONPointer value mismatch",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/baz/1": "q",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: true,
},
{
name: "invalid - JSONPointer not found",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/XXX/1243": "q",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: true,
},
}
for _, tt := range tests {
if err := validateBoundClaims(hclog.NewNullLogger(), tt.boundClaims, tt.allClaims); (err != nil) != tt.errExpected {
t.Errorf("validateBoundClaims(%s) error = %v, wantErr %v", tt.name, err, tt.errExpected)
}
}
}
45 changes: 15 additions & 30 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import (
"fmt"
"time"

oidc "github.com/coreos/go-oidc"
"github.com/coreos/go-oidc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/cidrutil"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"gopkg.in/square/go-jose.v2/jwt"
Expand Down Expand Up @@ -127,16 +126,19 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
}

expected := jwt.Expected{
Issuer: config.BoundIssuer,
Subject: role.BoundSubject,
Audience: jwt.Audience(role.BoundAudiences),
Time: time.Now(),
Issuer: config.BoundIssuer,
Subject: role.BoundSubject,
Time: time.Now(),
}

if err := claims.Validate(expected); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
}

if err := validateAudience(role.BoundAudiences, claims.Audience, true); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
}

case config.OIDCDiscoveryURL != "":
allClaims, err = b.verifyOIDCToken(ctx, config, role, token)
if err != nil {
Expand All @@ -147,6 +149,10 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return nil, errors.New("unhandled case during login")
}

if err := validateBoundClaims(b.Logger(), role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
Expand Down Expand Up @@ -234,36 +240,15 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig,
if role.BoundSubject != "" && role.BoundSubject != idToken.Subject {
return nil, errors.New("sub claim does not match bound subject")
}
if len(role.BoundAudiences) > 0 {
var found bool
for _, v := range role.BoundAudiences {
if strutil.StrListContains(idToken.Audience, v) {
found = true
break
}
}
if !found {
return nil, errors.New("aud claim does not match any bound audience")
}
}

if len(role.BoundClaims) > 0 {
for claim, expValue := range role.BoundClaims {
actValue := getClaim(b.Logger(), allClaims, claim)
if actValue == nil {
return nil, fmt.Errorf("claim is missing: %s", claim)
}

if expValue != actValue {
return nil, fmt.Errorf("claim '%s' does not match associated bound claim", claim)
}
}
if err := validateAudience(role.BoundAudiences, idToken.Audience, false); err != nil {
return nil, errwrap.Wrapf("error validating claims: {{err}}", err)
}

return allClaims, nil
}

// createIdentity creates an alias and set of groups aliass based on the role
// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
userClaimRaw, ok := allClaims[role.UserClaim]
Expand Down
Loading

0 comments on commit 86b4467

Please sign in to comment.