diff --git a/internal/darwin/security/security_darwin.go b/internal/darwin/security/security_darwin.go index 70cd570a..2c0c2123 100644 --- a/internal/darwin/security/security_darwin.go +++ b/internal/darwin/security/security_darwin.go @@ -84,6 +84,7 @@ var ( KSecPublicKeyAttrs = cf.TypeRef(C.kSecPublicKeyAttrs) KSecPrivateKeyAttrs = cf.TypeRef(C.kSecPrivateKeyAttrs) KSecReturnRef = cf.TypeRef(C.kSecReturnRef) + KSecReturnAttributes = cf.TypeRef(C.kSecReturnAttributes) KSecValueRef = cf.TypeRef(C.kSecValueRef) KSecValueData = cf.TypeRef(C.kSecValueData) ) @@ -139,6 +140,20 @@ const ( KSecAccessControlOr = SecAccessControlCreateFlags(C.kSecAccessControlOr) ) +type SecKeychainItemRef struct { + Value C.SecKeychainItemRef +} + +func NewSecKeychainItemRef(ref cf.TypeRef) *SecKeychainItemRef { + return &SecKeychainItemRef{ + Value: C.SecKeychainItemRef(ref), + } +} + +func (v *SecKeychainItemRef) Release() { cf.Release(v) } +func (v *SecKeychainItemRef) TypeRef() cf.CFTypeRef { return cf.CFTypeRef(v.Value) } +func (v *SecKeychainItemRef) Retain() { cf.Retain(v) } + type SecKeyRef struct { Value C.SecKeyRef } @@ -312,12 +327,11 @@ func GetSecAttrApplicationLabel(v *cf.DictionaryRef) []byte { } func GetSecAttrApplicationTag(v *cf.DictionaryRef) string { - ref := C.CFStringRef(C.CFDictionaryGetValue(C.CFDictionaryRef(v.Value), unsafe.Pointer(C.kSecAttrApplicationTag))) - tag := "" - if cstr := C.CFStringGetCStringPtr(ref, C.kCFStringEncodingUTF8); cstr != nil { - tag = C.GoString(cstr) - } - return tag + data := C.CFDataRef(C.CFDictionaryGetValue(C.CFDictionaryRef(v.Value), unsafe.Pointer(C.kSecAttrApplicationTag))) + return string(C.GoBytes( + unsafe.Pointer(C.CFDataGetBytePtr(data)), + C.int(C.CFDataGetLength(data)), + )) } func GetSecAttrLabel(v *cf.DictionaryRef) string { diff --git a/kms/mackms/mackms.go b/kms/mackms/mackms.go index 4cbf1f36..ebfcab45 100644 --- a/kms/mackms/mackms.go +++ b/kms/mackms/mackms.go @@ -555,7 +555,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return nil } -func (*MacKMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { +func (k *MacKMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysResponse, error) { if req.Query == "" { return nil, fmt.Errorf("searchKeysRequest 'query' cannot be empty") } @@ -572,37 +572,18 @@ func (*MacKMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespon results := make([]apiv1.SearchKeyResult, len(keys)) for i, key := range keys { - pub, hash, err := extractPublicKey(key) - if err != nil { - return nil, fmt.Errorf("failed extracting public key: %w", err) - } - - attrs := security.SecKeyCopyAttributes(key) - defer attrs.Release() - - h := security.GetSecAttrApplicationLabel(attrs) - fmt.Println("h", h, string(h)) // TODO: remove debugging - - a := security.GetSecAttrApplicationTag(attrs) - fmt.Println("a", a, string(a)) - - l := security.GetSecAttrLabel(attrs) - fmt.Println("l", l) + d := cf.NewDictionaryRef(cf.TypeRef(key.TypeRef())) + defer d.Release() - j := attrs.XML() - fmt.Println("j", string(j)) + fmt.Println(string(d.XML())) // TODO: remove debug name := uri.New(Scheme, url.Values{ - "hash": []string{hex.EncodeToString(hash)}, + "hash": []string{hex.EncodeToString(security.GetSecAttrApplicationLabel(d))}, + "label": []string{security.GetSecAttrLabel(d)}, + "tag": []string{security.GetSecAttrApplicationTag(d)}, }) - // TODO: should we rely on the values from u only? Or can we get them from key properties too? - if u.label != "" { - name.Values.Set("label", u.label) - } - if u.tag != "" { - name.Values.Set("tag", u.tag) - } + // TODO: extract those from the attributes too if u.useSecureEnclave { name.Values.Set("se", "true") } @@ -610,6 +591,15 @@ func (*MacKMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespon name.Values.Set("bio", "true") } + // obtain the public key by requesting it, as the current + // representation of the key are just the attributes. + pub, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: name.String(), + }) + if err != nil { + return nil, fmt.Errorf("failed getting public key: %w", err) + } + results[i] = apiv1.SearchKeyResult{ Name: name.String(), PublicKey: pub, @@ -617,9 +607,6 @@ func (*MacKMS) SearchKeys(req *apiv1.SearchKeysRequest) (*apiv1.SearchKeysRespon SigningKey: name.String(), }, } - - // ensure the resource is released - key.Release() } return &apiv1.SearchKeysResponse{ @@ -699,12 +686,12 @@ func getPrivateKey(u *keyAttributes) (*security.SecKeyRef, error) { return security.NewSecKeyRef(key), nil } -func getPrivateKeys(u *keyAttributes) ([]*security.SecKeyRef, error) { +func getPrivateKeys(u *keyAttributes) ([]*security.SecKeychainItemRef, error) { dict := cf.Dictionary{ - security.KSecClass: security.KSecClassKey, - security.KSecAttrKeyClass: security.KSecAttrKeyClassPrivate, - security.KSecReturnRef: cf.True, - security.KSecMatchLimit: security.KSecMatchLimitAll, + security.KSecClass: security.KSecClassKey, + security.KSecAttrKeyClass: security.KSecAttrKeyClassPrivate, + security.KSecReturnAttributes: cf.True, // return keychain attributes, i.e. tag and label + security.KSecMatchLimit: security.KSecMatchLimitAll, } if u.tag != "" { @@ -749,10 +736,10 @@ func getPrivateKeys(u *keyAttributes) ([]*security.SecKeyRef, error) { array := cf.NewArrayRef(result) defer array.Release() - keys := make([]*security.SecKeyRef, array.Len()) + keys := make([]*security.SecKeychainItemRef, array.Len()) for i := 0; i < array.Len(); i++ { item := array.Get(i) - key := security.NewSecKeyRef(item) + key := security.NewSecKeychainItemRef(item) key.Retain() // retain the key, so that it's not released early keys[i] = key } diff --git a/kms/mackms/mackms_test.go b/kms/mackms/mackms_test.go index e265a0d9..d233dc52 100644 --- a/kms/mackms/mackms_test.go +++ b/kms/mackms/mackms_test.go @@ -1268,9 +1268,11 @@ func TestMacKMS_SearchKeys(t *testing.T) { for _, key := range got.Results { u, err := uri.ParseWithScheme(Scheme, key.Name) require.NoError(t, err) + assert.Equal(t, tag, u.Get("tag")) if hash := u.Get("hash"); hash != "" { hashes = append(hashes, hash) } + } assert.Equal(t, expectedHashes, hashes)