Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the cap/oidc library for OIDC based authentication #158

Merged
merged 13 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import (
"errors"
"fmt"
"sync"
"time"

"github.com/coreos/go-oidc"
"github.com/hashicorp/cap/jwt"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/patrickmn/go-cache"
Expand All @@ -35,7 +34,7 @@ type jwtAuthBackend struct {
provider *oidc.Provider
validator *jwt.Validator
cachedConfig *jwtConfig
oidcStates *cache.Cache
oidcRequests *cache.Cache

providerCtx context.Context
providerCtxCancel context.CancelFunc
Expand All @@ -44,7 +43,7 @@ type jwtAuthBackend struct {
func backend() *jwtAuthBackend {
b := new(jwtAuthBackend)
b.providerCtx, b.providerCtxCancel = context.WithCancel(context.Background())
b.oidcStates = cache.New(oidcStateTimeout, 1*time.Minute)
b.oidcRequests = cache.New(oidcRequestTimeout, oidcRequestCleanupInterval)

b.Backend = &framework.Backend{
AuthRenew: b.pathLoginRenew,
Expand Down Expand Up @@ -87,6 +86,9 @@ func (b *jwtAuthBackend) cleanup(_ context.Context) {
if b.providerCtxCancel != nil {
b.providerCtxCancel()
}
if b.provider != nil {
b.provider.Done()
}
b.l.Unlock()
}

Expand All @@ -99,24 +101,18 @@ func (b *jwtAuthBackend) invalidate(ctx context.Context, key string) {

func (b *jwtAuthBackend) reset() {
b.l.Lock()
if b.provider != nil {
b.provider.Done()
}
b.provider = nil
b.cachedConfig = nil
b.validator = nil
b.l.Unlock()
}

func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) {
b.l.RLock()
unlockFunc := b.l.RUnlock
defer func() { unlockFunc() }()

if b.provider != nil {
return b.provider, nil
}

b.l.RUnlock()
b.l.Lock()
unlockFunc = b.l.Unlock
defer b.l.Unlock()

if b.provider != nil {
return b.provider, nil
Expand Down
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ module github.com/hashicorp/vault-plugin-auth-jwt
go 1.14

require (
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31
github.com/hashicorp/cap v0.0.0-20210122190810-1e160503dd74
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7
github.com/hashicorp/errwrap v1.0.0
github.com/hashicorp/go-cleanhttp v0.5.1
github.com/hashicorp/go-hclog v0.12.0
github.com/hashicorp/go-sockaddr v1.0.2
github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/go-version v1.2.0 // indirect
github.com/hashicorp/vault/api v1.0.5-0.20200215224050-f6547fa8e820
github.com/hashicorp/vault/sdk v0.1.14-0.20200215224050-f6547fa8e820
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/cap v0.0.0-20210122190810-1e160503dd74 h1:LhrPGNyZNGof4dYRjgAyfeRR4xhsTZATz/01lc7XEDA=
github.com/hashicorp/cap v0.0.0-20210122190810-1e160503dd74/go.mod h1:tIk5rB1nihW5+9bZjI7xlc8LGw8FYfiFMKOpHPbWgug=
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7 h1:6OHvaQs9ys66bR1yqHuoI231JAoalgGgxeqzQuVOfX0=
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7/go.mod h1:tIk5rB1nihW5+9bZjI7xlc8LGw8FYfiFMKOpHPbWgug=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
Expand Down
26 changes: 21 additions & 5 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"net/http"
"strings"

"github.com/coreos/go-oidc"
"github.com/hashicorp/cap/jwt"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/framework"
Expand Down Expand Up @@ -236,9 +236,14 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
return logical.ErrorResponse("both 'oidc_client_id' and 'oidc_client_secret' must be set for OIDC"), nil

case config.OIDCDiscoveryURL != "":
_, err := b.createProvider(config)
var err error
if config.OIDCClientID != "" && config.OIDCClientSecret != "" {
_, err = b.createProvider(config)
} else {
_, err = jwt.NewOIDCDiscoveryKeySet(ctx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
}
if err != nil {
return logical.ErrorResponse(errwrap.Wrapf("error checking oidc discovery URL: {{err}}", err).Error()), nil
return logical.ErrorResponse("error checking oidc discovery URL: %s", err.Error()), nil
}

case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
Expand Down Expand Up @@ -315,12 +320,23 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
oidcCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM)
supportedSigAlgs := make([]oidc.Alg, len(config.JWTSupportedAlgs))
for i, a := range config.JWTSupportedAlgs {
supportedSigAlgs[i] = oidc.Alg(a)
}

if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []oidc.Alg{oidc.RS256}
}

c, err := oidc.NewConfig(config.OIDCDiscoveryURL, config.OIDCClientID,
oidc.ClientSecret(config.OIDCClientSecret), supportedSigAlgs, []string{},
oidc.WithProviderCA(config.OIDCDiscoveryCAPEM))
if err != nil {
return nil, errwrap.Wrapf("error creating provider: {{err}}", err)
}

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
provider, err := oidc.NewProvider(c)
if err != nil {
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
}
Expand Down
4 changes: 4 additions & 0 deletions path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ func TestConfig_OIDC_Write(t *testing.T) {
data := map[string]interface{}{
"oidc_discovery_url": "https://team-vault.auth0.com/",
"oidc_discovery_ca_pem": oidcBadCACerts,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
}

req := &logical.Request{
Expand Down Expand Up @@ -345,6 +347,8 @@ func TestConfig_OIDC_Write(t *testing.T) {
JWTSupportedAlgs: []string{},
OIDCResponseTypes: []string{},
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
OIDCClientID: "abc",
OIDCClientSecret: "def",
ProviderConfig: map[string]interface{}{},
NamespaceInState: true,
}
Expand Down
52 changes: 6 additions & 46 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"errors"
"fmt"

"github.com/coreos/go-oidc"
"github.com/hashicorp/cap/jwt"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/oauth2"
)

func pathLogin(b *jwtAuthBackend) *framework.Path {
Expand Down Expand Up @@ -116,7 +116,7 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("audience claim found in JWT but no audiences bound to the role"), nil
}

alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role, nil)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
Expand Down Expand Up @@ -169,49 +169,9 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques
return resp, nil
}

func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, role *jwtRole, rawToken string) (map[string]interface{}, error) {
allClaims := make(map[string]interface{})

provider, err := b.getProvider(config)
if err != nil {
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
}

oidcConfig := &oidc.Config{
SupportedSigningAlgs: config.JWTSupportedAlgs,
}

if role.RoleType == "oidc" {
oidcConfig.ClientID = config.OIDCClientID
} else {
oidcConfig.SkipClientIDCheck = true
}

verifier := provider.Verifier(oidcConfig)

idToken, err := verifier.Verify(ctx, rawToken)
if err != nil {
return nil, errwrap.Wrapf("error validating signature: {{err}}", err)
}

if err := idToken.Claims(&allClaims); err != nil {
return nil, errwrap.Wrapf("unable to successfully parse all claims from token: {{err}}", err)
}

if role.BoundSubject != "" && role.BoundSubject != idToken.Subject {
return nil, errors.New("sub claim does not match bound subject")
}

if err := validateAudience(role.BoundAudiences, idToken.Audience, false); err != nil {
return nil, errwrap.Wrapf("error validating claims: {{err}}", err)
}

return allClaims, nil
}

// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (*logical.Alias, []*logical.Alias, error) {
userClaimRaw, ok := allClaims[role.UserClaim]
if !ok {
return nil, nil, fmt.Errorf("claim %q not found in token", role.UserClaim)
Expand Down Expand Up @@ -246,7 +206,7 @@ func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[strin
return alias, groupAliases, nil
}

groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role)
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role, tokenSource)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
}
Expand Down Expand Up @@ -285,12 +245,12 @@ func (b *jwtAuthBackend) fetchUserInfo(ctx context.Context, pConfig CustomProvid
}

// Checks if there's a custom provider_config and calls FetchGroups() if implemented
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
// If the custom provider implements interface GroupsFetcher, call it,
// otherwise fall through to the default method
if pConfig != nil {
if gf, ok := pConfig.(GroupsFetcher); ok {
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role)
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role, tokenSource)
if err != nil {
return nil, err
}
Expand Down
Loading