Skip to content
This repository has been archived by the owner on Apr 17, 2024. It is now read-only.

Commit

Permalink
Preserve keyset key order during PRF-based key derivation in Go.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 533317761
  • Loading branch information
cindylindeed authored and copybara-github committed May 19, 2023
1 parent a9ab106 commit fa543d8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 81 deletions.
56 changes: 26 additions & 30 deletions go/keyderivation/keyset_deriver_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,44 +52,40 @@ func newWrappedKeysetDeriver(ps *primitiveset.PrimitiveSet) (*wrappedKeysetDeriv
if _, ok := (ps.Primary.Primitive).(KeysetDeriver); !ok {
return nil, errNotKeysetDeriverPrimitive
}
for _, primitives := range ps.Entries {
for _, p := range primitives {
if _, ok := (p.Primitive).(KeysetDeriver); !ok {
return nil, errNotKeysetDeriverPrimitive
}
for _, p := range ps.EntriesInKeysetOrder {
if _, ok := (p.Primitive).(KeysetDeriver); !ok {
return nil, errNotKeysetDeriverPrimitive
}
}
return &wrappedKeysetDeriver{ps: ps}, nil
}

func (w *wrappedKeysetDeriver) DeriveKeyset(salt []byte) (*keyset.Handle, error) {
keys := []*tinkpb.Keyset_Key{}
for _, entriesWithSamePrefix := range w.ps.Entries {
for _, e := range entriesWithSamePrefix {
p, ok := (e.Primitive).(KeysetDeriver)
if !ok {
return nil, errNotKeysetDeriverPrimitive
}
handle, err := p.DeriveKeyset(salt)
if err != nil {
return nil, errors.New("keyset_deriver_factory: keyset derivation failed")
}
if len(handle.KeysetInfo().GetKeyInfo()) != 1 {
return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key")
}
ks := insecurecleartextkeyset.KeysetMaterial(handle)
if len(ks.GetKey()) != 1 {
return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key")
}
// Set all fields, except for KeyData, to match the Entry's in the keyset.
key := &tinkpb.Keyset_Key{
KeyData: ks.GetKey()[0].GetKeyData(),
Status: e.Status,
KeyId: e.KeyID,
OutputPrefixType: e.PrefixType,
}
keys = append(keys, key)
for _, e := range w.ps.EntriesInKeysetOrder {
p, ok := (e.Primitive).(KeysetDeriver)
if !ok {
return nil, errNotKeysetDeriverPrimitive
}
handle, err := p.DeriveKeyset(salt)
if err != nil {
return nil, errors.New("keyset_deriver_factory: keyset derivation failed")
}
if len(handle.KeysetInfo().GetKeyInfo()) != 1 {
return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key")
}
ks := insecurecleartextkeyset.KeysetMaterial(handle)
if len(ks.GetKey()) != 1 {
return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key")
}
// Set all fields, except for KeyData, to match the Entry's in the keyset.
key := &tinkpb.Keyset_Key{
KeyData: ks.GetKey()[0].GetKeyData(),
Status: e.Status,
KeyId: e.KeyID,
OutputPrefixType: e.PrefixType,
}
keys = append(keys, key)
}
ks := &tinkpb.Keyset{
PrimaryKeyId: w.ps.Primary.KeyID,
Expand Down
36 changes: 27 additions & 9 deletions go/keyderivation/keyset_deriver_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ type invalidDeriver struct{}
var _ KeysetDeriver = (*invalidDeriver)(nil)

func (i *invalidDeriver) DeriveKeyset(salt []byte) (*keyset.Handle, error) {
kh, err := keyset.NewHandle(aead.AES128GCMKeyTemplate())
manager := keyset.NewManager()
keyID, err := manager.Add(aead.AES128GCMKeyTemplate())
if err != nil {
return nil, err
}
manager := keyset.NewManagerFromHandle(kh)
if _, err := manager.Add(aead.AES128GCMKeyTemplate()); err != nil {
manager.SetPrimary(keyID)
if _, err = manager.Add(aead.AES256GCMKeyTemplate()); err != nil {
return nil, err
}
return kh, nil
return manager.Handle()
}

func TestInvalidKeysetDeriverImplementationFails(t *testing.T) {
func TestDeriveKeysetWithInvalidPrimitiveImplementationFails(t *testing.T) {
entry := &primitiveset.Entry{
KeyID: 119,
Primitive: &invalidDeriver{},
Expand All @@ -57,14 +58,31 @@ func TestInvalidKeysetDeriverImplementationFails(t *testing.T) {
Entries: map[string][]*primitiveset.Entry{
cryptofmt.RawPrefix: []*primitiveset.Entry{entry},
},
EntriesInKeysetOrder: []*primitiveset.Entry{entry},
}
wkd, err := newWrappedKeysetDeriver(ps)
wrappedDeriver, err := newWrappedKeysetDeriver(ps)
if err != nil {
t.Fatalf("newWrappedKeysetDeriver() err = %v, want nil", err)
}
if _, err := wkd.DeriveKeyset([]byte("salt")); err == nil {
t.Error("DeriveKeyset() err = nil, want non-nil")
} else if !strings.Contains(err.Error(), "exactly one key") {
_, err = wrappedDeriver.DeriveKeyset([]byte("salt"))
if err == nil {
t.Fatal("DeriveKeyset() err = nil, want non-nil")
}
if !strings.Contains(err.Error(), "exactly one key") {
t.Errorf("DeriveKeyset() err = %q, doesn't contain %q", err, "exactly one key")
}
}

func TestNewWrappedKeysetDeriverWrongPrimitiveFails(t *testing.T) {
handle, err := keyset.NewHandle(aead.AES128GCMKeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}
ps, err := handle.Primitives()
if err != nil {
t.Fatalf("handle.Primitives() err = %v, want nil", err)
}
if _, err := newWrappedKeysetDeriver(ps); err == nil {
t.Errorf("newWrappedKeysetDeriver() err = nil, want non-nil")
}
}
114 changes: 72 additions & 42 deletions go/keyderivation/keyset_deriver_factory_x_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,63 +31,98 @@ import (
)

func TestWrappedKeysetDeriver(t *testing.T) {
// Construct a deriving keyset handle containing one key.
sha256AES128GCMkeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
// Construct deriving keyset handle containing one key.
aes128GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
Params: &prfderpb.PrfBasedDeriverParams{
DerivedKeyTemplate: aead.AES128GCMKeyTemplate(),
},
}
serializedKeyFormat, err := proto.Marshal(sha256AES128GCMkeyFormat)
})
if err != nil {
t.Fatalf("proto.Marshal(%v) err = %v, want nil", sha256AES128GCMkeyFormat, err)
t.Fatalf("proto.Marshal(aes128GCMKeyFormat) err = %v, want nil", err)
}
template := &tinkpb.KeyTemplate{
singleKeyHandle, err := keyset.NewHandle(&tinkpb.KeyTemplate{
TypeUrl: prfBasedDeriverTypeURL,
OutputPrefixType: tinkpb.OutputPrefixType_RAW,
Value: serializedKeyFormat,
}
singleKeyHandle, err := keyset.NewHandle(template)
Value: aes128GCMKeyFormat,
})
if err != nil {
t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
}

// Construct a deriving keyset handle containing two different types of keys.
sha256AES256GCMNoPrefixKeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
// Construct deriving keyset handle containing three keys.
xChaChaKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
Params: &prfderpb.PrfBasedDeriverParams{
DerivedKeyTemplate: aead.AES256GCMNoPrefixKeyTemplate(),
DerivedKeyTemplate: aead.XChaCha20Poly1305KeyTemplate(),
},
})
if err != nil {
t.Fatalf("proto.Marshal(xChaChaKeyFormat) err = %v, want nil", err)
}
serializedKeyFormat, err = proto.Marshal(sha256AES256GCMNoPrefixKeyFormat)
aes256GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
Params: &prfderpb.PrfBasedDeriverParams{
DerivedKeyTemplate: aead.AES256GCMKeyTemplate(),
},
})
if err != nil {
t.Fatalf("proto.Marshal(%v) err = %v, want nil", sha256AES256GCMNoPrefixKeyFormat, err)
t.Fatalf("proto.Marshal(aes256GCMKeyFormat) err = %v, want nil", err)
}
template = &tinkpb.KeyTemplate{
manager := keyset.NewManager()
aes128GCMKeyID, err := manager.Add(&tinkpb.KeyTemplate{
TypeUrl: prfBasedDeriverTypeURL,
OutputPrefixType: tinkpb.OutputPrefixType_RAW,
Value: serializedKeyFormat,
Value: aes128GCMKeyFormat,
})
if err != nil {
t.Fatalf("manager.Add(aes128GCMTemplate) err = %v, want nil", err)
}
manager := keyset.NewManagerFromHandle(singleKeyHandle)
if _, err := manager.Add(template); err != nil {
t.Fatalf("manager.Add() err = %v, want nil", err)
if err := manager.SetPrimary(aes128GCMKeyID); err != nil {
t.Fatalf("manager.SetPrimary() err = %v, want nil", err)
}
if _, err := manager.Add(&tinkpb.KeyTemplate{
TypeUrl: prfBasedDeriverTypeURL,
OutputPrefixType: tinkpb.OutputPrefixType_TINK,
Value: xChaChaKeyFormat,
}); err != nil {
t.Fatalf("manager.Add(xChaChaTemplate) err = %v, want nil", err)
}
if _, err := manager.Add(&tinkpb.KeyTemplate{
TypeUrl: prfBasedDeriverTypeURL,
OutputPrefixType: tinkpb.OutputPrefixType_CRUNCHY,
Value: aes256GCMKeyFormat,
}); err != nil {
t.Fatalf("manager.Add(aes256GCMTemplate) err = %v, want nil", err)
}
multipleKeysHandle, err := manager.Handle()
if err != nil {
t.Fatalf("manager.Handle() err %v, want nil", err)
t.Fatalf("manager.Handle() err = %v, want nil", err)
}
if got, want := len(multipleKeysHandle.KeysetInfo().GetKeyInfo()), 3; got != want {
t.Fatalf("len(multipleKeysHandle) = %d, want %d", got, want)
}

for _, test := range []struct {
name string
handle *keyset.Handle
name string
handle *keyset.Handle
wantTypeURLs []string
}{
{
name: "single key",
handle: singleKeyHandle,
wantTypeURLs: []string{
"type.googleapis.com/google.crypto.tink.AesGcmKey",
},
},
{
name: "multiple keys",
handle: multipleKeysHandle,
wantTypeURLs: []string{
"type.googleapis.com/google.crypto.tink.AesGcmKey",
"type.googleapis.com/google.crypto.tink.XChaCha20Poly1305Key",
"type.googleapis.com/google.crypto.tink.AesGcmKey",
},
},
} {
t.Run(test.name, func(t *testing.T) {
Expand All @@ -107,30 +142,25 @@ func TestWrappedKeysetDeriver(t *testing.T) {
if len(derivedKeyInfo) != len(keyInfo) {
t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo))
}
if len(derivedKeyInfo) != len(test.wantTypeURLs) {
t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo))
}

// Verify each derived key.
// Verify derived keys.
hasPrimaryKey := false
for _, derivedKey := range derivedKeyInfo {
if derivedKey.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW {
t.Errorf("GetOutputPrefixType() = %s, want %s", derivedKey.GetOutputPrefixType(), tinkpb.OutputPrefixType_RAW)
for i, derivedKey := range derivedKeyInfo {
derivingKey := keyInfo[i]
if got, want := derivedKey.GetOutputPrefixType(), derivingKey.GetOutputPrefixType(); got != want {
t.Errorf("GetOutputPrefixType() = %s, want %s", got, want)
}
if got, want := derivedKey.GetKeyId(), derivingKey.GetKeyId(); got != want {
t.Errorf("GetKeyId() = %d, want %d", got, want)
}
// Verify each derived key has the same key ID as a deriving key.
hasMatchingDerivingKey := false
for _, key := range keyInfo {
if key.GetKeyId() == derivedKey.GetKeyId() {
hasMatchingDerivingKey = true
} else {
continue
}
if got, want := derivedKey.GetTypeUrl(), "type.googleapis.com/google.crypto.tink.AesGcmKey"; got != want {
t.Errorf("GetTypeUrl() = %q, want %q", got, want)
}
if derivedKey.GetStatus() != key.GetStatus() {
t.Errorf("GetStatus() = %s, want %s", derivedKey.GetStatus(), key.GetStatus())
}
if got, want := derivedKey.GetTypeUrl(), test.wantTypeURLs[i]; got != want {
t.Errorf("GetTypeUrl() = %q, want %q", got, want)
}
if !hasMatchingDerivingKey {
t.Errorf("derived key has no matching deriving key")
if got, want := derivedKey.GetStatus(), derivingKey.GetStatus(); got != want {
t.Errorf("GetStatus() = %s, want %s", got, want)
}
if derivedKey.GetKeyId() == derivedHandle.KeysetInfo().GetPrimaryKeyId() {
hasPrimaryKey = true
Expand Down

0 comments on commit fa543d8

Please sign in to comment.