diff --git a/go/keyderivation/keyset_deriver_factory.go b/go/keyderivation/keyset_deriver_factory.go index fca706ca78..6c030b9a3a 100644 --- a/go/keyderivation/keyset_deriver_factory.go +++ b/go/keyderivation/keyset_deriver_factory.go @@ -52,11 +52,9 @@ func newWrappedKeysetDeriver(ps *primitiveset.PrimitiveSet) (*wrappedKeysetDeriv if _, ok := (ps.Primary.Primitive).(KeysetDeriver); !ok { return nil, errNotKeysetDeriverPrimitive } - for _, primitives := range ps.Entries { - for _, p := range primitives { - if _, ok := (p.Primitive).(KeysetDeriver); !ok { - return nil, errNotKeysetDeriverPrimitive - } + for _, p := range ps.EntriesInKeysetOrder { + if _, ok := (p.Primitive).(KeysetDeriver); !ok { + return nil, errNotKeysetDeriverPrimitive } } return &wrappedKeysetDeriver{ps: ps}, nil @@ -64,32 +62,30 @@ func newWrappedKeysetDeriver(ps *primitiveset.PrimitiveSet) (*wrappedKeysetDeriv func (w *wrappedKeysetDeriver) DeriveKeyset(salt []byte) (*keyset.Handle, error) { keys := []*tinkpb.Keyset_Key{} - for _, entriesWithSamePrefix := range w.ps.Entries { - for _, e := range entriesWithSamePrefix { - p, ok := (e.Primitive).(KeysetDeriver) - if !ok { - return nil, errNotKeysetDeriverPrimitive - } - handle, err := p.DeriveKeyset(salt) - if err != nil { - return nil, errors.New("keyset_deriver_factory: keyset derivation failed") - } - if len(handle.KeysetInfo().GetKeyInfo()) != 1 { - return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key") - } - ks := insecurecleartextkeyset.KeysetMaterial(handle) - if len(ks.GetKey()) != 1 { - return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key") - } - // Set all fields, except for KeyData, to match the Entry's in the keyset. - key := &tinkpb.Keyset_Key{ - KeyData: ks.GetKey()[0].GetKeyData(), - Status: e.Status, - KeyId: e.KeyID, - OutputPrefixType: e.PrefixType, - } - keys = append(keys, key) + for _, e := range w.ps.EntriesInKeysetOrder { + p, ok := (e.Primitive).(KeysetDeriver) + if !ok { + return nil, errNotKeysetDeriverPrimitive } + handle, err := p.DeriveKeyset(salt) + if err != nil { + return nil, errors.New("keyset_deriver_factory: keyset derivation failed") + } + if len(handle.KeysetInfo().GetKeyInfo()) != 1 { + return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key") + } + ks := insecurecleartextkeyset.KeysetMaterial(handle) + if len(ks.GetKey()) != 1 { + return nil, errors.New("keyset_deriver_factory: primitive must derive keyset handle with exactly one key") + } + // Set all fields, except for KeyData, to match the Entry's in the keyset. + key := &tinkpb.Keyset_Key{ + KeyData: ks.GetKey()[0].GetKeyData(), + Status: e.Status, + KeyId: e.KeyID, + OutputPrefixType: e.PrefixType, + } + keys = append(keys, key) } ks := &tinkpb.Keyset{ PrimaryKeyId: w.ps.Primary.KeyID, diff --git a/go/keyderivation/keyset_deriver_factory_test.go b/go/keyderivation/keyset_deriver_factory_test.go index a96816ebb6..b833ef0794 100644 --- a/go/keyderivation/keyset_deriver_factory_test.go +++ b/go/keyderivation/keyset_deriver_factory_test.go @@ -33,18 +33,19 @@ type invalidDeriver struct{} var _ KeysetDeriver = (*invalidDeriver)(nil) func (i *invalidDeriver) DeriveKeyset(salt []byte) (*keyset.Handle, error) { - kh, err := keyset.NewHandle(aead.AES128GCMKeyTemplate()) + manager := keyset.NewManager() + keyID, err := manager.Add(aead.AES128GCMKeyTemplate()) if err != nil { return nil, err } - manager := keyset.NewManagerFromHandle(kh) - if _, err := manager.Add(aead.AES128GCMKeyTemplate()); err != nil { + manager.SetPrimary(keyID) + if _, err = manager.Add(aead.AES256GCMKeyTemplate()); err != nil { return nil, err } - return kh, nil + return manager.Handle() } -func TestInvalidKeysetDeriverImplementationFails(t *testing.T) { +func TestDeriveKeysetWithInvalidPrimitiveImplementationFails(t *testing.T) { entry := &primitiveset.Entry{ KeyID: 119, Primitive: &invalidDeriver{}, @@ -57,14 +58,31 @@ func TestInvalidKeysetDeriverImplementationFails(t *testing.T) { Entries: map[string][]*primitiveset.Entry{ cryptofmt.RawPrefix: []*primitiveset.Entry{entry}, }, + EntriesInKeysetOrder: []*primitiveset.Entry{entry}, } - wkd, err := newWrappedKeysetDeriver(ps) + wrappedDeriver, err := newWrappedKeysetDeriver(ps) if err != nil { t.Fatalf("newWrappedKeysetDeriver() err = %v, want nil", err) } - if _, err := wkd.DeriveKeyset([]byte("salt")); err == nil { - t.Error("DeriveKeyset() err = nil, want non-nil") - } else if !strings.Contains(err.Error(), "exactly one key") { + _, err = wrappedDeriver.DeriveKeyset([]byte("salt")) + if err == nil { + t.Fatal("DeriveKeyset() err = nil, want non-nil") + } + if !strings.Contains(err.Error(), "exactly one key") { t.Errorf("DeriveKeyset() err = %q, doesn't contain %q", err, "exactly one key") } } + +func TestNewWrappedKeysetDeriverWrongPrimitiveFails(t *testing.T) { + handle, err := keyset.NewHandle(aead.AES128GCMKeyTemplate()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + ps, err := handle.Primitives() + if err != nil { + t.Fatalf("handle.Primitives() err = %v, want nil", err) + } + if _, err := newWrappedKeysetDeriver(ps); err == nil { + t.Errorf("newWrappedKeysetDeriver() err = nil, want non-nil") + } +} diff --git a/go/keyderivation/keyset_deriver_factory_x_test.go b/go/keyderivation/keyset_deriver_factory_x_test.go index 7148577321..c403a33089 100644 --- a/go/keyderivation/keyset_deriver_factory_x_test.go +++ b/go/keyderivation/keyset_deriver_factory_x_test.go @@ -31,63 +31,98 @@ import ( ) func TestWrappedKeysetDeriver(t *testing.T) { - // Construct a deriving keyset handle containing one key. - sha256AES128GCMkeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{ + // Construct deriving keyset handle containing one key. + aes128GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{ PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(), Params: &prfderpb.PrfBasedDeriverParams{ DerivedKeyTemplate: aead.AES128GCMKeyTemplate(), }, - } - serializedKeyFormat, err := proto.Marshal(sha256AES128GCMkeyFormat) + }) if err != nil { - t.Fatalf("proto.Marshal(%v) err = %v, want nil", sha256AES128GCMkeyFormat, err) + t.Fatalf("proto.Marshal(aes128GCMKeyFormat) err = %v, want nil", err) } - template := &tinkpb.KeyTemplate{ + singleKeyHandle, err := keyset.NewHandle(&tinkpb.KeyTemplate{ TypeUrl: prfBasedDeriverTypeURL, OutputPrefixType: tinkpb.OutputPrefixType_RAW, - Value: serializedKeyFormat, - } - singleKeyHandle, err := keyset.NewHandle(template) + Value: aes128GCMKeyFormat, + }) if err != nil { t.Fatalf("keyset.NewHandle() err = %v, want nil", err) } - // Construct a deriving keyset handle containing two different types of keys. - sha256AES256GCMNoPrefixKeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{ + // Construct deriving keyset handle containing three keys. + xChaChaKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{ PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(), Params: &prfderpb.PrfBasedDeriverParams{ - DerivedKeyTemplate: aead.AES256GCMNoPrefixKeyTemplate(), + DerivedKeyTemplate: aead.XChaCha20Poly1305KeyTemplate(), }, + }) + if err != nil { + t.Fatalf("proto.Marshal(xChaChaKeyFormat) err = %v, want nil", err) } - serializedKeyFormat, err = proto.Marshal(sha256AES256GCMNoPrefixKeyFormat) + aes256GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{ + PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(), + Params: &prfderpb.PrfBasedDeriverParams{ + DerivedKeyTemplate: aead.AES256GCMKeyTemplate(), + }, + }) if err != nil { - t.Fatalf("proto.Marshal(%v) err = %v, want nil", sha256AES256GCMNoPrefixKeyFormat, err) + t.Fatalf("proto.Marshal(aes256GCMKeyFormat) err = %v, want nil", err) } - template = &tinkpb.KeyTemplate{ + manager := keyset.NewManager() + aes128GCMKeyID, err := manager.Add(&tinkpb.KeyTemplate{ TypeUrl: prfBasedDeriverTypeURL, OutputPrefixType: tinkpb.OutputPrefixType_RAW, - Value: serializedKeyFormat, + Value: aes128GCMKeyFormat, + }) + if err != nil { + t.Fatalf("manager.Add(aes128GCMTemplate) err = %v, want nil", err) } - manager := keyset.NewManagerFromHandle(singleKeyHandle) - if _, err := manager.Add(template); err != nil { - t.Fatalf("manager.Add() err = %v, want nil", err) + if err := manager.SetPrimary(aes128GCMKeyID); err != nil { + t.Fatalf("manager.SetPrimary() err = %v, want nil", err) + } + if _, err := manager.Add(&tinkpb.KeyTemplate{ + TypeUrl: prfBasedDeriverTypeURL, + OutputPrefixType: tinkpb.OutputPrefixType_TINK, + Value: xChaChaKeyFormat, + }); err != nil { + t.Fatalf("manager.Add(xChaChaTemplate) err = %v, want nil", err) + } + if _, err := manager.Add(&tinkpb.KeyTemplate{ + TypeUrl: prfBasedDeriverTypeURL, + OutputPrefixType: tinkpb.OutputPrefixType_CRUNCHY, + Value: aes256GCMKeyFormat, + }); err != nil { + t.Fatalf("manager.Add(aes256GCMTemplate) err = %v, want nil", err) } multipleKeysHandle, err := manager.Handle() if err != nil { - t.Fatalf("manager.Handle() err %v, want nil", err) + t.Fatalf("manager.Handle() err = %v, want nil", err) + } + if got, want := len(multipleKeysHandle.KeysetInfo().GetKeyInfo()), 3; got != want { + t.Fatalf("len(multipleKeysHandle) = %d, want %d", got, want) } for _, test := range []struct { - name string - handle *keyset.Handle + name string + handle *keyset.Handle + wantTypeURLs []string }{ { name: "single key", handle: singleKeyHandle, + wantTypeURLs: []string{ + "type.googleapis.com/google.crypto.tink.AesGcmKey", + }, }, { name: "multiple keys", handle: multipleKeysHandle, + wantTypeURLs: []string{ + "type.googleapis.com/google.crypto.tink.AesGcmKey", + "type.googleapis.com/google.crypto.tink.XChaCha20Poly1305Key", + "type.googleapis.com/google.crypto.tink.AesGcmKey", + }, }, } { t.Run(test.name, func(t *testing.T) { @@ -107,30 +142,25 @@ func TestWrappedKeysetDeriver(t *testing.T) { if len(derivedKeyInfo) != len(keyInfo) { t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo)) } + if len(derivedKeyInfo) != len(test.wantTypeURLs) { + t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo)) + } - // Verify each derived key. + // Verify derived keys. hasPrimaryKey := false - for _, derivedKey := range derivedKeyInfo { - if derivedKey.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW { - t.Errorf("GetOutputPrefixType() = %s, want %s", derivedKey.GetOutputPrefixType(), tinkpb.OutputPrefixType_RAW) + for i, derivedKey := range derivedKeyInfo { + derivingKey := keyInfo[i] + if got, want := derivedKey.GetOutputPrefixType(), derivingKey.GetOutputPrefixType(); got != want { + t.Errorf("GetOutputPrefixType() = %s, want %s", got, want) + } + if got, want := derivedKey.GetKeyId(), derivingKey.GetKeyId(); got != want { + t.Errorf("GetKeyId() = %d, want %d", got, want) } - // Verify each derived key has the same key ID as a deriving key. - hasMatchingDerivingKey := false - for _, key := range keyInfo { - if key.GetKeyId() == derivedKey.GetKeyId() { - hasMatchingDerivingKey = true - } else { - continue - } - if got, want := derivedKey.GetTypeUrl(), "type.googleapis.com/google.crypto.tink.AesGcmKey"; got != want { - t.Errorf("GetTypeUrl() = %q, want %q", got, want) - } - if derivedKey.GetStatus() != key.GetStatus() { - t.Errorf("GetStatus() = %s, want %s", derivedKey.GetStatus(), key.GetStatus()) - } + if got, want := derivedKey.GetTypeUrl(), test.wantTypeURLs[i]; got != want { + t.Errorf("GetTypeUrl() = %q, want %q", got, want) } - if !hasMatchingDerivingKey { - t.Errorf("derived key has no matching deriving key") + if got, want := derivedKey.GetStatus(), derivingKey.GetStatus(); got != want { + t.Errorf("GetStatus() = %s, want %s", got, want) } if derivedKey.GetKeyId() == derivedHandle.KeysetInfo().GetPrimaryKeyId() { hasPrimaryKey = true