diff --git a/pkg/doc/sdjwt/common/common.go b/pkg/doc/sdjwt/common/common.go index 394aa67dd..1a6bb3540 100644 --- a/pkg/doc/sdjwt/common/common.go +++ b/pkg/doc/sdjwt/common/common.go @@ -256,11 +256,18 @@ func GetCryptoHash(sdAlg string) (crypto.Hash, error) { var cryptoHash crypto.Hash + // From spec: the hash algorithms MD2, MD4, MD5, RIPEMD-160, and SHA-1 revealed fundamental weaknesses + // and they MUST NOT be used. + switch strings.ToUpper(sdAlg) { case crypto.SHA256.String(): cryptoHash = crypto.SHA256 + case crypto.SHA384.String(): + cryptoHash = crypto.SHA384 + case crypto.SHA512.String(): + cryptoHash = crypto.SHA512 default: - err = fmt.Errorf("%s '%s 'not supported", SDAlgorithmKey, sdAlg) + err = fmt.Errorf("%s '%s' not supported", SDAlgorithmKey, sdAlg) } return cryptoHash, err diff --git a/pkg/doc/sdjwt/common/common_test.go b/pkg/doc/sdjwt/common/common_test.go index 2dcac8cb0..75730d706 100644 --- a/pkg/doc/sdjwt/common/common_test.go +++ b/pkg/doc/sdjwt/common/common_test.go @@ -237,7 +237,7 @@ func TestVerifyDisclosuresInSDJWT(t *testing.T) { err = VerifyDisclosuresInSDJWT(nil, signedJWT) r.Error(err) - r.Contains(err.Error(), "_sd_alg 'SHA-XXX 'not supported") + r.Contains(err.Error(), "_sd_alg 'SHA-XXX' not supported") }) t.Run("error - algorithm is not a string", func(t *testing.T) { @@ -446,6 +446,31 @@ func TestGetDisclosedClaims(t *testing.T) { }) } +func TestGetCryptoHash(t *testing.T) { + r := require.New(t) + + t.Run("success", func(t *testing.T) { + hash, err := GetCryptoHash("sha-256") + r.NoError(err) + r.Equal(crypto.SHA256, hash) + + hash, err = GetCryptoHash("sha-384") + r.NoError(err) + r.Equal(crypto.SHA384, hash) + + hash, err = GetCryptoHash("sha-512") + r.NoError(err) + r.Equal(crypto.SHA512, hash) + }) + + t.Run("error - not supported", func(t *testing.T) { + hash, err := GetCryptoHash("invalid") + r.Error(err) + r.Equal(crypto.Hash(0), hash) + r.Contains(err.Error(), "_sd_alg 'invalid' not supported") + }) +} + func TestGetSDAlg(t *testing.T) { r := require.New(t) diff --git a/pkg/doc/sdjwt/integration_test.go b/pkg/doc/sdjwt/integration_test.go index 226940a19..2a334e4c8 100644 --- a/pkg/doc/sdjwt/integration_test.go +++ b/pkg/doc/sdjwt/integration_test.go @@ -8,6 +8,7 @@ package sdjwt import ( "bytes" + "crypto" "crypto/ed25519" "crypto/rand" "encoding/json" @@ -111,6 +112,7 @@ func TestSDJWTFlow(t *testing.T) { // Issuer will issue SD-JWT for specified claims and holder public key. token, err := issuer.New(testIssuer, claims, nil, signer, + issuer.WithHashAlgorithm(crypto.SHA512), issuer.WithNotBefore(jwt.NewNumericDate(now)), issuer.WithIssuedAt(jwt.NewNumericDate(now)), issuer.WithExpiry(jwt.NewNumericDate(now.Add(year))), diff --git a/pkg/doc/sdjwt/issuer/issuer.go b/pkg/doc/sdjwt/issuer/issuer.go index 1c0146f57..b27b8161c 100644 --- a/pkg/doc/sdjwt/issuer/issuer.go +++ b/pkg/doc/sdjwt/issuer/issuer.go @@ -191,6 +191,12 @@ func New(issuer string, claims interface{}, headers jose.Headers, return nil, fmt.Errorf("convert payload to map: %w", err) } + // check for the presence of the _sd claim in claims map + found := keyExistsInMap(common.SDKey, claimsMap) + if found { + return nil, fmt.Errorf("key '%s' cannot be present in the claims", common.SDKey) + } + disclosures, digests, err := createDisclosuresAndDigests("", claimsMap, nOpts) if err != nil { return nil, err @@ -439,6 +445,23 @@ func generateSalt() (string, error) { return base64.RawURLEncoding.EncodeToString(salt), nil } +func keyExistsInMap(key string, claims map[string]interface{}) bool { + for k, v := range claims { + if k == key { + return true + } + + if obj, ok := v.(map[string]interface{}); ok { + exists := keyExistsInMap(key, obj) + if exists { + return true + } + } + } + + return false +} + // payload represents SD-JWT payload. type payload struct { Issuer string `json:"iss,omitempty"` diff --git a/pkg/doc/sdjwt/issuer/issuer_test.go b/pkg/doc/sdjwt/issuer/issuer_test.go index 638f506ef..67325d264 100644 --- a/pkg/doc/sdjwt/issuer/issuer_test.go +++ b/pkg/doc/sdjwt/issuer/issuer_test.go @@ -285,7 +285,7 @@ func TestNew(t *testing.T) { r.Contains(err.Error(), "unknown key id") }) - t.Run("Create Mixed (SD + non-SD) JWS with flat claims flag", func(t *testing.T) { + t.Run("Create Mixed (SD + non-SD) JWS with flat claims flag, SHA-512", func(t *testing.T) { r := require.New(t) _, privKey, err := ed25519.GenerateKey(rand.Reader) @@ -306,6 +306,7 @@ func TestNew(t *testing.T) { newOpts = append(newOpts, WithNonSelectivelyDisclosableClaims([]string{"id", "degree.type"}), + WithHashAlgorithm(crypto.SHA512), ) token, err := New(issuer, complexClaims, nil, afjwt.NewEd25519Signer(privKey), newOpts...) @@ -315,6 +316,8 @@ func TestNew(t *testing.T) { err = token.DecodeClaims(&tokenClaims) r.NoError(err) + r.Equal("sha-512", tokenClaims[common.SDAlgorithmKey]) + printObject(t, "Token Claims", tokenClaims) combinedFormatForIssuance, err := token.Serialize(false) @@ -403,6 +406,41 @@ func TestNew(t *testing.T) { fmt.Println(prettyJSON) }) + t.Run("error - claims contain _sd key (top level object)", func(t *testing.T) { + r := require.New(t) + + _, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + complexClaims := map[string]interface{}{ + "_sd": "whatever", + } + + token, err := New(issuer, complexClaims, nil, afjwt.NewEd25519Signer(privKey)) + r.Error(err) + r.Nil(token) + r.Contains(err.Error(), "key '_sd' cannot be present in the claims") + }) + + t.Run("error - claims contain _sd key (inner object)", func(t *testing.T) { + r := require.New(t) + + _, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + complexClaims := map[string]interface{}{ + "degree": map[string]interface{}{ + "_sd": "whatever", + "type": "BachelorDegree", + }, + } + + token, err := New(issuer, complexClaims, nil, afjwt.NewEd25519Signer(privKey)) + r.Error(err) + r.Nil(token) + r.Contains(err.Error(), "key '_sd' cannot be present in the claims") + }) + t.Run("error - invalid holder public key", func(t *testing.T) { r := require.New(t)