Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve claims verification #31

Merged
merged 1 commit into from
Mar 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions claims.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package jwtauth

import (
"errors"
"fmt"
"strings"

"github.com/hashicorp/vault/helper/strutil"

log "github.com/hashicorp/go-hclog"
"github.com/mitchellh/pointerstructure"
)
Expand Down Expand Up @@ -63,3 +66,40 @@ func extractMetadata(logger log.Logger, allClaims map[string]interface{}, claimM
}
return metadata, nil
}

// validateAudience checks whether any of the audiences in audClaim match those
// in boundAudiences. If strict is true and there are no bound audiences, then the
// presence of any audience in the received claim is considered an error.
func validateAudience(boundAudiences, audClaim []string, strict bool) error {
if strict && len(boundAudiences) == 0 && len(audClaim) > 0 {
return errors.New("audience claim found in JWT but no audiences bound to the role")
}

if len(boundAudiences) > 0 {
for _, v := range boundAudiences {
if strutil.StrListContains(audClaim, v) {
return nil
}
}
return errors.New("aud claim does not match any bound audience")
}

return nil
}

// validateBoundClaims checks that all of the claim:value requirements in boundClaims are
// met in allClaims.
func validateBoundClaims(logger log.Logger, boundClaims, allClaims map[string]interface{}) error {
for claim, expValue := range boundClaims {
actValue := getClaim(logger, allClaims, claim)
if actValue == nil {
return fmt.Errorf("claim %q is missing", claim)
}

if expValue != actValue {
return fmt.Errorf("claim %q does not match associated bound claim", claim)
}
}

return nil
}
134 changes: 133 additions & 1 deletion claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package jwtauth

import (
"encoding/json"
"github.com/hashicorp/go-hclog"
"testing"

"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
)

func TestGetClaim(t *testing.T) {
Expand Down Expand Up @@ -150,3 +150,135 @@ func TestExtractMetadata(t *testing.T) {
}
}
}

func TestValidateAudience(t *testing.T) {
tests := []struct {
boundAudiences []string
audience []string
strict bool
errExpected bool
}{
{[]string{"a"}, []string{"a"}, false, false},
{[]string{"a"}, []string{"b"}, false, true},
{[]string{"a"}, []string{""}, false, true},
{[]string{}, []string{"a"}, false, false},
{[]string{}, []string{"a"}, true, true},
{[]string{"a", "b"}, []string{"a"}, false, false},
{[]string{"a", "b"}, []string{"b"}, false, false},
{[]string{"a", "b"}, []string{"a", "b", "c"}, false, false},
{[]string{"a", "b"}, []string{"c", "d"}, false, true},
}

for _, test := range tests {
err := validateAudience(test.boundAudiences, test.audience, test.strict)
if test.errExpected != (err != nil) {
t.Fatalf("unexpected error result: boundAudiences %v, audience %v, strict %t, err: %v",
test.boundAudiences, test.audience, test.strict, err)
}
}
}

func TestValidateBoundClaims(t *testing.T) {
tests := []struct {
name string
boundClaims map[string]interface{}
allClaims map[string]interface{}
errExpected bool
}{
{
name: "valid",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
errExpected: false,
},
{
name: "valid - extra data",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
"color": "green",
},
errExpected: false,
},
{
name: "mismatched value",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": "wrong",
},
errExpected: true,
},
{
name: "missing claim",
boundClaims: map[string]interface{}{
"foo": "a",
"bar": "b",
},
allClaims: map[string]interface{}{
"foo": "a",
},
errExpected: true,
},
{
name: "valid - JSONPointer",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/baz/1": "y",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: false,
},
{
name: "invalid - JSONPointer value mismatch",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/baz/1": "q",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: true,
},
{
name: "invalid - JSONPointer not found",
boundClaims: map[string]interface{}{
"foo": "a",
"/bar/XXX/1243": "q",
},
allClaims: map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
errExpected: true,
},
}
for _, tt := range tests {
if err := validateBoundClaims(hclog.NewNullLogger(), tt.boundClaims, tt.allClaims); (err != nil) != tt.errExpected {
t.Errorf("validateBoundClaims(%s) error = %v, wantErr %v", tt.name, err, tt.errExpected)
}
}
}
45 changes: 15 additions & 30 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import (
"fmt"
"time"

oidc "github.com/coreos/go-oidc"
"github.com/coreos/go-oidc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/cidrutil"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"gopkg.in/square/go-jose.v2/jwt"
Expand Down Expand Up @@ -127,16 +126,19 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
}

expected := jwt.Expected{
Issuer: config.BoundIssuer,
Subject: role.BoundSubject,
Audience: jwt.Audience(role.BoundAudiences),
Time: time.Now(),
Issuer: config.BoundIssuer,
Subject: role.BoundSubject,
Time: time.Now(),
}

if err := claims.Validate(expected); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
}

if err := validateAudience(role.BoundAudiences, claims.Audience, true); err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
}

case config.OIDCDiscoveryURL != "":
allClaims, err = b.verifyOIDCToken(ctx, config, role, token)
if err != nil {
Expand All @@ -147,6 +149,10 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return nil, errors.New("unhandled case during login")
}

if err := validateBoundClaims(b.Logger(), role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
Expand Down Expand Up @@ -234,36 +240,15 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig,
if role.BoundSubject != "" && role.BoundSubject != idToken.Subject {
return nil, errors.New("sub claim does not match bound subject")
}
if len(role.BoundAudiences) > 0 {
var found bool
for _, v := range role.BoundAudiences {
if strutil.StrListContains(idToken.Audience, v) {
found = true
break
}
}
if !found {
return nil, errors.New("aud claim does not match any bound audience")
}
}

if len(role.BoundClaims) > 0 {
for claim, expValue := range role.BoundClaims {
actValue := getClaim(b.Logger(), allClaims, claim)
if actValue == nil {
return nil, fmt.Errorf("claim is missing: %s", claim)
}

if expValue != actValue {
return nil, fmt.Errorf("claim '%s' does not match associated bound claim", claim)
}
}
if err := validateAudience(role.BoundAudiences, idToken.Audience, false); err != nil {
return nil, errwrap.Wrapf("error validating claims: {{err}}", err)
}

return allClaims, nil
}

// createIdentity creates an alias and set of groups aliass based on the role
// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
userClaimRaw, ok := allClaims[role.UserClaim]
Expand Down
Loading