generated from TBD54566975/tbd-project-template
-
Notifications
You must be signed in to change notification settings - Fork 8
/
jwt.go
286 lines (235 loc) · 7.92 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
package jwt
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/decentralized-identity/web5-go/dids/did"
"github.com/decentralized-identity/web5-go/dids/didcore"
"github.com/decentralized-identity/web5-go/jws"
)
// Decode decodes the 3-part base64url encoded jwt into it's relevant parts
func Decode(jwt string) (Decoded, error) {
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
return Decoded{}, fmt.Errorf("malformed JWT. Expected 3 parts, got %d", len(parts))
}
header, err := jws.DecodeHeader(parts[0])
if err != nil {
return Decoded{}, fmt.Errorf("malformed JWT. Failed to decode header: %w", err)
}
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return Decoded{}, fmt.Errorf("malformed JWT. Failed to decode claims: %w", err)
}
claims := Claims{}
err = json.Unmarshal(claimsBytes, &claims)
if err != nil {
return Decoded{}, fmt.Errorf("malformed JWT. Failed to unmarshal claims: %w", err)
}
signature, err := jws.DecodeSignature(parts[2])
if err != nil {
return Decoded{}, fmt.Errorf("malformed JWT. Failed to decode signature: %w", err)
}
signerDid, err := did.Parse(header.KID)
if err != nil {
return Decoded{}, fmt.Errorf("malformed JWT. Failed to parse signer DID: %w", err)
}
return Decoded{
Header: header,
Claims: claims,
Signature: signature,
Parts: parts,
SignerDID: signerDid,
}, nil
}
// signOpts is a type that holds all the options that can be passed to Sign
type signOpts struct {
selector didcore.VMSelector
typ string
}
// SignOpt is a type returned by all individual Sign Options.
type SignOpt func(opts *signOpts)
// Purpose is an option that can be provided to Sign to specify that a key from
// a given DID Document Verification Relationship should be used (e.g. authentication)
// Purpose is an option that can be passed to [github.com/decentralized-identity/web5-go/jws.Sign].
// It is used to select the appropriate key to sign with
func Purpose(p string) SignOpt {
return func(opts *signOpts) {
opts.selector = didcore.Purpose(p)
}
}
// Type is an option that can be used to set the typ header of the JWT
func Type(t string) SignOpt {
return func(opts *signOpts) {
opts.typ = t
}
}
// Sign signs the provided JWT Claims with the provided BearerDID.
// The Purpose option can be provided to specify that a key from a given
// DID Document Verification Relationship should be used (e.g. authentication).
// defaults to using assertionMethod
//
// # Note
//
// claims.Issuer will be overridden to the value of did.URI within this function
func Sign(claims Claims, did did.BearerDID, opts ...SignOpt) (string, error) {
o := signOpts{selector: nil, typ: ""}
for _, opt := range opts {
opt(&o)
}
jwsOpts := make([]jws.SignOpt, 0)
if o.typ != "" {
jwsOpts = append(jwsOpts, jws.Type(o.typ))
}
if o.selector != nil {
jwsOpts = append(jwsOpts, jws.VMSelector(o.selector))
}
// `iss` is required to be equal to the DID's URI
claims.Issuer = did.URI
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal jwt claims: %w", err)
}
return jws.Sign(payload, did, jwsOpts...)
}
// Verify verifies a JWT (JSON Web Token) as per the spec https://datatracker.ietf.org/doc/html/rfc7519
// Successful verification means that the JWT has not expired and the signature's integrity is intact
// Decoded JWT is returned if verification is successful
func Verify(jwt string) (Decoded, error) {
decodedJWT, err := Decode(jwt)
if err != nil {
return Decoded{}, err
}
err = decodedJWT.Verify()
return decodedJWT, err
}
// Header are JWS Headers. type aliasing because this could cause confusion
// for non-neckbeards
type Header = jws.Header
// Decoded represents a JWT Decoded into it's relevant parts
type Decoded struct {
Header Header
Claims Claims
Signature []byte
Parts []string
SignerDID did.DID
}
// Verify verifies a JWT (JSON Web Token)
func (jwt Decoded) Verify() error {
if jwt.Claims.Expiration != 0 && time.Now().Unix() > jwt.Claims.Expiration {
return errors.New("JWT has expired")
}
claimsBytes, err := base64.RawURLEncoding.DecodeString(jwt.Parts[1])
if err != nil {
return fmt.Errorf("malformed JWT. Failed to decode claims: %w", err)
}
decodedJWS := jws.Decoded{
Header: jwt.Header,
Payload: claimsBytes,
Signature: jwt.Signature,
Parts: jwt.Parts,
}
err = decodedJWS.Verify()
if err != nil {
return fmt.Errorf("JWT signature verification failed: %w", err)
}
// check to ensure that issuer has been set and that it matches the did used to sign.
// the value of KID should always be ${did}#${verificationMethodID} (aka did url)
if jwt.Claims.Issuer == "" || !strings.HasPrefix(jwt.Header.KID, jwt.Claims.Issuer) {
return errors.New("JWT issuer does not match the did url provided as KID")
}
//! we should check ^ prior to verifying the signature as verification
//! requires DID resolution which is a network call. doing so without duplicating
//! code is a bit tricky (Moe 2024-02-25)
return nil
}
// Claims represents JWT (JSON Web Token) Claims
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4
type Claims struct {
// The "iss" (issuer) claim identifies the principal that issued the
// JWT.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1
Issuer string `json:"iss,omitempty"`
// The "sub" (subject) claim identifies the principal that is the
// subject of the JWT.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2
Subject string `json:"sub,omitempty"`
// The "aud" (audience) claim identifies the recipients that the JWT is
// intended for.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience string `json:"aud,omitempty"`
// The "exp" (expiration time) claim identifies the expiration time on
// or after which the JWT must not be accepted for processing.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
Expiration int64 `json:"exp,omitempty"`
// The "nbf" (not before) claim identifies the time before which the JWT
// must not be accepted for processing.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5
NotBefore int64 `json:"nbf,omitempty"`
// The "iat" (issued at) claim identifies the time at which the JWT was
// issued.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
IssuedAt int64 `json:"iat,omitempty"`
// The "jti" (JWT ID) claim provides a unique identifier for the JWT.
//
// Spec: https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7
JTI string `json:"jti,omitempty"`
Misc map[string]any `json:"-"`
}
// MarshalJSON overrides default json.Marshal behavior to include misc claims as flattened
// properties of the top-level object
func (c Claims) MarshalJSON() ([]byte, error) {
copied := cpy(c)
bytes, err := json.Marshal(copied)
if err != nil {
return nil, err
}
var combined map[string]interface{}
err = json.Unmarshal(bytes, &combined)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal jwt claims: %w", err)
}
// Add private claims to the map
for key, value := range c.Misc {
combined[key] = value
}
return json.Marshal(combined)
}
// UnmarshalJSON overrides default json.Unmarshal behavior to place flattened Misc
// claims into Misc
func (c *Claims) UnmarshalJSON(b []byte) error {
var m map[string]interface{}
if err := json.Unmarshal(b, &m); err != nil {
return err
}
registeredClaims := map[string]bool{
"iss": true, "sub": true, "aud": true,
"exp": true, "nbf": true, "iat": true,
"jti": true,
}
misc := make(map[string]any)
for key, value := range m {
if _, ok := registeredClaims[key]; !ok {
misc[key] = value
}
}
claims := cpy{}
if err := json.Unmarshal(b, &claims); err != nil {
return err
}
claims.Misc = misc
*c = Claims(claims)
return nil
}
// cpy is a copy of Claims that is used to marshal/unmarshal the claims without infinitely looping
type cpy Claims