diff --git a/go.mod b/go.mod index aef6d8e..dae5029 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/auth0/go-jwt-middleware go 1.14 require ( + github.com/golang-jwt/jwt v3.2.1+incompatible github.com/google/go-cmp v0.5.5 github.com/stretchr/testify v1.7.0 // indirect golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect diff --git a/go.sum b/go.sum index e229ff5..2fcb062 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= +github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/validate/jwt-go/examples/main.go b/validate/jwt-go/examples/main.go new file mode 100644 index 0000000..f271786 --- /dev/null +++ b/validate/jwt-go/examples/main.go @@ -0,0 +1,97 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + jwtmiddleware "github.com/auth0/go-jwt-middleware" + jwtgo "github.com/auth0/go-jwt-middleware/validate/jwt-go" + "github.com/golang-jwt/jwt" +) + +// CustomClaimsExample contains custom data we want from the token. +type CustomClaimsExample struct { + Username string `json:"username"` + ShouldReject bool `json:"shouldReject,omitempty"` + jwt.StandardClaims +} + +// Validate does nothing for this example +func (c *CustomClaimsExample) Validate(ctx context.Context) error { + if c.ShouldReject { + return errors.New("should reject was set to true") + } + return nil +} + +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}) + j, err := json.MarshalIndent(claims, "", "\t") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Println(err) + } + + fmt.Fprintf(w, "This is an authenticated request\n") + fmt.Fprintf(w, "Claim content: %s\n", string(j)) +}) + +func main() { + keyFunc := func(t *jwt.Token) (interface{}, error) { + // our token must be signed using this data + return []byte("secret"), nil + } + /*expectedClaims := func() jwt.Expected { + // By setting up expected claims we are saying a token must + // have the data we specify. + return jwt.Expected{ + Issuer: "josev2-example", + Time: time.Now(), + } + }*/ + customClaims := func() jwtgo.CustomClaims { + // we want this struct to be filled in with our custom claims + // from the token + return &CustomClaimsExample{} + } + + // setup the jwt-go validator + validator, err := jwtgo.New( + keyFunc, + "HS256", + //jwtgo.WithExpectedClaims(expectedClaims), + jwtgo.WithCustomClaims(customClaims), + //jwtgo.WithAllowedClockSkew(30*time.Second), + ) + + if err != nil { + // we'll panic in order to fail fast + panic(err) + } + + // setup the middleware + m := jwtmiddleware.New(validator.ValidateToken) + + http.ListenAndServe("0.0.0.0:3000", m.CheckJWT(handler)) + // try it out with eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqd3Rnby1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIn0.ha_JgA29vSAb3HboPRXEi9Dm5zy7ARzd4P8AFoYP9t0 + // which is signed with 'secret' and has the data: + // { + // "iss": "jwtgo-example", + // "sub": "1234567890", + // "iat": 1516239022, + // "username": "user123" + // } + + // you can also try out the custom validation with eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqd3Rnby1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIiwic2hvdWxkUmVqZWN0Ijp0cnVlfQ.awZ0DFpJ-hH5xn-q-sZHJWj7oTAOkPULwgFO4O6D67o + // which is signed with 'secret' and has the data: + // { + // "iss": "jwtgo-example", + // "sub": "1234567890", + // "iat": 1516239022, + // "username": "user123", + // "shouldReject": true + // } +} diff --git a/validate/jwt-go/jwtgo.go b/validate/jwt-go/jwtgo.go new file mode 100644 index 0000000..6bab2dc --- /dev/null +++ b/validate/jwt-go/jwtgo.go @@ -0,0 +1,93 @@ +package jwtgo + +import ( + "context" + "errors" + "fmt" + + "github.com/golang-jwt/jwt" +) + +// CustomClaims defines any custom data / claims wanted. The validator will +// call the Validate function which is where custom validation logic can be +// defined. +type CustomClaims interface { + jwt.Claims + Validate(context.Context) error +} + +// Option is how options for the validator are setup. +type Option func(*validator) + +// WithCustomClaims sets up a function that returns the object CustomClaims are +// unmarshalled into and the object which Validate is called on for custom +// validation. If this option is not used the validator will do nothing for +// custom claims. +func WithCustomClaims(f func() CustomClaims) Option { + return func(v *validator) { + v.customClaims = f + } +} + +// New sets up a new Validator. With the required keyFunc and +// signatureAlgorithm as well as options. +func New(keyFunc jwt.Keyfunc, + signatureAlgorithm string, + opts ...Option) (*validator, error) { + + if keyFunc == nil { + return nil, errors.New("keyFunc is required but was nil") + } + + v := &validator{ + keyFunc: keyFunc, + signatureAlgorithm: signatureAlgorithm, + customClaims: nil, + } + + for _, opt := range opts { + opt(v) + } + + return v, nil +} + +type validator struct { + // required options + + keyFunc func(*jwt.Token) (interface{}, error) + signatureAlgorithm string + + // optional options + customClaims func() CustomClaims +} + +// ValidateToken validates the passed in JWT using the jwt-go package. +func (v *validator) ValidateToken(ctx context.Context, token string) (interface{}, error) { + var claims jwt.Claims + + if v.customClaims != nil { + claims = v.customClaims() + } else { + claims = &jwt.StandardClaims{} + } + + p := new(jwt.Parser) + + if v.signatureAlgorithm != "" { + p.ValidMethods = []string{v.signatureAlgorithm} + } + + _, err := p.ParseWithClaims(token, claims, v.keyFunc) + if err != nil { + return nil, fmt.Errorf("could not parse the token: %w", err) + } + + if customClaims, ok := claims.(CustomClaims); ok { + if err = customClaims.Validate(ctx); err != nil { + return nil, fmt.Errorf("custom claims not validated: %w", err) + } + } + + return claims, nil +} diff --git a/validate/jwt-go/jwtgo_test.go b/validate/jwt-go/jwtgo_test.go new file mode 100644 index 0000000..14d4af1 --- /dev/null +++ b/validate/jwt-go/jwtgo_test.go @@ -0,0 +1,142 @@ +package jwtgo + +import ( + "context" + "errors" + "testing" + + "github.com/golang-jwt/jwt" + "github.com/google/go-cmp/cmp" +) + +type testingCustomClaims struct { + Foo string `json:"foo"` + ReturnError error + jwt.StandardClaims +} + +func (tcc *testingCustomClaims) Validate(ctx context.Context) error { + return tcc.ReturnError +} + +func equalErrors(actual error, expected string) bool { + if actual == nil { + return expected == "" + } + return actual.Error() == expected +} + +func Test_Validate(t *testing.T) { + testCases := []struct { + name string + signatureAlgorithm string + token string + keyFuncReturnError error + customClaims CustomClaims + expectedError string + expectedContext jwt.Claims + }{ + { + name: "happy path", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.Rq8IxqeX7eA6GgYxlcHdPFVRNFFZc5rEI3MQTZZbK3I`, + expectedContext: &jwt.StandardClaims{Subject: "1234567890"}, + }, + { + // we want to test that when it expects RSA but we send + // HMAC encrypted with the server public key it will + // error + name: "errors on wrong algorithm", + signatureAlgorithm: "PS256", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + expectedError: "could not parse the token: signing method HS256 is invalid", + }, + { + name: "errors on wrong token format errors", + expectedError: "could not parse the token: token contains an invalid number of segments", + }, + { + name: "errors when the key func errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + keyFuncReturnError: errors.New("key func error message"), + expectedError: "could not parse the token: key func error message", + }, + { + name: "errors when signature is invalid", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.hDyICUnkCrwFJnkJHRSkwMZNSYZ9LI6z2EFJdtwFurA`, + expectedError: "could not parse the token: signature is invalid", + }, + { + name: "errors when custom claims errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZm9vIjoiYmFyIiwiaWF0IjoxNTE2MjM5MDIyfQ.DFTWyYib4-xFdMaEZFAYx5AKMPNS7Hhl4kcyjQVinYc`, + customClaims: &testingCustomClaims{ReturnError: errors.New("custom claims error message")}, + expectedError: "custom claims not validated: custom claims error message", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var customClaimsFunc func() CustomClaims = nil + if testCase.customClaims != nil { + customClaimsFunc = func() CustomClaims { return testCase.customClaims } + } + + v, _ := New(func(token *jwt.Token) (interface{}, error) { + return []byte("secret"), testCase.keyFuncReturnError + }, + testCase.signatureAlgorithm, + WithCustomClaims(customClaimsFunc), + ) + actualContext, err := v.ValidateToken(context.Background(), testCase.token) + if !equalErrors(err, testCase.expectedError) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", testCase.expectedError, err) + } + + if (testCase.expectedContext == nil && actualContext != nil) || (testCase.expectedContext != nil && actualContext == nil) { + t.Fatalf("wanted user context:\n%+v\ngot:\n%+v\n", testCase.expectedContext, actualContext) + } else if testCase.expectedContext != nil { + if diff := cmp.Diff(testCase.expectedContext, actualContext.(jwt.Claims)); diff != "" { + t.Errorf("user context mismatch (-want +got):\n%s", diff) + } + + } + }) + } +} + +func Test_New(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + keyFunc := func(t *jwt.Token) (interface{}, error) { return nil, nil } + customClaims := func() CustomClaims { return nil } + + v, err := New(keyFunc, "HS256", WithCustomClaims(customClaims)) + + if !equalErrors(err, "") { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", "", err) + } + + if v.keyFunc == nil { + t.Log("keyFunc was nil when it should not have been") + t.Fail() + } + + if v.signatureAlgorithm != "HS256" { + t.Logf("signatureAlgorithm was %q when it should have been %q", v.signatureAlgorithm, "HS256") + t.Fail() + } + + if v.customClaims == nil { + t.Log("customClaims was nil when it should not have been") + t.Fail() + } + }) + + t.Run("error on no keyFunc", func(t *testing.T) { + _, err := New(nil, "HS256") + + expectedErr := "keyFunc is required but was nil" + if !equalErrors(err, expectedErr) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", expectedErr, err) + } + }) + +}