diff --git a/pkg/didcomm/common/didconnection/api.go b/pkg/didcomm/common/didconnection/api.go index 4ade1556d..4b7ba542c 100644 --- a/pkg/didcomm/common/didconnection/api.go +++ b/pkg/didcomm/common/didconnection/api.go @@ -5,7 +5,11 @@ SPDX-License-Identifier: Apache-2.0 package didconnection -import diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" +import ( + "errors" + + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" +) // Store stores DIDs indexed by public key, so agents can find the DID associated with a given key. type Store interface { @@ -13,10 +17,11 @@ type Store interface { SaveDID(did string, keys ...string) error // GetDID gets the DID stored under the given key GetDID(key string) (string, error) - // SaveDIDConnection saves a connection between this agent's DID and another agent's DID - SaveDIDConnection(myDID, theirDID string, theirKeys []string) error // SaveDIDByResolving resolves a DID using the VDR then saves the map from keys -> did - SaveDIDByResolving(did, serviceType, keyType string) error + SaveDIDByResolving(did string, keys ...string) error // SaveDIDFromDoc saves a map from keys -> did for a did doc - SaveDIDFromDoc(doc *diddoc.Doc, serviceType, keyType string) error + SaveDIDFromDoc(doc *diddoc.Doc) error } + +// ErrNotFound signals that the entry for the given DID and key is not present in the store. +var ErrNotFound = errors.New("did not found under given key") diff --git a/pkg/didcomm/common/didconnection/didconnection.go b/pkg/didcomm/common/didconnection/didconnection.go index f20b90b8a..d5b3eb51b 100644 --- a/pkg/didcomm/common/didconnection/didconnection.go +++ b/pkg/didcomm/common/didconnection/didconnection.go @@ -16,8 +16,8 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/storage" ) -// BaseDIDConnectionStore stores DIDs indexed by key -type BaseDIDConnectionStore struct { +// ConnectionStore stores DIDs indexed by key +type ConnectionStore struct { store storage.Store vdr vdri.Registry } @@ -32,17 +32,17 @@ type provider interface { } // New returns a new did lookup Store -func New(ctx provider) (*BaseDIDConnectionStore, error) { - store, err := ctx.StorageProvider().OpenStore("con-store") +func New(ctx provider) (*ConnectionStore, error) { + store, err := ctx.StorageProvider().OpenStore("didconnection") if err != nil { return nil, err } - return &BaseDIDConnectionStore{store: store, vdr: ctx.VDRIRegistry()}, nil + return &ConnectionStore{store: store, vdr: ctx.VDRIRegistry()}, nil } // saveDID saves a DID, indexed using the given public key -func (c *BaseDIDConnectionStore) saveDID(did, key string) error { +func (c *ConnectionStore) saveDID(did, key string) error { data := didRecord{ DID: did, } @@ -56,7 +56,7 @@ func (c *BaseDIDConnectionStore) saveDID(did, key string) error { } // SaveDID saves a DID, indexed using the given public keys -func (c *BaseDIDConnectionStore) SaveDID(did string, keys ...string) error { +func (c *ConnectionStore) SaveDID(did string, keys ...string) error { for _, key := range keys { err := c.saveDID(did, key) if err != nil { @@ -68,29 +68,34 @@ func (c *BaseDIDConnectionStore) SaveDID(did string, keys ...string) error { } // SaveDIDFromDoc saves a map from a did doc's keys to the did -func (c *BaseDIDConnectionStore) SaveDIDFromDoc(doc *diddoc.Doc, serviceType, keyType string) error { - keys, ok := diddoc.LookupRecipientKeys(doc, serviceType, keyType) - if !ok { - return fmt.Errorf("getting DID doc keys") +func (c *ConnectionStore) SaveDIDFromDoc(doc *diddoc.Doc) error { + var keys []string + for i := range doc.PublicKey { + keys = append(keys, string(doc.PublicKey[i].Value)) } return c.SaveDID(doc.ID, keys...) } // SaveDIDByResolving resolves a DID using the VDR then saves the map from keys -> did -func (c *BaseDIDConnectionStore) SaveDIDByResolving(did, serviceType, keyType string) error { +// keys: fallback keys in case the DID can't be resolved +func (c *ConnectionStore) SaveDIDByResolving(did string, keys ...string) error { doc, err := c.vdr.Resolve(did) - if err != nil { + if errors.Is(err, vdri.ErrNotFound) { + return c.SaveDID(did, keys...) + } else if err != nil { return err } - return c.SaveDIDFromDoc(doc, serviceType, keyType) + return c.SaveDIDFromDoc(doc) } // GetDID gets the DID stored under the given key -func (c *BaseDIDConnectionStore) GetDID(key string) (string, error) { +func (c *ConnectionStore) GetDID(key string) (string, error) { bytes, err := c.store.Get(key) - if err != nil { + if errors.Is(err, storage.ErrDataNotFound) { + return "", ErrNotFound + } else if err != nil { return "", err } @@ -103,44 +108,3 @@ func (c *BaseDIDConnectionStore) GetDID(key string) (string, error) { return record.DID, nil } - -func (c *BaseDIDConnectionStore) resolvePublicKeys(id string) ([]string, error) { - doc, err := c.vdr.Resolve(id) - if err != nil { - return nil, err - } - - var keys []string - - for i := range doc.PublicKey { - keys = append(keys, string(doc.PublicKey[i].Value)) - } - - return keys, nil -} - -// SaveDIDConnection saves a connection between this agent's did and another agent -func (c *BaseDIDConnectionStore) SaveDIDConnection(myDID, theirDID string, theirKeys []string) error { - var keys []string - - keys, err := c.resolvePublicKeys(theirDID) - if errors.Is(err, vdri.ErrNotFound) { - keys = theirKeys - } else if err != nil { - return err - } - - // map their pub keys -> their DID - err = c.SaveDID(theirDID, keys...) - if err != nil { - return err - } - - // map their DID -> my DID - err = c.SaveDID(myDID, theirDID) - if err != nil { - return fmt.Errorf("save DID in did map: %w", err) - } - - return nil -} diff --git a/pkg/didcomm/common/didconnection/didconnection_test.go b/pkg/didcomm/common/didconnection/didconnection_test.go index bb5ee1b23..7cad6e427 100644 --- a/pkg/didcomm/common/didconnection/didconnection_test.go +++ b/pkg/didcomm/common/didconnection/didconnection_test.go @@ -86,7 +86,7 @@ func TestBaseConnectionStore(t *testing.T) { require.Equal(t, "did:abcde", didVal) wrong, err := connStore.GetDID("fhtagn") - require.EqualError(t, err, storage.ErrDataNotFound.Error()) + require.EqualError(t, err, ErrNotFound.Error()) require.Equal(t, "", wrong) err = connStore.store.Put("bad-data", []byte("aaooga")) @@ -97,31 +97,11 @@ func TestBaseConnectionStore(t *testing.T) { require.Contains(t, err.Error(), "invalid character") }) - ed25519KeyType := "Ed25519VerificationKey2018" - didCommServiceType := "did-communication" - t.Run("SaveDIDFromDoc", func(t *testing.T) { connStore, err := New(&prov) require.NoError(t, err) - err = connStore.SaveDIDFromDoc( - mockdiddoc.GetMockDIDDoc(), - didCommServiceType, - "bad") - require.Error(t, err) - require.Contains(t, err.Error(), "getting DID doc keys") - - err = connStore.SaveDIDFromDoc( - mockdiddoc.GetMockDIDDoc(), - "bad", - ed25519KeyType) - require.Error(t, err) - require.Contains(t, err.Error(), "getting DID doc keys") - - err = connStore.SaveDIDFromDoc( - mockdiddoc.GetMockDIDDoc(), - didCommServiceType, - ed25519KeyType) + err = connStore.SaveDIDFromDoc(mockdiddoc.GetMockDIDDoc()) require.NoError(t, err) }) @@ -129,10 +109,7 @@ func TestBaseConnectionStore(t *testing.T) { cs, err := New(&prov) require.NoError(t, err) - err = cs.SaveDIDByResolving( - mockdiddoc.GetMockDIDDoc().ID, - didCommServiceType, - ed25519KeyType) + err = cs.SaveDIDByResolving(mockdiddoc.GetMockDIDDoc().ID) require.NoError(t, err) }) @@ -145,46 +122,8 @@ func TestBaseConnectionStore(t *testing.T) { cs, err := New(&prov) require.NoError(t, err) - err = cs.SaveDIDByResolving("did", "abc", "def") + err = cs.SaveDIDByResolving("did") require.Error(t, err) require.Contains(t, err.Error(), "resolve error") }) - - t.Run("SaveDIDConnection success", func(t *testing.T) { - prov := ctx{ - vdr: &mockvdri.MockVDRIRegistry{ - ResolveValue: mockdiddoc.GetMockDIDDoc(), - }, - store: mockstorage.NewMockStoreProvider(), - } - - cs, err := New(&prov) - require.NoError(t, err) - - err = cs.SaveDIDConnection("mine", mockdiddoc.GetMockDIDDoc().ID, []string{"abc", "def"}) - require.NoError(t, err) - }) - - t.Run("SaveDIDConnection error", func(t *testing.T) { - prov := ctx{ - vdr: &mockvdri.MockVDRIRegistry{ - ResolveValue: mockdiddoc.GetMockDIDDoc(), - }, - store: &mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: map[string][]byte{}, - ErrPut: fmt.Errorf("store error"), - }}, - } - - cs, err := New(&prov) - require.NoError(t, err) - - err = cs.SaveDIDConnection("mine", "theirs", []string{"abc", "def"}) - require.Error(t, err) - require.Contains(t, err.Error(), "store error") - - err = cs.SaveDIDConnection("mine", "theirs", nil) - require.Error(t, err) - require.Contains(t, err.Error(), "saving DID in did map") - }) } diff --git a/pkg/didcomm/common/service/destination.go b/pkg/didcomm/common/service/destination.go index 67bf07701..7b1339150 100644 --- a/pkg/didcomm/common/service/destination.go +++ b/pkg/didcomm/common/service/destination.go @@ -24,7 +24,8 @@ type Destination struct { const ( didCommServiceType = "did-communication" - ed25519KeyType = "Ed25519VerificationKey2018" + // TODO: hardcoded key type https://github.com/hyperledger/aries-framework-go/issues/1008 + ed25519KeyType = "Ed25519VerificationKey2018" ) // GetDestination constructs a Destination struct based on the given DID and parameters diff --git a/pkg/didcomm/common/transport/envelope.go b/pkg/didcomm/common/transport/envelope.go index c910c7a0f..709bec9be 100644 --- a/pkg/didcomm/common/transport/envelope.go +++ b/pkg/didcomm/common/transport/envelope.go @@ -6,9 +6,14 @@ SPDX-License-Identifier: Apache-2.0 package transport -// Envelope contain msg, FromVerKey and ToVerKeys +// Envelope holds message data and metadata for inbound and outbound messaging type Envelope struct { Message []byte - FromVerKey string - ToVerKeys []string + FromVerKey []byte + // ToVerKeys stores string (base58) verification keys for an outbound message + ToVerKeys []string + // ToVerKey holds the key that was used to decrypt an inbound message + ToVerKey []byte + FromDID string + ToDID string } diff --git a/pkg/didcomm/dispatcher/outbound.go b/pkg/didcomm/dispatcher/outbound.go index 6f8ba3fd0..2acbd193a 100644 --- a/pkg/didcomm/dispatcher/outbound.go +++ b/pkg/didcomm/dispatcher/outbound.go @@ -11,10 +11,13 @@ import ( "fmt" "strings" + "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" ) // provider interface for outbound ctx @@ -22,6 +25,7 @@ type provider interface { Packager() commontransport.Packager OutboundTransports() []transport.OutboundTransport TransportReturnRoute() string + VDRIRegistry() vdri.Registry } // OutboundDispatcher dispatch msgs to destination @@ -29,6 +33,7 @@ type OutboundDispatcher struct { outboundTransports []transport.OutboundTransport packager commontransport.Packager transportReturnRoute string + vdRegistry vdri.Registry } // NewOutbound return new dispatcher outbound instance @@ -37,12 +42,28 @@ func NewOutbound(prov provider) *OutboundDispatcher { outboundTransports: prov.OutboundTransports(), packager: prov.Packager(), transportReturnRoute: prov.TransportReturnRoute(), + vdRegistry: prov.VDRIRegistry(), } } -// SendToDID msg +// SendToDID sends a message from myDID to the agent who owns theirDID func (o *OutboundDispatcher) SendToDID(msg interface{}, myDID, theirDID string) error { - return nil + dest, err := service.GetDestination(theirDID, o.vdRegistry) + if err != nil { + return err + } + + src, err := service.GetDestination(myDID, o.vdRegistry) + if err != nil { + return err + } + + // We get at least one recipient key, so we can use the first one + // (right now, with only one key type used for sending) + // TODO: relies on hardcoded key type + key := src.RecipientKeys[0] + + return o.Send(msg, key, dest) } // Send sends the message after packing with the sender key and recipient keys. @@ -79,7 +100,7 @@ func (o *OutboundDispatcher) Send(msg interface{}, senderVerKey string, des *ser } packedMsg, err := o.packager.PackMessage( - &commontransport.Envelope{Message: req, FromVerKey: senderVerKey, ToVerKeys: des.RecipientKeys}) + &commontransport.Envelope{Message: req, FromVerKey: base58.Decode(senderVerKey), ToVerKeys: des.RecipientKeys}) if err != nil { return fmt.Errorf("failed to pack msg: %w", err) } diff --git a/pkg/didcomm/dispatcher/outbound_test.go b/pkg/didcomm/dispatcher/outbound_test.go index eb197dec1..637028be7 100644 --- a/pkg/didcomm/dispatcher/outbound_test.go +++ b/pkg/didcomm/dispatcher/outbound_test.go @@ -19,8 +19,11 @@ import ( commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" mockdidcomm "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" + mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" ) func TestOutboundDispatcher_Send(t *testing.T) { @@ -58,6 +61,40 @@ func TestOutboundDispatcher_Send(t *testing.T) { }) } +func TestOutboundDispatcher_SendToDID(t *testing.T) { + mockDoc := mockdiddoc.GetMockDIDDoc() + + t.Run("success", func(t *testing.T) { + o := NewOutbound(&mockProvider{ + packagerValue: &mockpackager.Packager{}, + vdriRegistry: &mockvdri.MockVDRIRegistry{ + ResolveValue: mockDoc, + }, + outboundTransportsValue: []transport.OutboundTransport{ + &mockdidcomm.MockOutboundTransport{AcceptValue: true}, + }, + }) + + require.NoError(t, o.SendToDID("data", "", "")) + }) + + t.Run("resolve err", func(t *testing.T) { + o := NewOutbound(&mockProvider{ + packagerValue: &mockpackager.Packager{}, + vdriRegistry: &mockvdri.MockVDRIRegistry{ + ResolveErr: fmt.Errorf("resolve error"), + }, + outboundTransportsValue: []transport.OutboundTransport{ + &mockdidcomm.MockOutboundTransport{AcceptValue: true}, + }, + }) + + err := o.SendToDID("data", "", "") + require.Error(t, err) + require.Contains(t, err.Error(), "resolve error") + }) +} + func TestOutboundDispatcherTransportReturnRoute(t *testing.T) { t.Run("transport route option - value set all", func(t *testing.T) { transportReturnRoute := "all" @@ -168,6 +205,7 @@ type mockProvider struct { packagerValue commontransport.Packager outboundTransportsValue []transport.OutboundTransport transportReturnRoute string + vdriRegistry vdri.Registry } func (p *mockProvider) Packager() commontransport.Packager { @@ -182,6 +220,10 @@ func (p *mockProvider) TransportReturnRoute() string { return p.transportReturnRoute } +func (p *mockProvider) VDRIRegistry() vdri.Registry { + return p.vdriRegistry +} + // mockOutboundTransport mock outbound transport type mockOutboundTransport struct { expectedRequest string diff --git a/pkg/didcomm/packager/package_test.go b/pkg/didcomm/packager/package_test.go index 5f7ab1f61..39e4827d4 100644 --- a/pkg/didcomm/packager/package_test.go +++ b/pkg/didcomm/packager/package_test.go @@ -11,14 +11,17 @@ import ( "fmt" "testing" + "github.com/btcsuite/btcutil/base58" "github.com/stretchr/testify/require" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" . "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" jwe "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/jwe/authcrypt" legacy "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/legacy/authcrypt" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -26,16 +29,15 @@ import ( func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test failed to unmarshal encMessage", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -49,16 +51,15 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test bad encoding type", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -87,17 +88,16 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test key not found", func(t *testing.T) { - wp := newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }}) + wp := newMockKMSProvider(newMockStoreProvider()) w, err := kms.New(wp) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -117,7 +117,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // PackMessage should pass with both value from and to verification keys packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -131,27 +131,26 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test Pack/Unpack fails", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) - decryptValue := func(envelope []byte) ([]byte, []byte, error) { - return nil, nil, fmt.Errorf("unpack error") + decryptValue := func(envelope []byte) (*transport.Envelope, error) { + return nil, fmt.Errorf("unpack error") } mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // use a mocked packager with a mocked KMS to validate pack/unpack e := func(payload []byte, senderPubKey []byte, recipientsKeys [][]byte) (bytes []byte, e error) { - packer, e := jwe.New(mockedProviders, jwe.XC20P) + p, e := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, e) - return packer.Pack(payload, senderPubKey, recipientsKeys) + return p.Pack(payload, senderPubKey, recipientsKeys) } mockPacker := &didcomm.MockAuthCrypt{DecryptValue: decryptValue, EncryptValue: e, Type: "prs.hyperledger.aries-auth-message"} @@ -174,7 +173,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // now try to pack with non empty envelope - should pass packMsg, err = packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) require.NotEmpty(t, packMsg) @@ -194,7 +193,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { packager, err = New(mockedProviders) require.NoError(t, err) packMsg, err = packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.Error(t, err) require.Empty(t, packMsg) @@ -203,14 +202,14 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test Pack/Unpack success", func(t *testing.T) { // create a mock KMS with storage as a map - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: map[string][]byte{}}})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // create a real testPacker (no mocking here) @@ -233,7 +232,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // pack an non empty envelope - should pass packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -250,7 +249,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { require.NoError(t, err) packMsg, err = packager2.PackMessage(&transport.Envelope{Message: []byte("msg2"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -259,10 +258,99 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { require.NoError(t, err) require.Equal(t, unpackedMsg.Message, []byte("msg2")) }) + + t.Run("test success - dids not found", func(t *testing.T) { + // create a mock KMS with storage as a map + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) + require.NoError(t, err) + mockedProviders := &mockProvider{ + storage: newMockStoreProvider(), + kms: w, + primaryPacker: nil, + packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDErr: didconnection.ErrNotFound}, + } + + // create a real testPacker (no mocking here) + testPacker := legacy.New(mockedProviders) + require.NoError(t, err) + mockedProviders.primaryPacker = testPacker + + mockedProviders.packers = []packer.Packer{testPacker} + + // now create a new packager with the above provider context + packager, err := New(mockedProviders) + require.NoError(t, err) + + _, base58FromVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + _, base58ToVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + // pack an non empty envelope - should pass + packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), + FromVerKey: base58.Decode(base58FromVerKey), + ToVerKeys: []string{base58ToVerKey}}) + require.NoError(t, err) + + // unpack the packed message above - should pass and match the same payload (msg1) + unpackedMsg, err := packager.UnpackMessage(packMsg) + require.NoError(t, err) + require.Equal(t, unpackedMsg.Message, []byte("msg1")) + }) + + t.Run("test failure - did lookup broke", func(t *testing.T) { + // create a mock KMS with storage as a map + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) + require.NoError(t, err) + mockedProviders := &mockProvider{ + storage: newMockStoreProvider(), + kms: w, + primaryPacker: nil, + packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDErr: fmt.Errorf("bad error")}, + } + + // create a real testPacker (no mocking here) + testPacker := legacy.New(mockedProviders) + require.NoError(t, err) + mockedProviders.primaryPacker = testPacker + + mockedProviders.packers = []packer.Packer{testPacker} + + // now create a new packager with the above provider context + packager, err := New(mockedProviders) + require.NoError(t, err) + + _, base58FromVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + _, base58ToVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + // pack an non empty envelope - should pass + packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), + FromVerKey: base58.Decode(base58FromVerKey), + ToVerKeys: []string{base58ToVerKey}}) + require.NoError(t, err) + + // unpack the packed message above - should pass and match the same payload (msg1) + unpackedMsg, err := packager.UnpackMessage(packMsg) + require.Error(t, err) + require.Contains(t, err.Error(), "bad error") + require.Nil(t, unpackedMsg) + }) } func newMockKMSProvider(storagePvdr *mockstorage.MockStoreProvider) *mockProvider { - return &mockProvider{storagePvdr, nil, nil, nil} + return &mockProvider{storagePvdr, nil, nil, nil, nil} +} + +func newMockStoreProvider() *mockstorage.MockStoreProvider { + return &mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ + Store: make(map[string][]byte), + }} } // mockProvider mocks provider for KMS @@ -271,6 +359,7 @@ type mockProvider struct { kms kms.KeyManager packers []packer.Packer primaryPacker packer.Packer + lookupStore didconnection.Store } func (m *mockProvider) Packers() []packer.Packer { @@ -288,3 +377,8 @@ func (m *mockProvider) StorageProvider() storage.Provider { func (m *mockProvider) PrimaryPacker() packer.Packer { return m.primaryPacker } + +// DIDConnectionStore returns a didconnection.Store service. +func (m *mockProvider) DIDConnectionStore() didconnection.Store { + return m.lookupStore +} diff --git a/pkg/didcomm/packager/packager.go b/pkg/didcomm/packager/packager.go index 8ceda79c6..cb36fc1f1 100644 --- a/pkg/didcomm/packager/packager.go +++ b/pkg/didcomm/packager/packager.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" ) @@ -22,6 +23,7 @@ import ( type Provider interface { Packers() []packer.Packer PrimaryPacker() packer.Packer + DIDConnectionStore() didconnection.Store } // Creator method to create new packager service @@ -29,8 +31,9 @@ type Creator func(prov Provider) (transport.Packager, error) // Packager is the basic implementation of Packager type Packager struct { - primaryPacker packer.Packer - packers map[string]packer.Packer + primaryPacker packer.Packer + packers map[string]packer.Packer + connectionStore didconnection.Store } // PackerCreator holds a creator function for a Packer and the name of the Packer's encoding method. @@ -42,8 +45,9 @@ type PackerCreator struct { // New return new instance of KMS implementation func New(ctx Provider) (*Packager, error) { basePackager := Packager{ - primaryPacker: nil, - packers: map[string]packer.Packer{}, + primaryPacker: nil, + packers: map[string]packer.Packer{}, + connectionStore: ctx.DIDConnectionStore(), } for _, packerType := range ctx.Packers() { @@ -52,7 +56,7 @@ func New(ctx Provider) (*Packager, error) { basePackager.primaryPacker = ctx.PrimaryPacker() if basePackager.primaryPacker == nil { - return nil, fmt.Errorf("need primary primaryPacker to initialize packager") + return nil, fmt.Errorf("need primary packer to initialize packager") } basePackager.addPacker(basePackager.primaryPacker) @@ -86,7 +90,7 @@ func (bp *Packager) PackMessage(messageEnvelope *transport.Envelope) ([]byte, er recipients = append(recipients, verKeyBytes) } // pack message - bytes, err := bp.primaryPacker.Pack(messageEnvelope.Message, base58.Decode(messageEnvelope.FromVerKey), recipients) + bytes, err := bp.primaryPacker.Pack(messageEnvelope.Message, messageEnvelope.FromVerKey, recipients) if err != nil { return nil, fmt.Errorf("pack: %w", err) } @@ -146,10 +150,27 @@ func (bp *Packager) UnpackMessage(encMessage []byte) (*transport.Envelope, error return nil, fmt.Errorf("message Type not recognized") } - data, senderVerKey, err := p.Unpack(encMessage) + envelope, err := p.Unpack(encMessage) if err != nil { return nil, fmt.Errorf("unpack: %w", err) } - return &transport.Envelope{Message: data, FromVerKey: base58.Encode(senderVerKey)}, nil + // ignore error - agents can communicate without using DIDs - for example, in DIDExchange + theirDID, err := bp.connectionStore.GetDID(base58.Encode(envelope.FromVerKey)) + if errors.Is(err, didconnection.ErrNotFound) { + } else if err != nil { + return nil, fmt.Errorf("failed to get their did: %w", err) + } + + // ignore error - at beginning of DIDExchange, you might be about to generate a DID + myDID, err := bp.connectionStore.GetDID(base58.Encode(envelope.ToVerKey)) + if errors.Is(err, didconnection.ErrNotFound) { + } else if err != nil { + return nil, fmt.Errorf("failed to get my did: %w", err) + } + + envelope.ToDID = myDID + envelope.FromDID = theirDID + + return envelope, nil } diff --git a/pkg/didcomm/packer/api.go b/pkg/didcomm/packer/api.go index 0c7b50503..ab5c0803c 100644 --- a/pkg/didcomm/packer/api.go +++ b/pkg/didcomm/packer/api.go @@ -6,7 +6,10 @@ SPDX-License-Identifier: Apache-2.0 package packer -import "github.com/hyperledger/aries-framework-go/pkg/kms" +import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/kms" +) // Provider interface for Packer ctx type Provider interface { @@ -30,11 +33,10 @@ type Packer interface { // The recipient's key will be the one found in KMS that matches one of the list of recipients in the envelope // // returns: - // []byte containing the decrypted payload - // []byte contains the sender verification key + // Envelope containing the message, decryption key, and sender key // error if decryption failed // TODO add key type of recipients keys to be validated by the implementation - Issue #272 - Unpack(envelope []byte) ([]byte, []byte, error) + Unpack(envelope []byte) (*transport.Envelope, error) // Encoding returns the type of the encoding, as found in the header `Typ` field EncodingType() string diff --git a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go index d2dadf8ca..87f134beb 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go @@ -304,12 +304,14 @@ func TestEncrypt(t *testing.T) { t.Logf("Encryption with XC20P: %s", m) // decrypt for rec1 (as found in kms) - dec, senderVerKey, e := packer.Unpack(enc) + env, e := packer.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) }) t.Run("Success test case: : pack and unpack", func(t *testing.T) { @@ -331,10 +333,17 @@ func TestEncrypt(t *testing.T) { [][]byte{base58.Decode(recSign)}) require.NoError(t, err) - msgOut, sendKey, err := recPacker.Unpack(enc) + env, err := recPacker.Unpack(enc) require.NoError(t, err) - require.Equal(t, msgIn, msgOut) - require.Equal(t, base58.Encode(sendKey), encSign) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.Equal(t, msgIn, env.Message) + require.NotEmpty(t, env.FromVerKey) + require.Equal(t, base58.Encode(env.FromVerKey), encSign) + // this won't work, since input keys and output keys to this Packer + // are different types. + // require.Equal(t, base58.Encode(recKey), recSign) + require.NotEmpty(t, env.ToVerKey) }) t.Run("Success test case: Decrypting a message with two PackerValue instances to simulate two agents", func(t *testing.T) { //nolint:lll @@ -356,22 +365,26 @@ func TestEncrypt(t *testing.T) { // now decrypt with recipient3 packer1, e := New(recipient3KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + env, e := packer1.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) // now try decrypting with recipient2 packer2, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e = packer2.Unpack(enc) + env, e = packer2.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) t.Logf("Decryption Payload with XC20P: %s", pld) }) @@ -392,10 +405,9 @@ func TestEncrypt(t *testing.T) { // decrypting for recipient 2 (unauthorized) packer1, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + env, e := packer1.Unpack(enc) require.Error(t, e) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) }) t.Run("Failure test case: Decrypting a message but scramble JWE beforehand", func(t *testing.T) { @@ -427,10 +439,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e := packer.Unpack(enc) + env, e := packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 12") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.CipherText = validJwe.CipherText // update jwe with bad nonce format @@ -438,10 +449,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 5") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.IV = validJwe.IV // update jwe with bad tag format @@ -449,10 +459,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag format - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Tag = validJwe.Tag // update jwe with bad recipient spk (JWE format) @@ -460,10 +469,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: sender key: bad SPK format") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.SPK = validJwe.Recipients[0].Header.SPK // update jwe with bad recipient tag format @@ -471,10 +479,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.Tag = validJwe.Recipients[0].Header.Tag // update jwe with bad recipient nonce format @@ -482,10 +489,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 5") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV // update jwe with bad recipient nonce format @@ -493,10 +499,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: bad nonce size") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV // update jwe with bad recipient apu format @@ -504,10 +509,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.APU = validJwe.Recipients[0].Header.APU // update jwe with bad recipient kid (sender) format @@ -515,10 +519,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, fmt.Sprintf("unpack: %s", cryptoutil.ErrKeyNotFound.Error())) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.KID = validJwe.Recipients[0].Header.KID // update jwe with bad recipient CEK (encrypted key) format @@ -526,10 +529,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 15") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].EncryptedKey = validJwe.Recipients[0].EncryptedKey // update jwe with bad recipient CEK (encrypted key) value @@ -537,10 +539,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].EncryptedKey = validJwe.Recipients[0].EncryptedKey // now try bad nonce size @@ -549,10 +550,9 @@ func TestEncrypt(t *testing.T) { require.NoError(t, e) // decrypt with bad nonce value require.PanicsWithValue(t, "chacha20poly1305: bad nonce length passed to Open", func() { - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) }) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.IV = validJwe.IV // now try bad nonce value @@ -560,10 +560,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV }) } @@ -658,8 +657,9 @@ func TestRefEncrypt(t *testing.T) { require.NoError(t, err) require.NotNil(t, packer) - dec, senderVerKey, err := packer.Unpack([]byte(refJWE)) + env, err := packer.Unpack([]byte(refJWE)) require.NoError(t, err) - require.NotEmpty(t, dec) - require.NotEmpty(t, senderVerKey) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.NotEmpty(t, env.FromVerKey) } diff --git a/pkg/didcomm/packer/jwe/authcrypt/unpack.go b/pkg/didcomm/packer/jwe/authcrypt/unpack.go index ecbf4db55..01e70c3b9 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/unpack.go +++ b/pkg/didcomm/packer/jwe/authcrypt/unpack.go @@ -14,6 +14,8 @@ import ( "github.com/btcsuite/btcutil/base58" chacha "golang.org/x/crypto/chacha20poly1305" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" ) // Unpack will JWE decode the envelope argument for the recipientPrivKey and validates @@ -22,22 +24,22 @@ import ( // encrypted CEK. // The current recipient is the one with the sender's encrypted key that successfully // decrypts with recipientKeyPair.Priv Key. -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) (*transport.Envelope, error) { jwe := &Envelope{} err := json.Unmarshal(envelope, jwe) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, fmt.Errorf("unpack json: %w", err) } recipientPubKey, recipient, err := p.findRecipient(jwe.Recipients) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, fmt.Errorf("unpack: %w", err) } senderKey, err := p.decryptSPK(recipientPubKey, recipient.Header.SPK) if err != nil { - return nil, nil, fmt.Errorf("unpack: sender key: %w", err) + return nil, fmt.Errorf("unpack: sender key: %w", err) } // senderKey must not be empty to proceed @@ -47,18 +49,22 @@ func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { sharedKey, er := p.decryptCEK(recipientPubKey, senderPubKey, recipient) if er != nil { - return nil, nil, fmt.Errorf("unpack: decrypt shared key: %w", er) + return nil, fmt.Errorf("unpack: decrypt shared key: %w", er) } symOutput, er := p.decryptPayload(sharedKey, jwe) if er != nil { - return nil, nil, fmt.Errorf("unpack: %w", er) + return nil, fmt.Errorf("unpack: %w", er) } - return symOutput, senderKey, nil + return &transport.Envelope{ + Message: symOutput, + FromVerKey: senderKey, + ToVerKey: recipientPubKey[:], + }, nil } - return nil, nil, errors.New("unpack: invalid sender key in envelope") + return nil, errors.New("unpack: invalid sender key in envelope") } func (p *Packer) decryptPayload(cek []byte, jwe *Envelope) ([]byte, error) { diff --git a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go index 6f7e76737..d92146777 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go @@ -384,11 +384,12 @@ func TestDecrypt(t *testing.T) { enc, err := packer.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(recKey)}) require.NoError(t, err) - msgOut, senderVerKey, err := packer.Unpack(enc) + env, err := packer.Unpack(enc) require.NoError(t, err) - require.ElementsMatch(t, msgIn, msgOut) - require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.ElementsMatch(t, msgIn, env.Message) + require.Equal(t, senderKey, base58.Encode(env.FromVerKey)) + require.Equal(t, recKey, base58.Encode(env.ToVerKey)) }) t.Run("Success: pack and unpack, different packers, including fail recipient who wasn't sent the message", func(t *testing.T) { // nolint: lll @@ -413,15 +414,16 @@ func TestDecrypt(t *testing.T) { enc, err := sendPacker.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(rec1Key), base58.Decode(rec2Key), base58.Decode(rec3Key)}) require.NoError(t, err) - msgOut, senderVerKey, err := rec2Packer.Unpack(enc) + env, err := rec2Packer.Unpack(enc) require.NoError(t, err) - require.ElementsMatch(t, msgIn, msgOut) - require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.ElementsMatch(t, msgIn, env.Message) + require.Equal(t, senderKey, base58.Encode(env.FromVerKey)) + require.Equal(t, rec2Key, base58.Encode(env.ToVerKey)) emptyKMS, _ := newKMS(t) rec4Packer := newWithKMS(emptyKMS) - _, _, err = rec4Packer.Unpack(enc) + _, err = rec4Packer.Unpack(enc) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -440,10 +442,12 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + envOut, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) - require.ElementsMatch(t, []byte(msg), msgOut) - require.NotEmpty(t, senderVerKey) + require.ElementsMatch(t, []byte(msg), envOut.Message) + require.NotEmpty(t, envOut.FromVerKey) + require.NotEmpty(t, envOut.ToVerKey) + require.Equal(t, recPub, base58.Encode(envOut.ToVerKey)) }) t.Run("Test unpacking python envelope with multiple recipients", func(t *testing.T) { @@ -461,10 +465,12 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + envOut, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) - require.ElementsMatch(t, []byte(msg), msgOut) - require.NotEmpty(t, senderVerKey) + require.ElementsMatch(t, []byte(msg), envOut.Message) + require.NotEmpty(t, envOut.FromVerKey) + require.NotEmpty(t, envOut.ToVerKey) + require.Equal(t, recPub, base58.Encode(envOut.ToVerKey)) }) t.Run("Test unpacking python envelope with invalid recipient", func(t *testing.T) { @@ -479,7 +485,7 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - _, _, err = recPacker.Unpack([]byte(env)) + _, err = recPacker.Unpack([]byte(env)) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -502,7 +508,7 @@ func unpackComponentFailureTest( } recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(fullMessage)) + _, err = recPacker.Unpack([]byte(fullMessage)) require.NotNil(t, err) require.Contains(t, err.Error(), errString) } @@ -520,7 +526,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "invalid character 'e' looking for beginning of value") }) @@ -533,7 +539,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "illegal base64 data at input byte 0") }) @@ -694,7 +700,7 @@ func Test_getCEK(t *testing.T) { }, } - _, _, err := getCEK(recs, &k) + _, err := getCEK(recs, &k) require.EqualError(t, err, "mock error") } diff --git a/pkg/didcomm/packer/legacy/authcrypt/unpack.go b/pkg/didcomm/packer/legacy/authcrypt/unpack.go index ade93d45c..07419dc99 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/unpack.go +++ b/pkg/didcomm/packer/legacy/authcrypt/unpack.go @@ -14,52 +14,65 @@ import ( "github.com/btcsuite/btcutil/base58" chacha "golang.org/x/crypto/chacha20poly1305" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/internal/cryptoutil" "github.com/hyperledger/aries-framework-go/pkg/kms" ) // Unpack will decode the envelope using the legacy format // Using (X)Chacha20 encryption algorithm and Poly1035 authenticator -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) (*transport.Envelope, error) { var envelopeData legacyEnvelope err := json.Unmarshal(envelope, &envelopeData) if err != nil { - return nil, nil, err + return nil, err } protectedBytes, err := base64.URLEncoding.DecodeString(envelopeData.Protected) if err != nil { - return nil, nil, err + return nil, err } var protectedData protected err = json.Unmarshal(protectedBytes, &protectedData) if err != nil { - return nil, nil, err + return nil, err } if protectedData.Typ != encodingType { - return nil, nil, fmt.Errorf("message type %s not supported", protectedData.Typ) + return nil, fmt.Errorf("message type %s not supported", protectedData.Typ) } if protectedData.Alg != "Authcrypt" { // TODO https://github.com/hyperledger/aries-framework-go/issues/41 change this when anoncrypt is introduced - return nil, nil, fmt.Errorf("message format %s not supported", protectedData.Alg) + return nil, fmt.Errorf("message format %s not supported", protectedData.Alg) } - cek, recKey, err := getCEK(protectedData.Recipients, p.kms) + keys, err := getCEK(protectedData.Recipients, p.kms) if err != nil { - return nil, nil, err + return nil, err } + cek, senderKey, recKey := keys.cek, keys.theirKey, keys.myKey + data, err := p.decodeCipherText(cek, &envelopeData) - return data, recKey, err + return &transport.Envelope{ + Message: data, + FromVerKey: senderKey, + ToVerKey: recKey, + }, err +} + +type keys struct { + cek *[chacha.KeySize]byte + theirKey []byte + myKey []byte } -func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, []byte, error) { +func getCEK(recipients []recipient, km kms.KeyManager) (*keys, error) { var candidateKeys []string for _, candidate := range recipients { @@ -68,47 +81,51 @@ func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, [ recKeyIdx, err := km.FindVerKey(candidateKeys) if err != nil { - return nil, nil, fmt.Errorf("no key accessible %w", err) + return nil, fmt.Errorf("no key accessible %w", err) } recip := recipients[recKeyIdx] - recKey := recip.Header.KID + recKey := base58.Decode(recip.Header.KID) - recCurvePub, err := km.ConvertToEncryptionKey(base58.Decode(recKey)) + recCurvePub, err := km.ConvertToEncryptionKey(recKey) if err != nil { - return nil, nil, err + return nil, err } senderPub, senderPubCurve, err := decodeSender(recip.Header.Sender, recCurvePub, km) if err != nil { - return nil, nil, err + return nil, err } nonceSlice, err := base64.URLEncoding.DecodeString(recip.Header.IV) if err != nil { - return nil, nil, err + return nil, err } encCEK, err := base64.URLEncoding.DecodeString(recip.EncryptedKey) if err != nil { - return nil, nil, err + return nil, err } b, err := kms.NewCryptoBox(km) if err != nil { - return nil, nil, err + return nil, err } cekSlice, err := b.EasyOpen(encCEK, nonceSlice, senderPubCurve, recCurvePub) if err != nil { - return nil, nil, fmt.Errorf("failed to decrypt CEK: %s", err) + return nil, fmt.Errorf("failed to decrypt CEK: %s", err) } var cek [chacha.KeySize]byte copy(cek[:], cekSlice) - return &cek, senderPub, nil + return &keys{ + cek: &cek, + theirKey: senderPub, + myKey: recKey, + }, nil } func decodeSender(b64Sender string, pk []byte, km kms.KeyManager) ([]byte, []byte, error) { diff --git a/pkg/didcomm/protocol/didexchange/persistence.go b/pkg/didcomm/protocol/didexchange/persistence.go index 636185a1d..e4d63b79e 100644 --- a/pkg/didcomm/protocol/didexchange/persistence.go +++ b/pkg/didcomm/protocol/didexchange/persistence.go @@ -55,15 +55,15 @@ func (r *ConnectionRecord) isValid() error { } // NewConnectionRecorder returns new connection record instance -func NewConnectionRecorder(transientStore, store storage.Store, didMap didconnection.Store) *ConnectionRecorder { - return &ConnectionRecorder{transientStore: transientStore, store: store, didMap: didMap} +func NewConnectionRecorder(transientStore, store storage.Store, didStore didconnection.Store) *ConnectionRecorder { + return &ConnectionRecorder{transientStore: transientStore, store: store, didStore: didStore} } // ConnectionRecorder takes care of connection related persistence features type ConnectionRecorder struct { transientStore storage.Store store storage.Store - didMap didconnection.Store + didStore didconnection.Store } // SaveInvitation saves connection invitation to underlying store @@ -234,7 +234,7 @@ func (c *ConnectionRecorder) saveConnectionRecord(record *ConnectionRecord) erro return fmt.Errorf("save connection record in permanent store: %w", err) } - if err := c.didMap.SaveDIDConnection(record.MyDID, record.TheirDID, record.RecipientKeys); err != nil { + if err := c.didStore.SaveDIDByResolving(record.TheirDID, record.RecipientKeys...); err != nil { return err } } @@ -266,7 +266,7 @@ func (c *ConnectionRecorder) saveNewConnectionRecord(record *ConnectionRecord) e } if record.MyDID != "" { - if err := c.didMap.SaveDIDByResolving(record.MyDID, didCommServiceType, ed25519KeyType); err != nil { + if err := c.didStore.SaveDIDByResolving(record.MyDID); err != nil { return err } } diff --git a/pkg/didcomm/protocol/didexchange/persistence_test.go b/pkg/didcomm/protocol/didexchange/persistence_test.go index 626817cc1..00f73aa58 100644 --- a/pkg/didcomm/protocol/didexchange/persistence_test.go +++ b/pkg/didcomm/protocol/didexchange/persistence_test.go @@ -285,24 +285,6 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { err := record.saveNewConnectionRecord(connRec) require.Contains(t, err.Error(), "get error") }) - t.Run("error saving DID", func(t *testing.T) { - transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store, &mockdidconnection.MockDIDConnection{ - SaveConnectionErr: fmt.Errorf("save error"), - }) - require.NotNil(t, record) - connRec := &ConnectionRecord{ThreadID: threadIDValue, - ConnectionID: connIDValue, State: stateNameCompleted, Namespace: theirNSPrefix} - err := record.saveNewConnectionRecord(connRec) - require.Error(t, err) - require.Contains(t, err.Error(), "save error") - - // note: record is still stored, since error happens afterwards - storedRecord, err := record.GetConnectionRecord(connRec.ConnectionID) - require.NoError(t, err) - require.Equal(t, connRec, storedRecord) - }) t.Run("error saving DID by resolving", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} diff --git a/pkg/didcomm/protocol/didexchange/service.go b/pkg/didcomm/protocol/didexchange/service.go index eb1d771c2..5a7178b7c 100644 --- a/pkg/didcomm/protocol/didexchange/service.go +++ b/pkg/didcomm/protocol/didexchange/service.go @@ -640,7 +640,6 @@ func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID, inviteeLabe return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err) } - // TODO: hardcoded key type dest, err := service.CreateDestination(didDoc) if err != nil { return "", err diff --git a/pkg/didcomm/protocol/didexchange/service_test.go b/pkg/didcomm/protocol/didexchange/service_test.go index 88b8d719c..f71468501 100644 --- a/pkg/didcomm/protocol/didexchange/service_test.go +++ b/pkg/didcomm/protocol/didexchange/service_test.go @@ -1385,7 +1385,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(nil, store.Store, nil), + connectionStore: NewConnectionRecorder(nil, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1405,7 +1406,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolve error")}, - connectionStore: NewConnectionRecorder(nil, store.Store, nil), + connectionStore: NewConnectionRecorder(nil, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1428,7 +1430,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, nil), + connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store, TransientStoreProvider: transientStore}) diff --git a/pkg/didcomm/protocol/didexchange/states.go b/pkg/didcomm/protocol/didexchange/states.go index 10d237f3a..0c8e57173 100644 --- a/pkg/didcomm/protocol/didexchange/states.go +++ b/pkg/didcomm/protocol/didexchange/states.go @@ -405,8 +405,7 @@ func (ctx *context) getDIDDocAndConnection(pubDID string) (*did.Doc, *Connection return nil, nil, fmt.Errorf("resolve public did[%s]: %w", pubDID, err) } - // TODO: x.y.z.foo - err = ctx.connectionStore.didMap.SaveDIDFromDoc(didDoc, didCommServiceType, ed25519KeyType) + err = ctx.connectionStore.didStore.SaveDIDFromDoc(didDoc) if err != nil { return nil, nil, err } @@ -422,12 +421,7 @@ func (ctx *context) getDIDDocAndConnection(pubDID string) (*did.Doc, *Connection return nil, nil, fmt.Errorf("create %s did: %w", didMethod, err) } - // TODO: initialize did map (or mock) in all eleventy billion tests - if ctx.connectionStore.didMap == nil { - return nil, nil, fmt.Errorf("NIL CONN STORE") - } - - err = ctx.connectionStore.didMap.SaveDIDFromDoc(newDidDoc, didCommServiceType, ed25519KeyType) + err = ctx.connectionStore.didStore.SaveDIDFromDoc(newDidDoc) if err != nil { return nil, nil, err } diff --git a/pkg/didcomm/transport/http/inbound.go b/pkg/didcomm/transport/http/inbound.go index 306dc24f7..ed0907497 100644 --- a/pkg/didcomm/transport/http/inbound.go +++ b/pkg/didcomm/transport/http/inbound.go @@ -65,7 +65,7 @@ func processPOSTRequest(w http.ResponseWriter, r *http.Request, prov transport.P messageHandler := prov.InboundMessageHandler() - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { // TODO https://github.com/hyperledger/aries-framework-go/issues/271 HTTP Response Codes based on errors // from service diff --git a/pkg/didcomm/transport/http/inbound_test.go b/pkg/didcomm/transport/http/inbound_test.go index a4764f039..69769300b 100644 --- a/pkg/didcomm/transport/http/inbound_test.go +++ b/pkg/didcomm/transport/http/inbound_test.go @@ -30,7 +30,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Debugf("message received is %s", message) return nil } diff --git a/pkg/didcomm/transport/transport_interface.go b/pkg/didcomm/transport/transport_interface.go index e6b6cfd70..d7f58cc21 100644 --- a/pkg/didcomm/transport/transport_interface.go +++ b/pkg/didcomm/transport/transport_interface.go @@ -30,7 +30,7 @@ type OutboundTransport interface { // InboundMessageHandler handles the inbound requests. The transport will unpack the payload prior to the // message handle invocation. -type InboundMessageHandler func(message []byte) error +type InboundMessageHandler func(message []byte, myDID, theirDID string) error // Provider contains dependencies for starting the inbound/outbound transports. // It is typically created by using aries.Context(). diff --git a/pkg/didcomm/transport/ws/pool.go b/pkg/didcomm/transport/ws/pool.go index be8fd942f..86752d506 100644 --- a/pkg/didcomm/transport/ws/pool.go +++ b/pkg/didcomm/transport/ws/pool.go @@ -10,6 +10,7 @@ import ( "encoding/json" "sync" + "github.com/btcsuite/btcutil/base58" "nhooyr.io/websocket" commtransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" @@ -93,12 +94,12 @@ func (d *connPool) listener(conn *websocket.Conn) { } if trans != nil && trans.ReturnRoute != nil && trans.ReturnRoute.Value == decorator.TransportReturnRouteAll { - d.add(unpackMsg.FromVerKey, conn) + d.add(base58.Encode(unpackMsg.FromVerKey), conn) } messageHandler := d.msgHandler - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { logger.Errorf("incoming msg processing failed: %v", err) } diff --git a/pkg/didcomm/transport/ws/pool_test.go b/pkg/didcomm/transport/ws/pool_test.go index a008956a6..8d6d78b93 100644 --- a/pkg/didcomm/transport/ws/pool_test.go +++ b/pkg/didcomm/transport/ws/pool_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcutil/base58" "github.com/google/uuid" "github.com/stretchr/testify/require" "nhooyr.io/websocket" @@ -40,14 +41,14 @@ func TestConnectionStore(t *testing.T) { // create a transport provider (framework context) verKey := "ABCD" mockPackager := &mockpackager.Packager{ - UnpackValue: &commontransport.Envelope{Message: request, FromVerKey: verKey}, + UnpackValue: &commontransport.Envelope{Message: request, FromVerKey: base58.Decode(verKey)}, } response := "Hello" transportProvider := &mockTransportProvider{ packagerValue: mockPackager, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { resp, outboundErr := outbound.Send([]byte(response), prepareDestinationWithTransport("ws://doesnt-matter", "", []string{verKey})) require.NoError(t, outboundErr) @@ -100,7 +101,7 @@ func TestConnectionStore(t *testing.T) { transportProvider := &mockTransportProvider{ packagerValue: &mockPackager{verKey: verKey}, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { // validate the echo server response with the outbound sent message require.Equal(t, request, message) done <- struct{}{} diff --git a/pkg/didcomm/transport/ws/support_test.go b/pkg/didcomm/transport/ws/support_test.go index acd49fc50..d5665d605 100644 --- a/pkg/didcomm/transport/ws/support_test.go +++ b/pkg/didcomm/transport/ws/support_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcutil/base58" "github.com/google/uuid" "github.com/stretchr/testify/require" "nhooyr.io/websocket" @@ -32,7 +33,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Infof("message received is %s", string(message)) if string(message) == "invalid-data" { return errors.New("error") @@ -146,12 +147,12 @@ func (m *mockPackager) PackMessage(e *commontransport.Envelope) ([]byte, error) } func (m *mockPackager) UnpackMessage(encMessage []byte) (*commontransport.Envelope, error) { - return &commontransport.Envelope{Message: encMessage, FromVerKey: m.verKey}, nil + return &commontransport.Envelope{Message: encMessage, FromVerKey: base58.Decode(m.verKey)}, nil } type mockTransportProvider struct { packagerValue commontransport.Packager - executeInbound func(message []byte) error + executeInbound func(message []byte, myDID, theirDID string) error frameworkID string } diff --git a/pkg/framework/aries/framework.go b/pkg/framework/aries/framework.go index 568720758..4d0508b95 100644 --- a/pkg/framework/aries/framework.go +++ b/pkg/framework/aries/framework.go @@ -339,6 +339,7 @@ func createOutboundDispatcher(frameworkOpts *Aries) error { context.WithOutboundTransports(frameworkOpts.outboundTransports...), context.WithPackager(frameworkOpts.packager), context.WithTransportReturnRoute(frameworkOpts.transportReturnRoute), + context.WithVDRIRegistry(frameworkOpts.vdriRegistry), ) if err != nil { return fmt.Errorf("context creation failed: %w", err) diff --git a/pkg/framework/context/context.go b/pkg/framework/context/context.go index 92d4ff4fb..9b4adf974 100644 --- a/pkg/framework/context/context.go +++ b/pkg/framework/context/context.go @@ -111,7 +111,7 @@ func (p *Provider) InboundTransportEndpoint() string { // InboundMessageHandler return an inbound message handler. func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { msg, err := service.NewDIDCommMsg(message) if err != nil { return err @@ -120,7 +120,7 @@ func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { // find the service which accepts the message type for _, svc := range p.services { if svc.Accept(msg.Header.Type) { - _, err = svc.HandleInbound(msg, "", "") + _, err = svc.HandleInbound(msg, myDID, theirDID) return err } } diff --git a/pkg/framework/context/context_test.go b/pkg/framework/context/context_test.go index d8fd5531d..726f2bbcf 100644 --- a/pkg/framework/context/context_test.go +++ b/pkg/framework/context/context_test.go @@ -98,16 +98,16 @@ func TestNewProvider(t *testing.T) { { "@frameworkID": "5678876542345", "@type": "valid-message-type" - }`)) + }`), "", "") require.NoError(t, err) // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") @@ -116,7 +116,7 @@ func TestNewProvider(t *testing.T) { { "@type": "invalid-message-type", "label": "Bob" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "no message handlers found for the message type: invalid-message-type") @@ -125,7 +125,7 @@ func TestNewProvider(t *testing.T) { { "label": "Carol", "@type": "valid-message-type" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "error handling the message") }) diff --git a/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go index bac4b7ec4..b5fcf0e72 100644 --- a/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go +++ b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go @@ -11,13 +11,12 @@ import ( // MockDIDConnection mocks the did lookup store. type MockDIDConnection struct { - SaveRecordErr error - SaveConnectionErr error - SaveKeysErr error - GetDIDValue string - GetDIDErr error - SaveDIDErr error - ResolveDIDErr error + SaveRecordErr error + SaveKeysErr error + GetDIDValue string + GetDIDErr error + SaveDIDErr error + ResolveDIDErr error } // SaveDID saves a DID to the store @@ -30,17 +29,12 @@ func (m *MockDIDConnection) GetDID(key string) (string, error) { return m.GetDIDValue, m.GetDIDErr } -// SaveDIDConnection saves a DID connection -func (m *MockDIDConnection) SaveDIDConnection(myDID, theirDID string, theirKeys []string) error { - return m.SaveConnectionErr -} - // SaveDIDByResolving saves a DID by resolving it then using its doc -func (m *MockDIDConnection) SaveDIDByResolving(did, serviceType, keyType string) error { +func (m *MockDIDConnection) SaveDIDByResolving(did string, keys ...string) error { return m.ResolveDIDErr } // SaveDIDFromDoc saves a DID using the given doc -func (m *MockDIDConnection) SaveDIDFromDoc(doc *diddoc.Doc, serviceType, keyType string) error { +func (m *MockDIDConnection) SaveDIDFromDoc(doc *diddoc.Doc) error { return m.SaveDIDErr } diff --git a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go index 25f683e42..99ae4a076 100644 --- a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go +++ b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go @@ -27,7 +27,7 @@ func (m *MockOutbound) Send(msg interface{}, senderVerKey string, des *service.D // SendToDID msg func (m *MockOutbound) SendToDID(msg interface{}, myDID, theirDID string) error { - return nil + return m.SendErr } // Forward msg diff --git a/pkg/internal/mock/didcomm/mock_authcrypt.go b/pkg/internal/mock/didcomm/mock_authcrypt.go index d8ef0c72f..b75acc33f 100644 --- a/pkg/internal/mock/didcomm/mock_authcrypt.go +++ b/pkg/internal/mock/didcomm/mock_authcrypt.go @@ -6,10 +6,12 @@ SPDX-License-Identifier: Apache-2.0 package didcomm +import "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + // MockAuthCrypt mock auth crypt type MockAuthCrypt struct { EncryptValue func(payload, senderPubKey []byte, recipients [][]byte) ([]byte, error) - DecryptValue func(envelope []byte) ([]byte, []byte, error) + DecryptValue func(envelope []byte) (*transport.Envelope, error) Type string } @@ -20,7 +22,7 @@ func (m *MockAuthCrypt) Pack(payload, senderPubKey []byte, } // Unpack mock message unpacking -func (m *MockAuthCrypt) Unpack(envelope []byte) ([]byte, []byte, error) { +func (m *MockAuthCrypt) Unpack(envelope []byte) (*transport.Envelope, error) { return m.DecryptValue(envelope) } diff --git a/pkg/internal/mock/packer/noop.go b/pkg/internal/mock/didcomm/packer/noop.go similarity index 69% rename from pkg/internal/mock/packer/noop.go rename to pkg/internal/mock/didcomm/packer/noop.go index 3b00679a2..8c888c310 100644 --- a/pkg/internal/mock/packer/noop.go +++ b/pkg/internal/mock/didcomm/packer/noop.go @@ -9,16 +9,19 @@ package packer import ( "encoding/base64" "encoding/json" + "fmt" "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" ) type envelope struct { - Header string `json:"protected,omitempty"` - Sender string `json:"spk,omitempty"` - Message string `json:"msg,omitempty"` + Header string `json:"protected,omitempty"` + Sender string `json:"spk,omitempty"` + Recipient string `json:"kid,omitempty"` + Message string `json:"msg,omitempty"` } type header struct { @@ -50,10 +53,15 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + if len(recipientPubKeys) == 0 { + return nil, fmt.Errorf("no recipients") + } + message := envelope{ - Header: headerB64, - Sender: base58.Encode(sender), - Message: string(payload), + Header: headerB64, + Sender: base58.Encode(sender), + Recipient: base58.Encode(recipientPubKeys[0]), + Message: string(payload), } msgBytes, err := json.Marshal(&message) @@ -62,27 +70,31 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte } // Unpack will decode the envelope using the NOOP format. -func (p *Packer) Unpack(message []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(message []byte) (*transport.Envelope, error) { var env envelope err := json.Unmarshal(message, &env) if err != nil { - return nil, nil, err + return nil, err } headerBytes, err := base64.URLEncoding.DecodeString(env.Header) if err != nil { - return nil, nil, err + return nil, err } var head header err = json.Unmarshal(headerBytes, &head) if err != nil { - return nil, nil, err + return nil, err } - return []byte(env.Message), base58.Decode(env.Sender), nil + return &transport.Envelope{ + Message: []byte(env.Message), + FromVerKey: base58.Decode(env.Sender), + ToVerKey: base58.Decode(env.Recipient), + }, nil } // EncodingType returns the type of the encoding, as found in the header `Typ` field diff --git a/pkg/internal/mock/didcomm/packer/noop_test.go b/pkg/internal/mock/didcomm/packer/noop_test.go new file mode 100644 index 000000000..ecd44aa60 --- /dev/null +++ b/pkg/internal/mock/didcomm/packer/noop_test.go @@ -0,0 +1,111 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package packer + +import ( + "encoding/base64" + "testing" + + "github.com/btcsuite/btcutil/base58" + "github.com/stretchr/testify/require" +) + +// note: does not replicate correct packing +// when msg needs to be escaped. +func testPack(msg, senderKey, recKey []byte) []byte { + headerValue := base64.URLEncoding.EncodeToString([]byte(`{"typ":"NOOP"}`)) + + return []byte(`{"protected":"` + headerValue + + `","spk":"` + base58.Encode(senderKey) + + `","kid":"` + base58.Encode(recKey) + + `","msg":"` + string(msg) + `"}`) +} + +func TestPacker(t *testing.T) { + p := New(nil) + require.NotNil(t, p) + require.Equal(t, encodingType, p.EncodingType()) + + t.Run("no rec keys", func(t *testing.T) { + _, err := p.Pack(nil, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "no recipients") + }) + + t.Run("pack, compare against correct data", func(t *testing.T) { + msgin := []byte("hello my name is zoop") + key := []byte("senderkey") + rec := []byte("recipient") + + msgout, err := p.Pack(msgin, key, [][]byte{rec}) + require.NoError(t, err) + + correct := testPack(msgin, key, rec) + require.Equal(t, correct, msgout) + }) + + t.Run("unpack fixed value, confirm data", func(t *testing.T) { + correct := []byte("this is not a test message") + key := []byte("testKey") + rec := []byte("key2") + msgin := testPack(correct, key, rec) + + envOut, err := p.Unpack(msgin) + require.NoError(t, err) + + require.Equal(t, correct, envOut.Message) + require.Equal(t, key, envOut.FromVerKey) + require.Equal(t, rec, envOut.ToVerKey) + }) + + t.Run("multiple pack/unpacks", func(t *testing.T) { + cleartext := []byte("this is not a test message") + key1 := []byte("testKey") + rec1 := []byte("rec1") + key2 := []byte("wrapperKey") + rec2 := []byte("rec2") + + correct1 := testPack(cleartext, key1, rec1) + + msg1, err := p.Pack(cleartext, key1, [][]byte{rec1}) + require.NoError(t, err) + require.Equal(t, correct1, msg1) + + msg2, err := p.Pack(msg1, key2, [][]byte{rec2}) + require.NoError(t, err) + + env1, err := p.Unpack(msg2) + require.NoError(t, err) + require.Equal(t, key2, env1.FromVerKey) + require.Equal(t, rec2, env1.ToVerKey) + require.Equal(t, correct1, env1.Message) + + env2, err := p.Unpack(env1.Message) + require.NoError(t, err) + require.Equal(t, key1, env2.FromVerKey) + require.Equal(t, rec1, env2.ToVerKey) + require.Equal(t, cleartext, env2.Message) + }) + + t.Run("unpack errors", func(t *testing.T) { + _, err := p.Unpack(nil) + require.Error(t, err) + require.Contains(t, err.Error(), "end of JSON input") + + _, err = p.Unpack([]byte("{}")) + require.Error(t, err) + require.Contains(t, err.Error(), "end of JSON input") + + _, err = p.Unpack([]byte("{\"protected\":\"$$$$$$$$$$$$\"}")) + require.Error(t, err) + require.Contains(t, err.Error(), "illegal base64 data") + + _, err = p.Unpack([]byte("{\"protected\":\"e3t7\"}")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid character") + }) +} diff --git a/pkg/internal/mock/packer/noop_test.go b/pkg/internal/mock/packer/noop_test.go deleted file mode 100644 index 32f6696bf..000000000 --- a/pkg/internal/mock/packer/noop_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* -Copyright SecureKey Technologies Inc. All Rights Reserved. - -SPDX-License-Identifier: Apache-2.0 -*/ - -package packer - -import ( - "encoding/base64" - "testing" - - "github.com/btcsuite/btcutil/base58" - "github.com/stretchr/testify/require" -) - -// note: does not replicate correct packing -// when msg needs to be escaped. -func testPack(msg, key []byte) []byte { - headerValue := base64.URLEncoding.EncodeToString([]byte(`{"typ":"NOOP"}`)) - - return []byte(`{"protected":"` + headerValue + - `","spk":"` + base58.Encode(key) + - `","msg":"` + string(msg) + `"}`) -} - -func TestPacker(t *testing.T) { - p := New(nil) - require.NotNil(t, p) - require.Equal(t, encodingType, p.EncodingType()) - - t.Run("pack, compare against correct data", func(t *testing.T) { - msgin := []byte("hello my name is zoop") - key := []byte("senderkey") - - msgout, err := p.Pack(msgin, key, nil) - require.NoError(t, err) - - correct := testPack(msgin, key) - require.Equal(t, correct, msgout) - }) - - t.Run("unpack fixed value, confirm data", func(t *testing.T) { - correct := []byte("this is not a test message") - key := []byte("testKey") - msgin := testPack(correct, key) - - msgout, keyOut, err := p.Unpack(msgin) - require.NoError(t, err) - - require.Equal(t, correct, msgout) - require.Equal(t, key, keyOut) - }) - - t.Run("multiple pack/unpacks", func(t *testing.T) { - cleartext := []byte("this is not a test message") - key1 := []byte("testKey") - key2 := []byte("wrapperKey") - - correct1 := testPack(cleartext, key1) - - msg1, err := p.Pack(cleartext, key1, nil) - require.NoError(t, err) - require.Equal(t, correct1, msg1) - - msg2, err := p.Pack(msg1, key2, nil) - require.NoError(t, err) - - msg3, key1Out, err := p.Unpack(msg2) - require.NoError(t, err) - require.Equal(t, key2, key1Out) - require.Equal(t, correct1, msg3) - - msg4, key2Out, err := p.Unpack(msg3) - require.NoError(t, err) - require.Equal(t, key1, key2Out) - require.Equal(t, cleartext, msg4) - }) -}