Skip to content

Commit

Permalink
Add PKCE to OIDC Auth (#188) (#191)
Browse files Browse the repository at this point in the history
* Add PKCE to OIDC authorization code logins

* Add tests

* fix comment typos

* check for code response type

Co-authored-by: Jim Kalafut <[email protected]>

Co-authored-by: Jim Kalafut <[email protected]>
  • Loading branch information
fairclothjm and Jim Kalafut authored Dec 8, 2021
1 parent 7e8c8eb commit 4096e1f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 11 deletions.
7 changes: 7 additions & 0 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,13 @@ func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rol

if config.hasType(responseTypeIDToken) {
options = append(options, oidc.WithImplicitFlow())
} else if config.hasType(responseTypeCode) {
v, err := oidc.NewCodeVerifier()
if err != nil {
return nil, fmt.Errorf("error creating code challenge: %w", err)
}

options = append(options, oidc.WithPKCE(v))
}

if role.MaxAge > 0 {
Expand Down
103 changes: 92 additions & 11 deletions path_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package jwtauth
import (
"bytes"
"context"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
Expand Down Expand Up @@ -101,6 +103,7 @@ func TestOIDC_AuthURL(t *testing.T) {
`state=st_\w{20}`,
`redirect_uri=https%3A%2F%2Fexample.com`,
`response_type=code`,
`code_challenge=\w+`,
`scope=openid`,
}

Expand Down Expand Up @@ -211,7 +214,6 @@ func TestOIDC_AuthURL(t *testing.T) {
}

func TestOIDC_AuthURL_namespace(t *testing.T) {

type testCase struct {
namespaceInState string
allowedRedirectURIs []string
Expand Down Expand Up @@ -362,7 +364,6 @@ func TestOIDC_AuthURL_namespace(t *testing.T) {
if !matchState {
t.Fatalf("expected state to match regex: %s, %s", test.expectedStateRegEx, state)
}

})
}
}
Expand Down Expand Up @@ -583,7 +584,6 @@ func TestOIDC_ResponseTypeIDToken(t *testing.T) {

func TestOIDC_Callback(t *testing.T) {
t.Run("successful login", func(t *testing.T) {

// run test with and without bound_cidrs configured
for _, useBoundCIDRs := range []bool{false, true} {
b, storage, s := getBackendAndServer(t, useBoundCIDRs)
Expand Down Expand Up @@ -617,7 +617,10 @@ func TestOIDC_Callback(t *testing.T) {
// set mock provider's expected code
s.code = "abc"

// invoke the callback, which will in to try to exchange the code
// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

// invoke the callback, which will try to exchange the code
// with the mock provider.
req = &logical.Request{
Operation: logical.ReadOperation,
Expand Down Expand Up @@ -711,6 +714,9 @@ func TestOIDC_Callback(t *testing.T) {
// set mock provider's expected code
s.code = "abc"

// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

// invoke the callback, which will in to try to exchange the code
// with the mock provider.
req = &logical.Request{
Expand Down Expand Up @@ -765,6 +771,9 @@ func TestOIDC_Callback(t *testing.T) {
// set mock provider's expected code
s.code = "abc"

// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

// invoke the callback, which will in to try to exchange the code
// with the mock provider.
req = &logical.Request{
Expand Down Expand Up @@ -896,6 +905,10 @@ func TestOIDC_Callback(t *testing.T) {
// set mock provider's expected code
s.code = "abc"

// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

// verify failure with wrong code
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "oidc/callback",
Expand All @@ -915,6 +928,58 @@ func TestOIDC_Callback(t *testing.T) {
}
})

t.Run("failed code exchange (PKCE)", func(t *testing.T) {
b, storage, s := getBackendAndServer(t, false)
defer s.server.Close()

// get auth_url
data := map[string]interface{}{
"role": "test",
"redirect_uri": "https://example.com",
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "oidc/auth_url",
Storage: storage,
Data: data,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}

authURL := resp.Data["auth_url"].(string)
state := getQueryParam(t, authURL, "state")

// set mock provider's expected code
s.code = "abc"

// Verify failure with failed PKCE verification
// The challenge on the request side is embedded in the cap library request state which
// is inaccessible. To cause a mismatch, adjust the mock.
s.codeChallenge = "wrong_challenge"

// verify failure with PKCE
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "oidc/callback",
Storage: storage,
Data: map[string]interface{}{
"state": state,
"code": "abc",
},
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}

if resp == nil || !strings.Contains(resp.Error().Error(), "cannot fetch token") {
t.Fatalf("expected code exchange error response, got: %#v", resp)
}
})

t.Run("no response from provider", func(t *testing.T) {
b, storage, s := getBackendAndServer(t, false)

Expand Down Expand Up @@ -1037,6 +1102,9 @@ func TestOIDC_Callback(t *testing.T) {
// set provider claims that will be returned by the mock server
s.customClaims = sampleClaims(nonce)

// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

req = &logical.Request{
Operation: logical.ReadOperation,
Path: "oidc/callback",
Expand Down Expand Up @@ -1123,6 +1191,9 @@ func TestOIDC_Callback(t *testing.T) {
// set mock provider's expected code
s.code = "abc"

// save PKCE challenge
s.codeChallenge = getQueryParam(t, authURL, "code_challenge")

// invoke the callback, which will try to exchange the code
// with the mock provider.
req = &logical.Request{
Expand All @@ -1149,15 +1220,16 @@ func TestOIDC_Callback(t *testing.T) {
})
}

// oidcProvider is local server the mocks the basis endpoints used by the
// oidcProvider is a local server that mocks the basis endpoints used by the
// OIDC callback process.
type oidcProvider struct {
t *testing.T
server *httptest.Server
clientID string
clientSecret string
code string
customClaims map[string]interface{}
t *testing.T
server *httptest.Server
clientID string
clientSecret string
code string
codeChallenge string
customClaims map[string]interface{}
}

func newOIDCProvider(t *testing.T) *oidcProvider {
Expand Down Expand Up @@ -1190,12 +1262,21 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("It's not a keyset!"))
case "/token":
code := r.FormValue("code")
codeVerifier := r.FormValue("code_verifier")

if code != o.code {
w.WriteHeader(401)
break
}

sum := sha256.Sum256([]byte(codeVerifier))
computedChallenge := base64.RawURLEncoding.EncodeToString(sum[:])

if computedChallenge != o.codeChallenge {
w.WriteHeader(401)
break
}

stdClaims := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: o.server.URL,
Expand Down

0 comments on commit 4096e1f

Please sign in to comment.