diff --git a/pkg/doc/sdjwt/common/common.go b/pkg/doc/sdjwt/common/common.go index 394aa67dd5..1a6bb3540d 100644 --- a/pkg/doc/sdjwt/common/common.go +++ b/pkg/doc/sdjwt/common/common.go @@ -256,11 +256,18 @@ func GetCryptoHash(sdAlg string) (crypto.Hash, error) { var cryptoHash crypto.Hash + // From spec: the hash algorithms MD2, MD4, MD5, RIPEMD-160, and SHA-1 revealed fundamental weaknesses + // and they MUST NOT be used. + switch strings.ToUpper(sdAlg) { case crypto.SHA256.String(): cryptoHash = crypto.SHA256 + case crypto.SHA384.String(): + cryptoHash = crypto.SHA384 + case crypto.SHA512.String(): + cryptoHash = crypto.SHA512 default: - err = fmt.Errorf("%s '%s 'not supported", SDAlgorithmKey, sdAlg) + err = fmt.Errorf("%s '%s' not supported", SDAlgorithmKey, sdAlg) } return cryptoHash, err diff --git a/pkg/doc/sdjwt/common/common_test.go b/pkg/doc/sdjwt/common/common_test.go index 2dcac8cb0d..75730d7066 100644 --- a/pkg/doc/sdjwt/common/common_test.go +++ b/pkg/doc/sdjwt/common/common_test.go @@ -237,7 +237,7 @@ func TestVerifyDisclosuresInSDJWT(t *testing.T) { err = VerifyDisclosuresInSDJWT(nil, signedJWT) r.Error(err) - r.Contains(err.Error(), "_sd_alg 'SHA-XXX 'not supported") + r.Contains(err.Error(), "_sd_alg 'SHA-XXX' not supported") }) t.Run("error - algorithm is not a string", func(t *testing.T) { @@ -446,6 +446,31 @@ func TestGetDisclosedClaims(t *testing.T) { }) } +func TestGetCryptoHash(t *testing.T) { + r := require.New(t) + + t.Run("success", func(t *testing.T) { + hash, err := GetCryptoHash("sha-256") + r.NoError(err) + r.Equal(crypto.SHA256, hash) + + hash, err = GetCryptoHash("sha-384") + r.NoError(err) + r.Equal(crypto.SHA384, hash) + + hash, err = GetCryptoHash("sha-512") + r.NoError(err) + r.Equal(crypto.SHA512, hash) + }) + + t.Run("error - not supported", func(t *testing.T) { + hash, err := GetCryptoHash("invalid") + r.Error(err) + r.Equal(crypto.Hash(0), hash) + r.Contains(err.Error(), "_sd_alg 'invalid' not supported") + }) +} + func TestGetSDAlg(t *testing.T) { r := require.New(t) diff --git a/pkg/doc/sdjwt/issuer/issuer.go b/pkg/doc/sdjwt/issuer/issuer.go index 1c0146f576..b27b8161cf 100644 --- a/pkg/doc/sdjwt/issuer/issuer.go +++ b/pkg/doc/sdjwt/issuer/issuer.go @@ -191,6 +191,12 @@ func New(issuer string, claims interface{}, headers jose.Headers, return nil, fmt.Errorf("convert payload to map: %w", err) } + // check for the presence of the _sd claim in claims map + found := keyExistsInMap(common.SDKey, claimsMap) + if found { + return nil, fmt.Errorf("key '%s' cannot be present in the claims", common.SDKey) + } + disclosures, digests, err := createDisclosuresAndDigests("", claimsMap, nOpts) if err != nil { return nil, err @@ -439,6 +445,23 @@ func generateSalt() (string, error) { return base64.RawURLEncoding.EncodeToString(salt), nil } +func keyExistsInMap(key string, claims map[string]interface{}) bool { + for k, v := range claims { + if k == key { + return true + } + + if obj, ok := v.(map[string]interface{}); ok { + exists := keyExistsInMap(key, obj) + if exists { + return true + } + } + } + + return false +} + // payload represents SD-JWT payload. type payload struct { Issuer string `json:"iss,omitempty"` diff --git a/pkg/doc/sdjwt/issuer/issuer_test.go b/pkg/doc/sdjwt/issuer/issuer_test.go index 638f506ef6..a00e54b93e 100644 --- a/pkg/doc/sdjwt/issuer/issuer_test.go +++ b/pkg/doc/sdjwt/issuer/issuer_test.go @@ -403,6 +403,41 @@ func TestNew(t *testing.T) { fmt.Println(prettyJSON) }) + t.Run("error - claims contain _sd key (top level object)", func(t *testing.T) { + r := require.New(t) + + _, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + complexClaims := map[string]interface{}{ + "_sd": "whatever", + } + + token, err := New(issuer, complexClaims, nil, afjwt.NewEd25519Signer(privKey)) + r.Error(err) + r.Nil(token) + r.Contains(err.Error(), "key '_sd' cannot be present in the claims") + }) + + t.Run("error - claims contain _sd key (inner object)", func(t *testing.T) { + r := require.New(t) + + _, privKey, err := ed25519.GenerateKey(rand.Reader) + r.NoError(err) + + complexClaims := map[string]interface{}{ + "degree": map[string]interface{}{ + "_sd": "whatever", + "type": "BachelorDegree", + }, + } + + token, err := New(issuer, complexClaims, nil, afjwt.NewEd25519Signer(privKey)) + r.Error(err) + r.Nil(token) + r.Contains(err.Error(), "key '_sd' cannot be present in the claims") + }) + t.Run("error - invalid holder public key", func(t *testing.T) { r := require.New(t)