diff --git a/impl/go.mod b/impl/go.mod index 6632ca96..8a08caad 100644 --- a/impl/go.mod +++ b/impl/go.mod @@ -66,6 +66,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/edsrzf/mmap-go v1.1.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.4 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -133,6 +134,7 @@ require ( github.com/swaggo/swag v1.8.12 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect + github.com/x448/float16 v0.8.4 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 // indirect go.opentelemetry.io/otel/metric v1.28.0 // indirect go.opentelemetry.io/proto/otlp v1.3.1 // indirect diff --git a/impl/go.sum b/impl/go.sum index 4ad34051..ee4a9c11 100644 --- a/impl/go.sum +++ b/impl/go.sum @@ -119,6 +119,8 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gabriel-vasile/mimetype v1.4.4 h1:QjV6pZ7/XZ7ryI2KuyeEDE8wnh7fHP9YnQy+R0LnH8I= github.com/gabriel-vasile/mimetype v1.4.4/go.mod h1:JwLei5XPtWdGiMFB5Pjle1oEeoSeEuJfJE+TtfvdB/s= github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= @@ -453,6 +455,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= diff --git a/impl/internal/did/cbor.go b/impl/internal/did/cbor.go new file mode 100644 index 00000000..e59ca64a --- /dev/null +++ b/impl/internal/did/cbor.go @@ -0,0 +1,378 @@ +package did + +import ( + "bytes" + "encoding/base64" + "fmt" + "strings" + + "github.com/TBD54566975/ssi-sdk/crypto" + "github.com/TBD54566975/ssi-sdk/crypto/jwx" + "github.com/TBD54566975/ssi-sdk/cryptosuite" + "github.com/TBD54566975/ssi-sdk/did" + "github.com/fxamacker/cbor/v2" +) + +const ( + // CBOR map keys + keyID byte = 1 + keyVerificationMethod byte = 2 + keyAuthentication byte = 3 + keyAssertionMethod byte = 4 + keyKeyAgreement byte = 5 + keyCapabilityInvocation byte = 6 + keyCapabilityDelegation byte = 7 + keyService byte = 8 + keyController byte = 9 + keyAlsoKnownAs byte = 10 + keyTypes byte = 11 + keyGateways byte = 12 + keyPreviousDID byte = 13 + + // VerificationMethod keys + keyVMID byte = 1 + keyVMType byte = 2 + keyVMController byte = 3 + keyVMPublicKey byte = 4 + + // PublicKey keys + keyPKType byte = 1 + keyPKData byte = 2 + keyPKAlg byte = 3 + keyPKCrv byte = 4 + keyPKKty byte = 5 + + // Service keys + keyServiceID byte = 1 + keyServiceType byte = 2 + keyServiceEndpoint byte = 3 + keyServiceEnc byte = 4 + keyServiceSig byte = 5 +) + +func (d DHT) ToCBOR(doc did.Document, types []TypeIndex, gateways []AuthoritativeGateway, previousDID *PreviousDID) ([]byte, error) { + em, err := cbor.EncOptions{Sort: cbor.SortCanonical}.EncMode() + if err != nil { + return nil, err + } + + cborMap := make(map[byte]any) + + // Extract the DID suffix + didSuffix := strings.TrimPrefix(doc.ID, "did:dht:") + + // Only include non-empty fields + if didSuffix != "" { + cborMap[keyID] = didSuffix + } + + if len(doc.VerificationMethod) > 0 { + vms := make([]any, len(doc.VerificationMethod)) + for i, vm := range doc.VerificationMethod { + vmMap := make(map[byte]any) + vmMap[keyVMID] = strings.TrimPrefix(vm.ID, doc.ID+"#") + + // Convert the public key to bytes + pubKey, err := vm.PublicKeyJWK.ToPublicKey() + if err != nil { + return nil, err + } + pubKeyBytes, err := crypto.PubKeyToBytes(pubKey, crypto.ECDSAMarshalCompressed) + if err != nil { + return nil, err + } + + keyType := keyTypeForJWK(*vm.PublicKeyJWK) + pkData := fmt.Sprintf("t=%d;k=%s", keyType, base64.RawURLEncoding.EncodeToString(pubKeyBytes)) + + // Only include the alg if it's not the default for the key type + if !algIsDefaultForJWK(*vm.PublicKeyJWK) { + pkData += fmt.Sprintf(";a=%s", vm.PublicKeyJWK.ALG) + } + + // Only include the controller if it's different from the DID + if vm.Controller != doc.ID { + pkData += fmt.Sprintf(";c=%s", vm.Controller) + } + + vmMap[keyVMPublicKey] = pkData + vms[i] = vmMap + } + cborMap[keyVerificationMethod] = vms + } + + addVerificationRelationship := func(key byte, relationships []did.VerificationMethodSet) { + if len(relationships) > 0 { + refs := make([]string, len(relationships)) + for i, r := range relationships { + refs[i] = strings.TrimPrefix(r.(string), doc.ID+"#") + } + cborMap[key] = refs + } + } + + addVerificationRelationship(keyAuthentication, doc.Authentication) + addVerificationRelationship(keyAssertionMethod, doc.AssertionMethod) + addVerificationRelationship(keyKeyAgreement, doc.KeyAgreement) + addVerificationRelationship(keyCapabilityInvocation, doc.CapabilityInvocation) + addVerificationRelationship(keyCapabilityDelegation, doc.CapabilityDelegation) + + if len(doc.Services) > 0 { + services := make([]any, len(doc.Services)) + for i, svc := range doc.Services { + svcMap := make(map[byte]any) + svcMap[keyServiceID] = svc.ID + svcMap[keyServiceType] = svc.Type + svcMap[keyServiceEndpoint] = svc.ServiceEndpoint + if svc.Enc != nil { + svcMap[keyServiceEnc] = svc.Enc + } + if svc.Sig != nil { + svcMap[keyServiceSig] = svc.Sig + } + services[i] = svcMap + } + cborMap[keyService] = services + } + + if doc.Controller != nil { + cborMap[keyController] = doc.Controller + } + + if doc.AlsoKnownAs != nil { + cborMap[keyAlsoKnownAs] = doc.AlsoKnownAs + } + + if len(types) > 0 { + cborMap[keyTypes] = types + } + + if len(gateways) > 0 { + cborMap[keyGateways] = gateways + } + + if previousDID != nil { + cborMap[keyPreviousDID] = map[string]any{ + "did": previousDID.PreviousDID, + "signature": previousDID.Signature, + } + } + + return em.Marshal(cborMap) +} + +func (d DHT) FromCBOR(cborData []byte) (*DIDDHTDocument, error) { + var cborMap map[any]any + if err := cbor.Unmarshal(cborData, &cborMap); err != nil { + return nil, fmt.Errorf("failed to unmarshal CBOR data: %w", err) + } + + var doc did.Document + var types []TypeIndex + var gateways []AuthoritativeGateway + var previousDID *PreviousDID + + if id, ok := getMapValue(cborMap, keyID).(string); ok { + doc.ID = "did:dht:" + id + } + + // Get the identity key from the DID + identityKey, err := DHT(doc.ID).IdentityKey() + if err != nil { + return nil, fmt.Errorf("failed to get identity key: %w", err) + } + + if vms, ok := getMapValue(cborMap, keyVerificationMethod).([]any); ok { + for _, vmInterface := range vms { + if vm, ok := vmInterface.(map[any]any); ok { + vmID := doc.ID + "#" + getMapValue(vm, keyVMID).(string) + pkData := getMapValue(vm, keyVMPublicKey).(string) + + pkParts := strings.Split(pkData, ";") + var keyType int + var keyData, alg, controller string + + for _, part := range pkParts { + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + switch kv[0] { + case "t": + fmt.Sscanf(kv[1], "%d", &keyType) + case "k": + keyData = kv[1] + case "a": + alg = kv[1] + case "c": + controller = kv[1] + } + } + + if controller == "" { + controller = doc.ID + } + + keyBytes, err := base64.RawURLEncoding.DecodeString(keyData) + if err != nil { + return nil, err + } + + pubKey, err := crypto.BytesToPubKey(keyBytes, keyTypeLookUp(fmt.Sprintf("%d", keyType)), crypto.ECDSAUnmarshalCompressed) + if err != nil { + return nil, err + } + + jwk, err := jwx.PublicKeyToPublicKeyJWK(nil, pubKey) + if err != nil { + return nil, err + } + + if alg == "" { + alg = defaultAlgForJWK(*jwk) + } + jwk.ALG = alg + + // Check if this is the identity key + if bytes.Equal(keyBytes, identityKey) { + jwk.KID = "0" + vmID = doc.ID + "#0" + } else { + jwk.KID = strings.TrimPrefix(vmID, doc.ID+"#") + } + + verificationMethod := did.VerificationMethod{ + ID: vmID, + Type: cryptosuite.JSONWebKeyType, + Controller: controller, + PublicKeyJWK: jwk, + } + + doc.VerificationMethod = append(doc.VerificationMethod, verificationMethod) + } + } + } + + getVerificationRelationship := func(key byte) []did.VerificationMethodSet { + if relationships, ok := getMapValue(cborMap, key).([]any); ok { + var result []did.VerificationMethodSet + for _, r := range relationships { + if ref, ok := r.(string); ok { + result = append(result, doc.ID+"#"+ref) + } + } + return result + } + return nil + } + + doc.Authentication = getVerificationRelationship(keyAuthentication) + doc.AssertionMethod = getVerificationRelationship(keyAssertionMethod) + doc.KeyAgreement = getVerificationRelationship(keyKeyAgreement) + doc.CapabilityInvocation = getVerificationRelationship(keyCapabilityInvocation) + doc.CapabilityDelegation = getVerificationRelationship(keyCapabilityDelegation) + + if services, ok := getMapValue(cborMap, keyService).([]any); ok { + for _, svcInterface := range services { + if svc, ok := svcInterface.(map[any]any); ok { + service := did.Service{ + ID: getMapValue(svc, keyServiceID).(string), + Type: getMapValue(svc, keyServiceType).(string), + ServiceEndpoint: getMapValue(svc, keyServiceEndpoint), + } + if enc := getMapValue(svc, keyServiceEnc); enc != nil { + service.Enc = enc + } + if sig := getMapValue(svc, keyServiceSig); sig != nil { + service.Sig = sig + } + doc.Services = append(doc.Services, service) + } + } + } + + if controller := getMapValue(cborMap, keyController); controller != nil { + doc.Controller = controller + } + + if alsoKnownAs, ok := getMapValue(cborMap, keyAlsoKnownAs).([]any); ok { + doc.AlsoKnownAs = make([]string, len(alsoKnownAs)) + var akas []string + for _, aka := range alsoKnownAs { + akas = append(akas, aka.(string)) + } + doc.AlsoKnownAs = akas + } + + if typesInterface := getMapValue(cborMap, keyTypes); typesInterface != nil { + switch typedTypes := typesInterface.(type) { + case []any: + for _, t := range typedTypes { + switch typedT := t.(type) { + case uint64: + types = append(types, TypeIndex(typedT)) + case int64: + types = append(types, TypeIndex(typedT)) + case float64: + types = append(types, TypeIndex(typedT)) + } + } + case []uint64: + for _, t := range typedTypes { + types = append(types, TypeIndex(t)) + } + case []int64: + for _, t := range typedTypes { + types = append(types, TypeIndex(t)) + } + } + } + + if gatewaysInterface := getMapValue(cborMap, keyGateways); gatewaysInterface != nil { + if gws, ok := gatewaysInterface.([]any); ok { + for _, g := range gws { + if gatewayString, ok := g.(string); ok { + gateways = append(gateways, AuthoritativeGateway(gatewayString)) + } + } + } + } + + if prev := getMapValue(cborMap, keyPreviousDID); prev != nil { + if prevMap, ok := prev.(map[any]any); ok { + previousDID = &PreviousDID{ + PreviousDID: DHT(getMapValue(prevMap, "did").(string)), + Signature: getMapValue(prevMap, "signature").(string), + } + } + } + + return &DIDDHTDocument{ + Doc: doc, + Types: types, + Gateways: gateways, + PreviousDID: previousDID, + }, nil +} + +func getMapValue(m map[any]any, key any) any { + if v, ok := m[key]; ok { + return v + } + + // If the key is a byte, try to find it as a uint64 + if byteKey, ok := key.(byte); ok { + if v, ok := m[uint64(byteKey)]; ok { + return v + } + } + + // If the key is a uint64, try to find it as a byte + if uint64Key, ok := key.(uint64); ok { + if v, ok := m[byte(uint64Key)]; ok { + return v + } + } + + return nil +} diff --git a/impl/internal/did/did_test.go b/impl/internal/did/did_test.go index b8b2bdff..11595f51 100644 --- a/impl/internal/did/did_test.go +++ b/impl/internal/did/did_test.go @@ -2,6 +2,7 @@ package did import ( "crypto/ed25519" + "encoding/hex" "fmt" "testing" @@ -128,6 +129,9 @@ func TestToDNSPacket(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, packet) + pb, _ := packet.Pack() + println("1 - DNS Length: ", len(pb)) + didDHTDoc, err := didID.FromDNSPacket(packet) require.NoError(t, err) require.NotEmpty(t, didDHTDoc) @@ -145,6 +149,39 @@ func TestToDNSPacket(t *testing.T) { assert.JSONEq(t, string(jsonDoc), string(jsonDecodedDoc)) }) + t.Run("simple doc - test to dns packet round trip - cbor", func(t *testing.T) { + privKey, doc, err := GenerateDIDDHT(CreateDIDDHTOpts{}) + require.NoError(t, err) + require.NotEmpty(t, privKey) + require.NotEmpty(t, doc) + + didID := DHT(doc.ID) + packet, err := didID.ToCBOR(*doc, nil, nil, nil) + require.NoError(t, err) + require.NotEmpty(t, packet) + + println("1 - CBOR Length: ", len(packet)) + + didDHTDoc, err := didID.FromCBOR(packet) + require.NoError(t, err) + + jsonDecodedDoc, err := json.Marshal(didDHTDoc.Doc) + require.NoError(t, err) + require.NotEmpty(t, didDHTDoc) + require.NotEmpty(t, didDHTDoc.Doc) + require.Empty(t, didDHTDoc.Types) + require.Empty(t, didDHTDoc.Gateways) + require.Empty(t, didDHTDoc.PreviousDID) + + jsonDoc, err := json.Marshal(doc) + require.NoError(t, err) + + jsonDecodedDoc, err = json.Marshal(didDHTDoc.Doc) + require.NoError(t, err) + + assert.JSONEq(t, string(jsonDoc), string(jsonDecodedDoc)) + }) + t.Run("doc with types and a gateway - test to dns packet round trip", func(t *testing.T) { privKey, doc, err := GenerateDIDDHT(CreateDIDDHTOpts{}) require.NoError(t, err) @@ -156,6 +193,9 @@ func TestToDNSPacket(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, packet) + pb, _ := packet.Pack() + println("2 - DNS Length: ", len(pb)) + didDHTDoc, err := didID.FromDNSPacket(packet) require.NoError(t, err) require.NotEmpty(t, didDHTDoc) @@ -169,6 +209,32 @@ func TestToDNSPacket(t *testing.T) { assert.EqualValues(t, *doc, didDHTDoc.Doc) }) + t.Run("doc with types and a gateway - test to dns packet round trip - cbor", func(t *testing.T) { + privKey, doc, err := GenerateDIDDHT(CreateDIDDHTOpts{}) + require.NoError(t, err) + require.NotEmpty(t, privKey) + require.NotEmpty(t, doc) + + didID := DHT(doc.ID) + packet, err := didID.ToCBOR(*doc, []TypeIndex{1, 2, 3}, []AuthoritativeGateway{"gateway1.example-did-dht-gateway.com."}, nil) + require.NoError(t, err) + require.NotEmpty(t, packet) + + println("2 - CBOR Length: ", len(packet)) + + didDHTDoc, err := didID.FromCBOR(packet) + require.NoError(t, err) + require.NotEmpty(t, didDHTDoc) + require.NotEmpty(t, didDHTDoc.Doc) + require.NotEmpty(t, didDHTDoc.Types) + require.Empty(t, didDHTDoc.PreviousDID) + require.Equal(t, didDHTDoc.Types, []TypeIndex{1, 2, 3}) + require.NotEmpty(t, didDHTDoc.Gateways) + require.Equal(t, didDHTDoc.Gateways, []AuthoritativeGateway{"gateway1.example-did-dht-gateway.com."}) + + assert.EqualValues(t, *doc, didDHTDoc.Doc) + }) + t.Run("doc with multiple keys and services - test to dns packet round trip", func(t *testing.T) { pubKey, _, err := crypto.GenerateSECP256k1Key() require.NoError(t, err) @@ -210,6 +276,10 @@ func TestToDNSPacket(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, packet) + pb, _ := packet.Pack() + println("3 - DNS Length: ", len(pb)) + + println(string(pb)) didDHTDoc, err := didID.FromDNSPacket(packet) require.NoError(t, err) require.NotEmpty(t, didDHTDoc) @@ -226,6 +296,70 @@ func TestToDNSPacket(t *testing.T) { assert.JSONEq(t, string(docJSON), string(decodedJSON)) }) + + t.Run("doc with multiple keys and services - test to dns packet round trip - cbor", func(t *testing.T) { + pubKey, _, err := crypto.GenerateSECP256k1Key() + require.NoError(t, err) + pubKeyJWK, err := jwx.PublicKeyToPublicKeyJWK(nil, pubKey) + require.NoError(t, err) + + opts := CreateDIDDHTOpts{ + VerificationMethods: []VerificationMethod{ + { + VerificationMethod: did.VerificationMethod{ + Type: cryptosuite.JSONWebKeyType, + PublicKeyJWK: pubKeyJWK, + }, + Purposes: []did.PublicKeyPurpose{did.AssertionMethod, did.CapabilityInvocation}, + }, + }, + Services: []did.Service{ + { + ID: "vcs", + Type: "VerifiableCredentialService", + ServiceEndpoint: []string{"https://example.com/vc/"}, + Sig: []string{"1", "2"}, + Enc: "3", + }, + { + ID: "hub", + Type: "MessagingService", + ServiceEndpoint: []string{"https://example.com/hub/", "https://example.com/hub2/"}, + }, + }, + } + privKey, doc, err := GenerateDIDDHT(opts) + require.NoError(t, err) + require.NotEmpty(t, privKey) + require.NotEmpty(t, doc) + + didID := DHT(doc.ID) + packet, err := didID.ToCBOR(*doc, nil, nil, nil) + require.NoError(t, err) + require.NotEmpty(t, packet) + + println("3 - CBOR Length: ", len(packet)) + + // convert bytes to hex + s := hex.EncodeToString(packet) + println(s) + + didDHTDoc, err := didID.FromCBOR(packet) + require.NoError(t, err) + require.NotEmpty(t, didDHTDoc) + require.NotEmpty(t, didDHTDoc.Doc) + require.Empty(t, didDHTDoc.Types) + require.Empty(t, didDHTDoc.Gateways) + require.Empty(t, didDHTDoc.PreviousDID) + + decodedJSON, err := json.Marshal(didDHTDoc.Doc) + require.NoError(t, err) + + docJSON, err := json.Marshal(doc) + require.NoError(t, err) + + assert.JSONEq(t, string(docJSON), string(decodedJSON)) + }) } func TestDIDDHTFeatures(t *testing.T) {