Skip to content

Commit

Permalink
Merge pull request #173 from Sean-Q-Sun/verifyEndpoints
Browse files Browse the repository at this point in the history
Verify distributed claim endpoints
  • Loading branch information
ericchiang authored Jun 4, 2018
2 parents 1bddd0c + 1790296 commit 8ae1da5
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
language: go

go:
- 1.7.5
- 1.8
- "1.9"
- "1.10"

install:
- go get -v -t github.com/coreos/go-oidc/...
Expand Down
49 changes: 49 additions & 0 deletions verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -118,6 +120,53 @@ func contains(sli []string, ele string) bool {
return false
}

// Returns the Claims from the distributed JWT token
func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src claimSource) ([]byte, error) {
req, err := http.NewRequest("GET", src.Endpoint, nil)
if err != nil {
return nil, fmt.Errorf("malformed request: %v", err)
}
if src.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+src.AccessToken)
}

resp, err := doRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("oidc: Request to endpoint failed: %v", err)
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("oidc: request failed: %v", resp.StatusCode)
}

token, err := verifier.Verify(ctx, string(body))
if err != nil {
return nil, fmt.Errorf("malformed response body: %v", err)
}

return token.claims, nil
}

func parseClaim(raw []byte, name string, v interface{}) error {
var parsed map[string]json.RawMessage
if err := json.Unmarshal(raw, &parsed); err != nil {
return err
}

val, ok := parsed[name]
if !ok {
return fmt.Errorf("claim doesn't exist: %s", name)
}

return json.Unmarshal([]byte(val), v)
}

// Verify parses a raw ID Token, verifies it's been signed by the provider, preforms
// any additional checks depending on the Config, and returns the payload.
//
Expand Down
147 changes: 147 additions & 0 deletions verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package oidc
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"
Expand Down Expand Up @@ -342,6 +345,150 @@ func TestDistributedClaims(t *testing.T) {
}
}

func TestDistClaimResolver(t *testing.T) {
tests := []resolverTest{
{
name: "noAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"[email protected]",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
issuer: "https://foo",

want: map[string]claimSource{},
},
{
name: "rightAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"[email protected]",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
accessToken: "1234",
issuer: "https://foo",

want: map[string]claimSource{},
},
{
name: "wrongAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"[email protected]",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
accessToken: "12345",
issuer: "https://foo",
wantErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
claims, err := test.testEndpoint(t)
if err != nil {
if !test.wantErr {
t.Errorf("%v", err)
}
return
}
if test.wantErr {
t.Errorf("expected error receiving response")
return
}
if !reflect.DeepEqual(string(claims), test.payload) {
t.Errorf("expected dist claim: %#v, got: %v", test.payload, string(claims))
}
})
}

}

type resolverTest struct {
// Name of the subtest.
name string

// issuer will be the endpoint server url
issuer string

// just the payload
payload string

// Key to sign the ID Token with.
signKey *signingKey

// If not provided defaults to signKey. Only useful when
// testing invalid signatures.
verificationKey *signingKey

config Config
wantErr bool
want map[string]claimSource

//this is the access token that the testEndpoint will accept
accessToken string
}

func (v resolverTest) testEndpoint(t *testing.T) ([]byte, error) {
token := v.signKey.sign(t, []byte(v.payload))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("Authorization")
if got != "" && got != "Bearer 1234" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
io.WriteString(w, token)
}))
defer s.Close()

issuer := v.issuer
var ks KeySet
if v.verificationKey == nil {
ks = &testVerifier{v.signKey.jwk()}
} else {
ks = &testVerifier{v.verificationKey.jwk()}
}
verifier := NewVerifier(issuer, ks, &v.config)

ctx = ClientContext(ctx, s.Client())

src := claimSource{
Endpoint: s.URL + "/",
AccessToken: v.accessToken,
}
return resolveDistributedClaim(ctx, verifier, src)
}

type verificationTest struct {
// Name of the subtest.
name string
Expand Down

0 comments on commit 8ae1da5

Please sign in to comment.