diff --git a/pkg/client/didexchange/client.go b/pkg/client/didexchange/client.go index 2e9fb3fd0..462fcc182 100644 --- a/pkg/client/didexchange/client.go +++ b/pkg/client/didexchange/client.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -40,6 +41,7 @@ type provider interface { InboundTransportEndpoint() string StorageProvider() storage.Provider TransientStorageProvider() storage.Provider + DIDConnectionStore() didconnection.Store } // Client enable access to didexchange api @@ -94,7 +96,7 @@ func New(ctx provider) (*Client, error) { didexchangeSvc: didexchangeSvc, kms: ctx.KMS(), inboundTransportEndpoint: ctx.InboundTransportEndpoint(), - connectionStore: didexchange.NewConnectionRecorder(transientStore, store), + connectionStore: didexchange.NewConnectionRecorder(transientStore, store, ctx.DIDConnectionStore()), }, nil } diff --git a/pkg/client/didexchange/client_test.go b/pkg/client/didexchange/client_test.go index 13029842c..7d826fb36 100644 --- a/pkg/client/didexchange/client_test.go +++ b/pkg/client/didexchange/client_test.go @@ -202,7 +202,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { require.NoError(t, err) require.NoError(t, err) require.NoError(t, transientStore.Put("conn_id1", connBytes)) - c := didexchange.NewConnectionRecorder(transientStore, store) + c := didexchange.NewConnectionRecorder(transientStore, store, nil) result, err := c.GetConnectionRecord(connID) require.NoError(t, err) require.Equal(t, "complete", result.State) @@ -219,7 +219,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { connRec := &didexchange.ConnectionRecord{ConnectionID: connID, ThreadID: threadID, State: "complete"} connBytes, err := json.Marshal(connRec) require.NoError(t, err) - c := didexchange.NewConnectionRecorder(transientStore, store) + c := didexchange.NewConnectionRecorder(transientStore, store, nil) require.NoError(t, transientStore.Put("conn_id1", connBytes)) _, err = c.GetConnectionRecord(connID) require.Error(t, err) @@ -234,7 +234,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { transientStore := mockstore.MockStore{ErrGet: storage.ErrDataNotFound} store := mockstore.MockStore{} require.NoError(t, err) - c := didexchange.NewConnectionRecorder(&transientStore, &store) + c := didexchange.NewConnectionRecorder(&transientStore, &store, nil) _, err = c.GetConnectionRecord(connID) require.Error(t, err) require.True(t, errors.Is(err, storage.ErrDataNotFound)) diff --git a/pkg/didcomm/common/didconnection/api.go b/pkg/didcomm/common/didconnection/api.go new file mode 100644 index 000000000..4b7ba542c --- /dev/null +++ b/pkg/didcomm/common/didconnection/api.go @@ -0,0 +1,27 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "errors" + + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" +) + +// Store stores DIDs indexed by public key, so agents can find the DID associated with a given key. +type Store interface { + // SaveDID saves a DID indexed by the given public keys to the Store + SaveDID(did string, keys ...string) error + // GetDID gets the DID stored under the given key + GetDID(key string) (string, error) + // SaveDIDByResolving resolves a DID using the VDR then saves the map from keys -> did + SaveDIDByResolving(did string, keys ...string) error + // SaveDIDFromDoc saves a map from keys -> did for a did doc + SaveDIDFromDoc(doc *diddoc.Doc) error +} + +// ErrNotFound signals that the entry for the given DID and key is not present in the store. +var ErrNotFound = errors.New("did not found under given key") diff --git a/pkg/didcomm/common/didconnection/didconnection.go b/pkg/didcomm/common/didconnection/didconnection.go new file mode 100644 index 000000000..d5b3eb51b --- /dev/null +++ b/pkg/didcomm/common/didconnection/didconnection.go @@ -0,0 +1,110 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/json" + "errors" + "fmt" + + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" + "github.com/hyperledger/aries-framework-go/pkg/storage" +) + +// ConnectionStore stores DIDs indexed by key +type ConnectionStore struct { + store storage.Store + vdr vdri.Registry +} + +type didRecord struct { + DID string `json:"did,omitempty"` +} + +type provider interface { + StorageProvider() storage.Provider + VDRIRegistry() vdri.Registry +} + +// New returns a new did lookup Store +func New(ctx provider) (*ConnectionStore, error) { + store, err := ctx.StorageProvider().OpenStore("didconnection") + if err != nil { + return nil, err + } + + return &ConnectionStore{store: store, vdr: ctx.VDRIRegistry()}, nil +} + +// saveDID saves a DID, indexed using the given public key +func (c *ConnectionStore) saveDID(did, key string) error { + data := didRecord{ + DID: did, + } + + bytes, err := json.Marshal(data) + if err != nil { + return err + } + + return c.store.Put(key, bytes) +} + +// SaveDID saves a DID, indexed using the given public keys +func (c *ConnectionStore) SaveDID(did string, keys ...string) error { + for _, key := range keys { + err := c.saveDID(did, key) + if err != nil { + return fmt.Errorf("saving DID in did map: %w", err) + } + } + + return nil +} + +// SaveDIDFromDoc saves a map from a did doc's keys to the did +func (c *ConnectionStore) SaveDIDFromDoc(doc *diddoc.Doc) error { + var keys []string + for i := range doc.PublicKey { + keys = append(keys, string(doc.PublicKey[i].Value)) + } + + return c.SaveDID(doc.ID, keys...) +} + +// SaveDIDByResolving resolves a DID using the VDR then saves the map from keys -> did +// keys: fallback keys in case the DID can't be resolved +func (c *ConnectionStore) SaveDIDByResolving(did string, keys ...string) error { + doc, err := c.vdr.Resolve(did) + if errors.Is(err, vdri.ErrNotFound) { + return c.SaveDID(did, keys...) + } else if err != nil { + return err + } + + return c.SaveDIDFromDoc(doc) +} + +// GetDID gets the DID stored under the given key +func (c *ConnectionStore) GetDID(key string) (string, error) { + bytes, err := c.store.Get(key) + if errors.Is(err, storage.ErrDataNotFound) { + return "", ErrNotFound + } else if err != nil { + return "", err + } + + var record didRecord + + err = json.Unmarshal(bytes, &record) + if err != nil { + return "", err + } + + return record.DID, nil +} diff --git a/pkg/didcomm/common/didconnection/didconnection_test.go b/pkg/didcomm/common/didconnection/didconnection_test.go new file mode 100644 index 000000000..7cad6e427 --- /dev/null +++ b/pkg/didcomm/common/didconnection/didconnection_test.go @@ -0,0 +1,129 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" + mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" + mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" + "github.com/hyperledger/aries-framework-go/pkg/storage" +) + +type ctx struct { + store storage.Provider + vdr vdriapi.Registry +} + +func (c *ctx) StorageProvider() storage.Provider { + return c.store +} + +func (c *ctx) VDRIRegistry() vdriapi.Registry { + return c.vdr +} + +func TestBaseConnectionStore(t *testing.T) { + prov := ctx{ + store: mockstorage.NewMockStoreProvider(), + vdr: &mockvdri.MockVDRIRegistry{ + CreateValue: mockdiddoc.GetMockDIDDoc(), + ResolveValue: mockdiddoc.GetMockDIDDoc(), + }, + } + + t.Run("New", func(t *testing.T) { + _, err := New(&prov) + require.NoError(t, err) + + _, err = New(&ctx{ + store: &mockstorage.MockStoreProvider{ + ErrOpenStoreHandle: fmt.Errorf("store error"), + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "store error") + }) + + t.Run("SaveDID error", func(t *testing.T) { + cs, err := New(&ctx{ + store: &mockstorage.MockStoreProvider{ + Store: &mockstorage.MockStore{ + Store: map[string][]byte{}, + ErrPut: fmt.Errorf("put error"), + }, + }, + vdr: &mockvdri.MockVDRIRegistry{ + CreateValue: mockdiddoc.GetMockDIDDoc(), + ResolveValue: mockdiddoc.GetMockDIDDoc(), + }, + }) + require.NoError(t, err) + + err = cs.SaveDID("did", "key") + require.Error(t, err) + require.Contains(t, err.Error(), "put error") + }) + + t.Run("SaveDID + GetDID", func(t *testing.T) { + connStore, err := New(&prov) + require.NoError(t, err) + + err = connStore.SaveDID("did:abcde", "abcde") + require.NoError(t, err) + + didVal, err := connStore.GetDID("abcde") + require.NoError(t, err) + require.Equal(t, "did:abcde", didVal) + + wrong, err := connStore.GetDID("fhtagn") + require.EqualError(t, err, ErrNotFound.Error()) + require.Equal(t, "", wrong) + + err = connStore.store.Put("bad-data", []byte("aaooga")) + require.NoError(t, err) + + _, err = connStore.GetDID("bad-data") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid character") + }) + + t.Run("SaveDIDFromDoc", func(t *testing.T) { + connStore, err := New(&prov) + require.NoError(t, err) + + err = connStore.SaveDIDFromDoc(mockdiddoc.GetMockDIDDoc()) + require.NoError(t, err) + }) + + t.Run("SaveDIDByResolving success", func(t *testing.T) { + cs, err := New(&prov) + require.NoError(t, err) + + err = cs.SaveDIDByResolving(mockdiddoc.GetMockDIDDoc().ID) + require.NoError(t, err) + }) + + t.Run("SaveDIDByResolving error", func(t *testing.T) { + prov := ctx{ + store: mockstorage.NewMockStoreProvider(), + vdr: &mockvdri.MockVDRIRegistry{ResolveErr: fmt.Errorf("resolve error")}, + } + + cs, err := New(&prov) + require.NoError(t, err) + + err = cs.SaveDIDByResolving("did") + require.Error(t, err) + require.Contains(t, err.Error(), "resolve error") + }) +} diff --git a/pkg/didcomm/common/service/destination.go b/pkg/didcomm/common/service/destination.go index 67bf07701..7b1339150 100644 --- a/pkg/didcomm/common/service/destination.go +++ b/pkg/didcomm/common/service/destination.go @@ -24,7 +24,8 @@ type Destination struct { const ( didCommServiceType = "did-communication" - ed25519KeyType = "Ed25519VerificationKey2018" + // TODO: hardcoded key type https://github.com/hyperledger/aries-framework-go/issues/1008 + ed25519KeyType = "Ed25519VerificationKey2018" ) // GetDestination constructs a Destination struct based on the given DID and parameters diff --git a/pkg/didcomm/common/transport/envelope.go b/pkg/didcomm/common/transport/envelope.go index c910c7a0f..709bec9be 100644 --- a/pkg/didcomm/common/transport/envelope.go +++ b/pkg/didcomm/common/transport/envelope.go @@ -6,9 +6,14 @@ SPDX-License-Identifier: Apache-2.0 package transport -// Envelope contain msg, FromVerKey and ToVerKeys +// Envelope holds message data and metadata for inbound and outbound messaging type Envelope struct { Message []byte - FromVerKey string - ToVerKeys []string + FromVerKey []byte + // ToVerKeys stores string (base58) verification keys for an outbound message + ToVerKeys []string + // ToVerKey holds the key that was used to decrypt an inbound message + ToVerKey []byte + FromDID string + ToDID string } diff --git a/pkg/didcomm/dispatcher/outbound.go b/pkg/didcomm/dispatcher/outbound.go index 6f8ba3fd0..2acbd193a 100644 --- a/pkg/didcomm/dispatcher/outbound.go +++ b/pkg/didcomm/dispatcher/outbound.go @@ -11,10 +11,13 @@ import ( "fmt" "strings" + "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" ) // provider interface for outbound ctx @@ -22,6 +25,7 @@ type provider interface { Packager() commontransport.Packager OutboundTransports() []transport.OutboundTransport TransportReturnRoute() string + VDRIRegistry() vdri.Registry } // OutboundDispatcher dispatch msgs to destination @@ -29,6 +33,7 @@ type OutboundDispatcher struct { outboundTransports []transport.OutboundTransport packager commontransport.Packager transportReturnRoute string + vdRegistry vdri.Registry } // NewOutbound return new dispatcher outbound instance @@ -37,12 +42,28 @@ func NewOutbound(prov provider) *OutboundDispatcher { outboundTransports: prov.OutboundTransports(), packager: prov.Packager(), transportReturnRoute: prov.TransportReturnRoute(), + vdRegistry: prov.VDRIRegistry(), } } -// SendToDID msg +// SendToDID sends a message from myDID to the agent who owns theirDID func (o *OutboundDispatcher) SendToDID(msg interface{}, myDID, theirDID string) error { - return nil + dest, err := service.GetDestination(theirDID, o.vdRegistry) + if err != nil { + return err + } + + src, err := service.GetDestination(myDID, o.vdRegistry) + if err != nil { + return err + } + + // We get at least one recipient key, so we can use the first one + // (right now, with only one key type used for sending) + // TODO: relies on hardcoded key type + key := src.RecipientKeys[0] + + return o.Send(msg, key, dest) } // Send sends the message after packing with the sender key and recipient keys. @@ -79,7 +100,7 @@ func (o *OutboundDispatcher) Send(msg interface{}, senderVerKey string, des *ser } packedMsg, err := o.packager.PackMessage( - &commontransport.Envelope{Message: req, FromVerKey: senderVerKey, ToVerKeys: des.RecipientKeys}) + &commontransport.Envelope{Message: req, FromVerKey: base58.Decode(senderVerKey), ToVerKeys: des.RecipientKeys}) if err != nil { return fmt.Errorf("failed to pack msg: %w", err) } diff --git a/pkg/didcomm/dispatcher/outbound_test.go b/pkg/didcomm/dispatcher/outbound_test.go index eb197dec1..637028be7 100644 --- a/pkg/didcomm/dispatcher/outbound_test.go +++ b/pkg/didcomm/dispatcher/outbound_test.go @@ -19,8 +19,11 @@ import ( commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" mockdidcomm "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" + mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" ) func TestOutboundDispatcher_Send(t *testing.T) { @@ -58,6 +61,40 @@ func TestOutboundDispatcher_Send(t *testing.T) { }) } +func TestOutboundDispatcher_SendToDID(t *testing.T) { + mockDoc := mockdiddoc.GetMockDIDDoc() + + t.Run("success", func(t *testing.T) { + o := NewOutbound(&mockProvider{ + packagerValue: &mockpackager.Packager{}, + vdriRegistry: &mockvdri.MockVDRIRegistry{ + ResolveValue: mockDoc, + }, + outboundTransportsValue: []transport.OutboundTransport{ + &mockdidcomm.MockOutboundTransport{AcceptValue: true}, + }, + }) + + require.NoError(t, o.SendToDID("data", "", "")) + }) + + t.Run("resolve err", func(t *testing.T) { + o := NewOutbound(&mockProvider{ + packagerValue: &mockpackager.Packager{}, + vdriRegistry: &mockvdri.MockVDRIRegistry{ + ResolveErr: fmt.Errorf("resolve error"), + }, + outboundTransportsValue: []transport.OutboundTransport{ + &mockdidcomm.MockOutboundTransport{AcceptValue: true}, + }, + }) + + err := o.SendToDID("data", "", "") + require.Error(t, err) + require.Contains(t, err.Error(), "resolve error") + }) +} + func TestOutboundDispatcherTransportReturnRoute(t *testing.T) { t.Run("transport route option - value set all", func(t *testing.T) { transportReturnRoute := "all" @@ -168,6 +205,7 @@ type mockProvider struct { packagerValue commontransport.Packager outboundTransportsValue []transport.OutboundTransport transportReturnRoute string + vdriRegistry vdri.Registry } func (p *mockProvider) Packager() commontransport.Packager { @@ -182,6 +220,10 @@ func (p *mockProvider) TransportReturnRoute() string { return p.transportReturnRoute } +func (p *mockProvider) VDRIRegistry() vdri.Registry { + return p.vdriRegistry +} + // mockOutboundTransport mock outbound transport type mockOutboundTransport struct { expectedRequest string diff --git a/pkg/didcomm/packager/package_test.go b/pkg/didcomm/packager/package_test.go index 5f7ab1f61..39e4827d4 100644 --- a/pkg/didcomm/packager/package_test.go +++ b/pkg/didcomm/packager/package_test.go @@ -11,14 +11,17 @@ import ( "fmt" "testing" + "github.com/btcsuite/btcutil/base58" "github.com/stretchr/testify/require" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" . "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" jwe "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/jwe/authcrypt" legacy "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/legacy/authcrypt" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -26,16 +29,15 @@ import ( func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test failed to unmarshal encMessage", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -49,16 +51,15 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test bad encoding type", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -87,17 +88,16 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test key not found", func(t *testing.T) { - wp := newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }}) + wp := newMockKMSProvider(newMockStoreProvider()) w, err := kms.New(wp) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -117,7 +117,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // PackMessage should pass with both value from and to verification keys packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -131,27 +131,26 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test Pack/Unpack fails", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) - decryptValue := func(envelope []byte) ([]byte, []byte, error) { - return nil, nil, fmt.Errorf("unpack error") + decryptValue := func(envelope []byte) (*transport.Envelope, error) { + return nil, fmt.Errorf("unpack error") } mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // use a mocked packager with a mocked KMS to validate pack/unpack e := func(payload []byte, senderPubKey []byte, recipientsKeys [][]byte) (bytes []byte, e error) { - packer, e := jwe.New(mockedProviders, jwe.XC20P) + p, e := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, e) - return packer.Pack(payload, senderPubKey, recipientsKeys) + return p.Pack(payload, senderPubKey, recipientsKeys) } mockPacker := &didcomm.MockAuthCrypt{DecryptValue: decryptValue, EncryptValue: e, Type: "prs.hyperledger.aries-auth-message"} @@ -174,7 +173,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // now try to pack with non empty envelope - should pass packMsg, err = packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) require.NotEmpty(t, packMsg) @@ -194,7 +193,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { packager, err = New(mockedProviders) require.NoError(t, err) packMsg, err = packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.Error(t, err) require.Empty(t, packMsg) @@ -203,14 +202,14 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test Pack/Unpack success", func(t *testing.T) { // create a mock KMS with storage as a map - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: map[string][]byte{}}})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // create a real testPacker (no mocking here) @@ -233,7 +232,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { // pack an non empty envelope - should pass packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -250,7 +249,7 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { require.NoError(t, err) packMsg, err = packager2.PackMessage(&transport.Envelope{Message: []byte("msg2"), - FromVerKey: base58FromVerKey, + FromVerKey: base58.Decode(base58FromVerKey), ToVerKeys: []string{base58ToVerKey}}) require.NoError(t, err) @@ -259,10 +258,99 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { require.NoError(t, err) require.Equal(t, unpackedMsg.Message, []byte("msg2")) }) + + t.Run("test success - dids not found", func(t *testing.T) { + // create a mock KMS with storage as a map + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) + require.NoError(t, err) + mockedProviders := &mockProvider{ + storage: newMockStoreProvider(), + kms: w, + primaryPacker: nil, + packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDErr: didconnection.ErrNotFound}, + } + + // create a real testPacker (no mocking here) + testPacker := legacy.New(mockedProviders) + require.NoError(t, err) + mockedProviders.primaryPacker = testPacker + + mockedProviders.packers = []packer.Packer{testPacker} + + // now create a new packager with the above provider context + packager, err := New(mockedProviders) + require.NoError(t, err) + + _, base58FromVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + _, base58ToVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + // pack an non empty envelope - should pass + packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), + FromVerKey: base58.Decode(base58FromVerKey), + ToVerKeys: []string{base58ToVerKey}}) + require.NoError(t, err) + + // unpack the packed message above - should pass and match the same payload (msg1) + unpackedMsg, err := packager.UnpackMessage(packMsg) + require.NoError(t, err) + require.Equal(t, unpackedMsg.Message, []byte("msg1")) + }) + + t.Run("test failure - did lookup broke", func(t *testing.T) { + // create a mock KMS with storage as a map + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) + require.NoError(t, err) + mockedProviders := &mockProvider{ + storage: newMockStoreProvider(), + kms: w, + primaryPacker: nil, + packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDErr: fmt.Errorf("bad error")}, + } + + // create a real testPacker (no mocking here) + testPacker := legacy.New(mockedProviders) + require.NoError(t, err) + mockedProviders.primaryPacker = testPacker + + mockedProviders.packers = []packer.Packer{testPacker} + + // now create a new packager with the above provider context + packager, err := New(mockedProviders) + require.NoError(t, err) + + _, base58FromVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + _, base58ToVerKey, err := w.CreateKeySet() + require.NoError(t, err) + + // pack an non empty envelope - should pass + packMsg, err := packager.PackMessage(&transport.Envelope{Message: []byte("msg1"), + FromVerKey: base58.Decode(base58FromVerKey), + ToVerKeys: []string{base58ToVerKey}}) + require.NoError(t, err) + + // unpack the packed message above - should pass and match the same payload (msg1) + unpackedMsg, err := packager.UnpackMessage(packMsg) + require.Error(t, err) + require.Contains(t, err.Error(), "bad error") + require.Nil(t, unpackedMsg) + }) } func newMockKMSProvider(storagePvdr *mockstorage.MockStoreProvider) *mockProvider { - return &mockProvider{storagePvdr, nil, nil, nil} + return &mockProvider{storagePvdr, nil, nil, nil, nil} +} + +func newMockStoreProvider() *mockstorage.MockStoreProvider { + return &mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ + Store: make(map[string][]byte), + }} } // mockProvider mocks provider for KMS @@ -271,6 +359,7 @@ type mockProvider struct { kms kms.KeyManager packers []packer.Packer primaryPacker packer.Packer + lookupStore didconnection.Store } func (m *mockProvider) Packers() []packer.Packer { @@ -288,3 +377,8 @@ func (m *mockProvider) StorageProvider() storage.Provider { func (m *mockProvider) PrimaryPacker() packer.Packer { return m.primaryPacker } + +// DIDConnectionStore returns a didconnection.Store service. +func (m *mockProvider) DIDConnectionStore() didconnection.Store { + return m.lookupStore +} diff --git a/pkg/didcomm/packager/packager.go b/pkg/didcomm/packager/packager.go index 8ceda79c6..cb36fc1f1 100644 --- a/pkg/didcomm/packager/packager.go +++ b/pkg/didcomm/packager/packager.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" ) @@ -22,6 +23,7 @@ import ( type Provider interface { Packers() []packer.Packer PrimaryPacker() packer.Packer + DIDConnectionStore() didconnection.Store } // Creator method to create new packager service @@ -29,8 +31,9 @@ type Creator func(prov Provider) (transport.Packager, error) // Packager is the basic implementation of Packager type Packager struct { - primaryPacker packer.Packer - packers map[string]packer.Packer + primaryPacker packer.Packer + packers map[string]packer.Packer + connectionStore didconnection.Store } // PackerCreator holds a creator function for a Packer and the name of the Packer's encoding method. @@ -42,8 +45,9 @@ type PackerCreator struct { // New return new instance of KMS implementation func New(ctx Provider) (*Packager, error) { basePackager := Packager{ - primaryPacker: nil, - packers: map[string]packer.Packer{}, + primaryPacker: nil, + packers: map[string]packer.Packer{}, + connectionStore: ctx.DIDConnectionStore(), } for _, packerType := range ctx.Packers() { @@ -52,7 +56,7 @@ func New(ctx Provider) (*Packager, error) { basePackager.primaryPacker = ctx.PrimaryPacker() if basePackager.primaryPacker == nil { - return nil, fmt.Errorf("need primary primaryPacker to initialize packager") + return nil, fmt.Errorf("need primary packer to initialize packager") } basePackager.addPacker(basePackager.primaryPacker) @@ -86,7 +90,7 @@ func (bp *Packager) PackMessage(messageEnvelope *transport.Envelope) ([]byte, er recipients = append(recipients, verKeyBytes) } // pack message - bytes, err := bp.primaryPacker.Pack(messageEnvelope.Message, base58.Decode(messageEnvelope.FromVerKey), recipients) + bytes, err := bp.primaryPacker.Pack(messageEnvelope.Message, messageEnvelope.FromVerKey, recipients) if err != nil { return nil, fmt.Errorf("pack: %w", err) } @@ -146,10 +150,27 @@ func (bp *Packager) UnpackMessage(encMessage []byte) (*transport.Envelope, error return nil, fmt.Errorf("message Type not recognized") } - data, senderVerKey, err := p.Unpack(encMessage) + envelope, err := p.Unpack(encMessage) if err != nil { return nil, fmt.Errorf("unpack: %w", err) } - return &transport.Envelope{Message: data, FromVerKey: base58.Encode(senderVerKey)}, nil + // ignore error - agents can communicate without using DIDs - for example, in DIDExchange + theirDID, err := bp.connectionStore.GetDID(base58.Encode(envelope.FromVerKey)) + if errors.Is(err, didconnection.ErrNotFound) { + } else if err != nil { + return nil, fmt.Errorf("failed to get their did: %w", err) + } + + // ignore error - at beginning of DIDExchange, you might be about to generate a DID + myDID, err := bp.connectionStore.GetDID(base58.Encode(envelope.ToVerKey)) + if errors.Is(err, didconnection.ErrNotFound) { + } else if err != nil { + return nil, fmt.Errorf("failed to get my did: %w", err) + } + + envelope.ToDID = myDID + envelope.FromDID = theirDID + + return envelope, nil } diff --git a/pkg/didcomm/packer/api.go b/pkg/didcomm/packer/api.go index 0c7b50503..ab5c0803c 100644 --- a/pkg/didcomm/packer/api.go +++ b/pkg/didcomm/packer/api.go @@ -6,7 +6,10 @@ SPDX-License-Identifier: Apache-2.0 package packer -import "github.com/hyperledger/aries-framework-go/pkg/kms" +import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/kms" +) // Provider interface for Packer ctx type Provider interface { @@ -30,11 +33,10 @@ type Packer interface { // The recipient's key will be the one found in KMS that matches one of the list of recipients in the envelope // // returns: - // []byte containing the decrypted payload - // []byte contains the sender verification key + // Envelope containing the message, decryption key, and sender key // error if decryption failed // TODO add key type of recipients keys to be validated by the implementation - Issue #272 - Unpack(envelope []byte) ([]byte, []byte, error) + Unpack(envelope []byte) (*transport.Envelope, error) // Encoding returns the type of the encoding, as found in the header `Typ` field EncodingType() string diff --git a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go index d2dadf8ca..87f134beb 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go @@ -304,12 +304,14 @@ func TestEncrypt(t *testing.T) { t.Logf("Encryption with XC20P: %s", m) // decrypt for rec1 (as found in kms) - dec, senderVerKey, e := packer.Unpack(enc) + env, e := packer.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) }) t.Run("Success test case: : pack and unpack", func(t *testing.T) { @@ -331,10 +333,17 @@ func TestEncrypt(t *testing.T) { [][]byte{base58.Decode(recSign)}) require.NoError(t, err) - msgOut, sendKey, err := recPacker.Unpack(enc) + env, err := recPacker.Unpack(enc) require.NoError(t, err) - require.Equal(t, msgIn, msgOut) - require.Equal(t, base58.Encode(sendKey), encSign) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.Equal(t, msgIn, env.Message) + require.NotEmpty(t, env.FromVerKey) + require.Equal(t, base58.Encode(env.FromVerKey), encSign) + // this won't work, since input keys and output keys to this Packer + // are different types. + // require.Equal(t, base58.Encode(recKey), recSign) + require.NotEmpty(t, env.ToVerKey) }) t.Run("Success test case: Decrypting a message with two PackerValue instances to simulate two agents", func(t *testing.T) { //nolint:lll @@ -356,22 +365,26 @@ func TestEncrypt(t *testing.T) { // now decrypt with recipient3 packer1, e := New(recipient3KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + env, e := packer1.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) // now try decrypting with recipient2 packer2, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e = packer2.Unpack(enc) + env, e = packer2.Unpack(enc) require.NoError(t, e) - require.NotEmpty(t, dec) - require.EqualValues(t, dec, pld) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.EqualValues(t, env.Message, pld) + require.NotEmpty(t, env.FromVerKey) require.Equal(t, base64.RawURLEncoding.EncodeToString(sender.EncKeyPair.Pub), - base64.RawURLEncoding.EncodeToString(senderVerKey)) + base64.RawURLEncoding.EncodeToString(env.FromVerKey)) t.Logf("Decryption Payload with XC20P: %s", pld) }) @@ -392,10 +405,9 @@ func TestEncrypt(t *testing.T) { // decrypting for recipient 2 (unauthorized) packer1, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + env, e := packer1.Unpack(enc) require.Error(t, e) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) }) t.Run("Failure test case: Decrypting a message but scramble JWE beforehand", func(t *testing.T) { @@ -427,10 +439,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e := packer.Unpack(enc) + env, e := packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 12") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.CipherText = validJwe.CipherText // update jwe with bad nonce format @@ -438,10 +449,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 5") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.IV = validJwe.IV // update jwe with bad tag format @@ -449,10 +459,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag format - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Tag = validJwe.Tag // update jwe with bad recipient spk (JWE format) @@ -460,10 +469,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: sender key: bad SPK format") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.SPK = validJwe.Recipients[0].Header.SPK // update jwe with bad recipient tag format @@ -471,10 +479,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.Tag = validJwe.Recipients[0].Header.Tag // update jwe with bad recipient nonce format @@ -482,10 +489,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 5") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV // update jwe with bad recipient nonce format @@ -493,10 +499,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: bad nonce size") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV // update jwe with bad recipient apu format @@ -504,10 +509,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.APU = validJwe.Recipients[0].Header.APU // update jwe with bad recipient kid (sender) format @@ -515,10 +519,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, fmt.Sprintf("unpack: %s", cryptoutil.ErrKeyNotFound.Error())) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.KID = validJwe.Recipients[0].Header.KID // update jwe with bad recipient CEK (encrypted key) format @@ -526,10 +529,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 15") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].EncryptedKey = validJwe.Recipients[0].EncryptedKey // update jwe with bad recipient CEK (encrypted key) value @@ -537,10 +539,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].EncryptedKey = validJwe.Recipients[0].EncryptedKey // now try bad nonce size @@ -549,10 +550,9 @@ func TestEncrypt(t *testing.T) { require.NoError(t, e) // decrypt with bad nonce value require.PanicsWithValue(t, "chacha20poly1305: bad nonce length passed to Open", func() { - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) }) - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.IV = validJwe.IV // now try bad nonce value @@ -560,10 +560,9 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + env, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") - require.Empty(t, dec) - require.Empty(t, senderVerKey) + require.Empty(t, env) jwe.Recipients[0].Header.IV = validJwe.Recipients[0].Header.IV }) } @@ -658,8 +657,9 @@ func TestRefEncrypt(t *testing.T) { require.NoError(t, err) require.NotNil(t, packer) - dec, senderVerKey, err := packer.Unpack([]byte(refJWE)) + env, err := packer.Unpack([]byte(refJWE)) require.NoError(t, err) - require.NotEmpty(t, dec) - require.NotEmpty(t, senderVerKey) + require.NotEmpty(t, env) + require.NotEmpty(t, env.Message) + require.NotEmpty(t, env.FromVerKey) } diff --git a/pkg/didcomm/packer/jwe/authcrypt/unpack.go b/pkg/didcomm/packer/jwe/authcrypt/unpack.go index ecbf4db55..01e70c3b9 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/unpack.go +++ b/pkg/didcomm/packer/jwe/authcrypt/unpack.go @@ -14,6 +14,8 @@ import ( "github.com/btcsuite/btcutil/base58" chacha "golang.org/x/crypto/chacha20poly1305" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" ) // Unpack will JWE decode the envelope argument for the recipientPrivKey and validates @@ -22,22 +24,22 @@ import ( // encrypted CEK. // The current recipient is the one with the sender's encrypted key that successfully // decrypts with recipientKeyPair.Priv Key. -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) (*transport.Envelope, error) { jwe := &Envelope{} err := json.Unmarshal(envelope, jwe) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, fmt.Errorf("unpack json: %w", err) } recipientPubKey, recipient, err := p.findRecipient(jwe.Recipients) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, fmt.Errorf("unpack: %w", err) } senderKey, err := p.decryptSPK(recipientPubKey, recipient.Header.SPK) if err != nil { - return nil, nil, fmt.Errorf("unpack: sender key: %w", err) + return nil, fmt.Errorf("unpack: sender key: %w", err) } // senderKey must not be empty to proceed @@ -47,18 +49,22 @@ func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { sharedKey, er := p.decryptCEK(recipientPubKey, senderPubKey, recipient) if er != nil { - return nil, nil, fmt.Errorf("unpack: decrypt shared key: %w", er) + return nil, fmt.Errorf("unpack: decrypt shared key: %w", er) } symOutput, er := p.decryptPayload(sharedKey, jwe) if er != nil { - return nil, nil, fmt.Errorf("unpack: %w", er) + return nil, fmt.Errorf("unpack: %w", er) } - return symOutput, senderKey, nil + return &transport.Envelope{ + Message: symOutput, + FromVerKey: senderKey, + ToVerKey: recipientPubKey[:], + }, nil } - return nil, nil, errors.New("unpack: invalid sender key in envelope") + return nil, errors.New("unpack: invalid sender key in envelope") } func (p *Packer) decryptPayload(cek []byte, jwe *Envelope) ([]byte, error) { diff --git a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go index 6f7e76737..d92146777 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go @@ -384,11 +384,12 @@ func TestDecrypt(t *testing.T) { enc, err := packer.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(recKey)}) require.NoError(t, err) - msgOut, senderVerKey, err := packer.Unpack(enc) + env, err := packer.Unpack(enc) require.NoError(t, err) - require.ElementsMatch(t, msgIn, msgOut) - require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.ElementsMatch(t, msgIn, env.Message) + require.Equal(t, senderKey, base58.Encode(env.FromVerKey)) + require.Equal(t, recKey, base58.Encode(env.ToVerKey)) }) t.Run("Success: pack and unpack, different packers, including fail recipient who wasn't sent the message", func(t *testing.T) { // nolint: lll @@ -413,15 +414,16 @@ func TestDecrypt(t *testing.T) { enc, err := sendPacker.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(rec1Key), base58.Decode(rec2Key), base58.Decode(rec3Key)}) require.NoError(t, err) - msgOut, senderVerKey, err := rec2Packer.Unpack(enc) + env, err := rec2Packer.Unpack(enc) require.NoError(t, err) - require.ElementsMatch(t, msgIn, msgOut) - require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.ElementsMatch(t, msgIn, env.Message) + require.Equal(t, senderKey, base58.Encode(env.FromVerKey)) + require.Equal(t, rec2Key, base58.Encode(env.ToVerKey)) emptyKMS, _ := newKMS(t) rec4Packer := newWithKMS(emptyKMS) - _, _, err = rec4Packer.Unpack(enc) + _, err = rec4Packer.Unpack(enc) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -440,10 +442,12 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + envOut, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) - require.ElementsMatch(t, []byte(msg), msgOut) - require.NotEmpty(t, senderVerKey) + require.ElementsMatch(t, []byte(msg), envOut.Message) + require.NotEmpty(t, envOut.FromVerKey) + require.NotEmpty(t, envOut.ToVerKey) + require.Equal(t, recPub, base58.Encode(envOut.ToVerKey)) }) t.Run("Test unpacking python envelope with multiple recipients", func(t *testing.T) { @@ -461,10 +465,12 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + envOut, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) - require.ElementsMatch(t, []byte(msg), msgOut) - require.NotEmpty(t, senderVerKey) + require.ElementsMatch(t, []byte(msg), envOut.Message) + require.NotEmpty(t, envOut.FromVerKey) + require.NotEmpty(t, envOut.ToVerKey) + require.Equal(t, recPub, base58.Encode(envOut.ToVerKey)) }) t.Run("Test unpacking python envelope with invalid recipient", func(t *testing.T) { @@ -479,7 +485,7 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - _, _, err = recPacker.Unpack([]byte(env)) + _, err = recPacker.Unpack([]byte(env)) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -502,7 +508,7 @@ func unpackComponentFailureTest( } recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(fullMessage)) + _, err = recPacker.Unpack([]byte(fullMessage)) require.NotNil(t, err) require.Contains(t, err.Error(), errString) } @@ -520,7 +526,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "invalid character 'e' looking for beginning of value") }) @@ -533,7 +539,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "illegal base64 data at input byte 0") }) @@ -694,7 +700,7 @@ func Test_getCEK(t *testing.T) { }, } - _, _, err := getCEK(recs, &k) + _, err := getCEK(recs, &k) require.EqualError(t, err, "mock error") } diff --git a/pkg/didcomm/packer/legacy/authcrypt/unpack.go b/pkg/didcomm/packer/legacy/authcrypt/unpack.go index ade93d45c..07419dc99 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/unpack.go +++ b/pkg/didcomm/packer/legacy/authcrypt/unpack.go @@ -14,52 +14,65 @@ import ( "github.com/btcsuite/btcutil/base58" chacha "golang.org/x/crypto/chacha20poly1305" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/internal/cryptoutil" "github.com/hyperledger/aries-framework-go/pkg/kms" ) // Unpack will decode the envelope using the legacy format // Using (X)Chacha20 encryption algorithm and Poly1035 authenticator -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) (*transport.Envelope, error) { var envelopeData legacyEnvelope err := json.Unmarshal(envelope, &envelopeData) if err != nil { - return nil, nil, err + return nil, err } protectedBytes, err := base64.URLEncoding.DecodeString(envelopeData.Protected) if err != nil { - return nil, nil, err + return nil, err } var protectedData protected err = json.Unmarshal(protectedBytes, &protectedData) if err != nil { - return nil, nil, err + return nil, err } if protectedData.Typ != encodingType { - return nil, nil, fmt.Errorf("message type %s not supported", protectedData.Typ) + return nil, fmt.Errorf("message type %s not supported", protectedData.Typ) } if protectedData.Alg != "Authcrypt" { // TODO https://github.com/hyperledger/aries-framework-go/issues/41 change this when anoncrypt is introduced - return nil, nil, fmt.Errorf("message format %s not supported", protectedData.Alg) + return nil, fmt.Errorf("message format %s not supported", protectedData.Alg) } - cek, recKey, err := getCEK(protectedData.Recipients, p.kms) + keys, err := getCEK(protectedData.Recipients, p.kms) if err != nil { - return nil, nil, err + return nil, err } + cek, senderKey, recKey := keys.cek, keys.theirKey, keys.myKey + data, err := p.decodeCipherText(cek, &envelopeData) - return data, recKey, err + return &transport.Envelope{ + Message: data, + FromVerKey: senderKey, + ToVerKey: recKey, + }, err +} + +type keys struct { + cek *[chacha.KeySize]byte + theirKey []byte + myKey []byte } -func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, []byte, error) { +func getCEK(recipients []recipient, km kms.KeyManager) (*keys, error) { var candidateKeys []string for _, candidate := range recipients { @@ -68,47 +81,51 @@ func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, [ recKeyIdx, err := km.FindVerKey(candidateKeys) if err != nil { - return nil, nil, fmt.Errorf("no key accessible %w", err) + return nil, fmt.Errorf("no key accessible %w", err) } recip := recipients[recKeyIdx] - recKey := recip.Header.KID + recKey := base58.Decode(recip.Header.KID) - recCurvePub, err := km.ConvertToEncryptionKey(base58.Decode(recKey)) + recCurvePub, err := km.ConvertToEncryptionKey(recKey) if err != nil { - return nil, nil, err + return nil, err } senderPub, senderPubCurve, err := decodeSender(recip.Header.Sender, recCurvePub, km) if err != nil { - return nil, nil, err + return nil, err } nonceSlice, err := base64.URLEncoding.DecodeString(recip.Header.IV) if err != nil { - return nil, nil, err + return nil, err } encCEK, err := base64.URLEncoding.DecodeString(recip.EncryptedKey) if err != nil { - return nil, nil, err + return nil, err } b, err := kms.NewCryptoBox(km) if err != nil { - return nil, nil, err + return nil, err } cekSlice, err := b.EasyOpen(encCEK, nonceSlice, senderPubCurve, recCurvePub) if err != nil { - return nil, nil, fmt.Errorf("failed to decrypt CEK: %s", err) + return nil, fmt.Errorf("failed to decrypt CEK: %s", err) } var cek [chacha.KeySize]byte copy(cek[:], cekSlice) - return &cek, senderPub, nil + return &keys{ + cek: &cek, + theirKey: senderPub, + myKey: recKey, + }, nil } func decodeSender(b64Sender string, pk []byte, km kms.KeyManager) ([]byte, []byte, error) { diff --git a/pkg/didcomm/protocol/didexchange/persistence.go b/pkg/didcomm/protocol/didexchange/persistence.go index fc9f670fe..e4d63b79e 100644 --- a/pkg/didcomm/protocol/didexchange/persistence.go +++ b/pkg/didcomm/protocol/didexchange/persistence.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/storage" ) @@ -54,14 +55,15 @@ func (r *ConnectionRecord) isValid() error { } // NewConnectionRecorder returns new connection record instance -func NewConnectionRecorder(transientStore, store storage.Store) *ConnectionRecorder { - return &ConnectionRecorder{transientStore: transientStore, store: store} +func NewConnectionRecorder(transientStore, store storage.Store, didStore didconnection.Store) *ConnectionRecorder { + return &ConnectionRecorder{transientStore: transientStore, store: store, didStore: didStore} } // ConnectionRecorder takes care of connection related persistence features type ConnectionRecorder struct { transientStore storage.Store store storage.Store + didStore didconnection.Store } // SaveInvitation saves connection invitation to underlying store @@ -231,6 +233,10 @@ func (c *ConnectionRecorder) saveConnectionRecord(record *ConnectionRecord) erro if err := marshalAndSave(connectionKeyPrefix(record.ConnectionID), record, c.store); err != nil { return fmt.Errorf("save connection record in permanent store: %w", err) } + + if err := c.didStore.SaveDIDByResolving(record.TheirDID, record.RecipientKeys...); err != nil { + return err + } } return nil @@ -259,6 +265,12 @@ func (c *ConnectionRecorder) saveNewConnectionRecord(record *ConnectionRecord) e return fmt.Errorf("save new connection record: %w", err) } + if record.MyDID != "" { + if err := c.didStore.SaveDIDByResolving(record.MyDID); err != nil { + return err + } + } + return c.saveNSThreadID(record.ThreadID, record.Namespace, record.ConnectionID) } diff --git a/pkg/didcomm/protocol/didexchange/persistence_test.go b/pkg/didcomm/protocol/didexchange/persistence_test.go index 46272b26a..00f73aa58 100644 --- a/pkg/didcomm/protocol/didexchange/persistence_test.go +++ b/pkg/didcomm/protocol/didexchange/persistence_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" "github.com/hyperledger/aries-framework-go/pkg/storage" "github.com/hyperledger/aries-framework-go/pkg/storage/mem" @@ -47,7 +48,7 @@ func Test_ComputeHash(t *testing.T) { func TestConnectionRecord_SaveInvitation(t *testing.T) { t.Run("test save invitation success", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) value := &Invitation{ @@ -71,7 +72,7 @@ func TestConnectionRecord_SaveInvitation(t *testing.T) { t.Run("test save invitation failure due to invalid key", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) value := &Invitation{ @@ -86,7 +87,7 @@ func TestConnectionRecord_SaveInvitation(t *testing.T) { func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - success", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueStored := &Invitation{ @@ -104,7 +105,7 @@ func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - not found scenario", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueFound, err := record.GetInvitation("sample-key4") @@ -115,7 +116,7 @@ func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - invalid key scenario", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueFound, err := record.GetInvitation("") @@ -129,7 +130,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test success found data in transient store ", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix} @@ -148,7 +149,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test success found data in store ", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix} @@ -169,7 +170,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { Store: make(map[string][]byte), ErrGet: fmt.Errorf("get error transientstore")} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRecBytes, err := json.Marshal(&ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix}) @@ -183,7 +184,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test error from store", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte), ErrGet: fmt.Errorf("get error store")} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRecBytes, err := json.Marshal(&ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix}) @@ -198,7 +199,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { func TestConnectionRecordByState(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrGet: nil} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: generateRandomID(), ThreadID: threadIDValue, @@ -235,7 +236,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { t.Run("save connection record and get connection Record success", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -248,7 +249,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record and fetch from no namespace error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited} @@ -258,7 +259,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: "", ConnectionID: "test", State: stateNameInvited, Namespace: theirNSPrefix} @@ -267,7 +268,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -277,19 +278,37 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { t.Run("save connection record in permanent store error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameCompleted, Namespace: theirNSPrefix} err := record.saveNewConnectionRecord(connRec) require.Contains(t, err.Error(), "get error") }) + t.Run("error saving DID by resolving", func(t *testing.T) { + transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} + store := &mockstorage.MockStore{Store: make(map[string][]byte)} + record := NewConnectionRecorder(transientStore, store, &mockdidconnection.MockDIDConnection{ + ResolveDIDErr: fmt.Errorf("save error"), + }) + require.NotNil(t, record) + connRec := &ConnectionRecord{ThreadID: threadIDValue, MyDID: "did:foo", + ConnectionID: connIDValue, State: stateNameCompleted, Namespace: theirNSPrefix} + err := record.saveNewConnectionRecord(connRec) + require.Error(t, err) + require.Contains(t, err.Error(), "save error") + + // note: record is still stored, since error happens afterwards + storedRecord, err := record.GetConnectionRecord(connRec.ConnectionID) + require.NoError(t, err) + require.Equal(t, connRec, storedRecord) + }) } func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { t.Run(" get connection record by namespace threadID in my namespace", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: myNSPrefix} @@ -305,7 +324,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { }) t.Run(" get connection record by namespace threadID their namespace", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -321,7 +340,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { }) t.Run(" data not found error due to missing input parameter", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec, err := record.GetConnectionRecordByNSThreadID("") require.Contains(t, err.Error(), "data not found") @@ -332,7 +351,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { func TestConnectionRecorder_PrepareConnectionRecord(t *testing.T) { t.Run(" prepare connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec, err := prepareConnectionRecord(nil) require.Contains(t, err.Error(), "prepare connection record") @@ -355,7 +374,7 @@ func TestConnectionRecorder_CreateNSKeys(t *testing.T) { func TestConnectionRecorder_SaveNSThreadID(t *testing.T) { t.Run("missing required parameters", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) err := record.saveNSThreadID("", theirNSPrefix, connIDValue) require.Error(t, err) @@ -398,7 +417,7 @@ func TestConnectionRecorder_QueryConnectionRecord(t *testing.T) { require.NoError(t, err) } - recorder := NewConnectionRecorder(transientStore, store) + recorder := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, recorder) result, err := recorder.QueryConnectionRecords() require.NoError(t, err) @@ -410,7 +429,7 @@ func TestConnectionRecorder_QueryConnectionRecord(t *testing.T) { err := store.Put(fmt.Sprintf("%s_abc123", connIDKeyPrefix), []byte("-----")) require.NoError(t, err) - recorder := NewConnectionRecorder(nil, store) + recorder := NewConnectionRecorder(nil, store, nil) require.NotNil(t, recorder) result, err := recorder.QueryConnectionRecords() require.Error(t, err) diff --git a/pkg/didcomm/protocol/didexchange/service.go b/pkg/didcomm/protocol/didexchange/service.go index 5545c2aab..5a7178b7c 100644 --- a/pkg/didcomm/protocol/didexchange/service.go +++ b/pkg/didcomm/protocol/didexchange/service.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" @@ -57,6 +58,7 @@ type provider interface { OutboundDispatcher() dispatcher.Outbound StorageProvider() storage.Provider TransientStorageProvider() storage.Provider + DIDConnectionStore() didconnection.Store Signer() kms.Signer VDRIRegistry() vdriapi.Registry } @@ -76,6 +78,7 @@ type Service struct { ctx *context callbackChannel chan *message connectionStore *ConnectionRecorder + didConnections didconnection.Store } type context struct { @@ -106,7 +109,7 @@ func New(prov provider) (*Service, error) { return nil, err } - connRecorder := NewConnectionRecorder(transientStore, store) + connRecorder := NewConnectionRecorder(transientStore, store, prov.DIDConnectionStore()) svc := &Service{ ctx: &context{ outboundDispatcher: prov.OutboundDispatcher(), @@ -117,6 +120,7 @@ func New(prov provider) (*Service, error) { // TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246 callbackChannel: make(chan *message, 10), connectionStore: connRecorder, + didConnections: prov.DIDConnectionStore(), } // start the listener @@ -271,7 +275,7 @@ func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) erro if canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) { msg.NextStateName = next.Name() if err = s.sendActionEvent(msg, aEvent); err != nil { - return fmt.Errorf("handle inbound : %w", err) + return fmt.Errorf("handle inbound: %w", err) } haltExecution = true @@ -636,7 +640,6 @@ func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID, inviteeLabe return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err) } - // TODO: hardcoded key type dest, err := service.CreateDestination(didDoc) if err != nil { return "", err diff --git a/pkg/didcomm/protocol/didexchange/service_test.go b/pkg/didcomm/protocol/didexchange/service_test.go index f43d79ed7..f71468501 100644 --- a/pkg/didcomm/protocol/didexchange/service_test.go +++ b/pkg/didcomm/protocol/didexchange/service_test.go @@ -20,6 +20,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/protocol" mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" @@ -68,17 +69,21 @@ func TestService_Handle_Inviter(t *testing.T) { prov := protocol.MockProvider{} store := mockstorage.NewMockStoreProvider() pubKey, privKey := generateKeyPair() + didConnectionStore := didconnection.MockDIDConnection{} + ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, &didConnectionStore), } newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) - s, err := New(&protocol.MockProvider{StoreProvider: store}) + s, err := New(&protocol.MockProvider{StoreProvider: store, + DIDConnectionStoreValue: &didConnectionStore, + }) require.NoError(t, err) actionCh := make(chan service.DIDCommAction, 10) @@ -127,7 +132,7 @@ func TestService_Handle_Inviter(t *testing.T) { require.NoError(t, err) msg, err := service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(msg, "", "") + _, err = s.HandleInbound(msg, newDidDoc.ID, "") require.NoError(t, err) select { @@ -149,7 +154,7 @@ func TestService_Handle_Inviter(t *testing.T) { didMsg, err := service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(didMsg, "", "") + _, err = s.HandleInbound(didMsg, newDidDoc.ID, "theirDID") require.NoError(t, err) select { @@ -202,19 +207,24 @@ func TestService_Handle_Invitee(t *testing.T) { store := mockstorage.NewMockStoreProvider() prov := protocol.MockProvider{} pubKey, privKey := generateKeyPair() + didConnectionStore := didconnection.MockDIDConnection{} + ctx := context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(transientStore.Store, store.Store), + connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, &didConnectionStore), } newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) s, err := New( - &protocol.MockProvider{StoreProvider: store, - TransientStoreProvider: transientStore}) + &protocol.MockProvider{ + StoreProvider: store, + TransientStoreProvider: transientStore, + DIDConnectionStoreValue: &didConnectionStore, + }) require.NoError(t, err) s.ctx.vdriRegistry = &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc} @@ -378,7 +388,7 @@ func TestService_Handle_EdgeCases(t *testing.T) { require.NoError(t, err) transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")} - svc.connectionStore = NewConnectionRecorder(transientStore, nil) + svc.connectionStore = NewConnectionRecorder(transientStore, nil, nil) _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), ""), "", "") require.Error(t, err) @@ -452,7 +462,7 @@ func TestService_CurrentState(t *testing.T) { svc := &Service{ connectionStore: NewConnectionRecorder(&mockStore{ get: func(string) ([]byte, error) { return nil, storage.ErrDataNotFound }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -468,7 +478,7 @@ func TestService_CurrentState(t *testing.T) { svc := &Service{ connectionStore: NewConnectionRecorder(&mockStore{ get: func(string) ([]byte, error) { return connRec, nil }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -483,7 +493,7 @@ func TestService_CurrentState(t *testing.T) { get: func(string) ([]byte, error) { return nil, errors.New("test") }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -510,7 +520,7 @@ func TestService_Update(t *testing.T) { get: func(k string) ([]byte, error) { return bytes, nil }, - }, nil), + }, nil, nil), } require.NoError(t, svc.update(RequestMsgType, connRec)) @@ -762,7 +772,7 @@ func TestServiceErrors(t *testing.T) { mockStore := &mockStore{get: func(s string) (bytes []byte, e error) { return nil, errors.New("error") }} - svc.connectionStore = NewConnectionRecorder(mockStore, nil) + svc.connectionStore = NewConnectionRecorder(mockStore, nil, nil) payload := generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "") _, err = svc.HandleInbound(payload, "", "") require.Error(t, err) @@ -773,7 +783,7 @@ func TestServiceErrors(t *testing.T) { transientStore, err := mockstorage.NewMockStoreProvider().OpenStore(DIDExchange) require.NoError(t, err) - svc.connectionStore = NewConnectionRecorder(transientStore, nil) + svc.connectionStore = NewConnectionRecorder(transientStore, nil, nil) _, err = svc.HandleInbound(msg, "", "") require.Error(t, err) @@ -858,7 +868,7 @@ func TestInvitationRecord(t *testing.T) { // db error svc.connectionStore = NewConnectionRecorder( - &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil) + &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil, nil) invitationBytes, err = json.Marshal(&Invitation{ Type: InvitationMsgType, @@ -886,7 +896,7 @@ func TestRequestRecord(t *testing.T) { // db error svc.connectionStore = NewConnectionRecorder( - &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil) + &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil, nil) _, err = svc.requestMsgRecord(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "")) @@ -1343,7 +1353,7 @@ func generateRequestMsgPayload(t *testing.T, prov provider, id, invitationID str store := mockstorage.NewMockStoreProvider() ctx := context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, - connectionStore: NewConnectionRecorder(nil, store.Store)} + connectionStore: NewConnectionRecorder(nil, store.Store, nil)} newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) @@ -1375,7 +1385,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1395,7 +1406,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolve error")}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1418,7 +1430,8 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(transientStore.Store, store.Store), + connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, + &didconnection.MockDIDConnection{}), } s, err := New(&protocol.MockProvider{StoreProvider: store, TransientStoreProvider: transientStore}) diff --git a/pkg/didcomm/protocol/didexchange/states.go b/pkg/didcomm/protocol/didexchange/states.go index d98db1633..0c8e57173 100644 --- a/pkg/didcomm/protocol/didexchange/states.go +++ b/pkg/didcomm/protocol/didexchange/states.go @@ -326,6 +326,7 @@ func (ctx *context) handleInboundRequest(request *Request, options *options, con } // get did document that will be used in exchange response + // (my did doc) responseDidDoc, connection, err := ctx.getDIDDocAndConnection(getPublicDID(options)) if err != nil { return nil, nil, err @@ -404,6 +405,11 @@ func (ctx *context) getDIDDocAndConnection(pubDID string) (*did.Doc, *Connection return nil, nil, fmt.Errorf("resolve public did[%s]: %w", pubDID, err) } + err = ctx.connectionStore.didStore.SaveDIDFromDoc(didDoc) + if err != nil { + return nil, nil, err + } + return didDoc, &Connection{DID: didDoc.ID}, nil } @@ -415,6 +421,11 @@ func (ctx *context) getDIDDocAndConnection(pubDID string) (*did.Doc, *Connection return nil, nil, fmt.Errorf("create %s did: %w", didMethod, err) } + err = ctx.connectionStore.didStore.SaveDIDFromDoc(newDidDoc) + if err != nil { + return nil, nil, err + } + connection := &Connection{ DID: newDidDoc.ID, DIDDoc: newDidDoc, diff --git a/pkg/didcomm/protocol/didexchange/states_test.go b/pkg/didcomm/protocol/didexchange/states_test.go index 118ffeae9..c600b548f 100644 --- a/pkg/didcomm/protocol/didexchange/states_test.go +++ b/pkg/didcomm/protocol/didexchange/states_test.go @@ -25,6 +25,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/protocol" mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" @@ -308,7 +309,12 @@ func TestRequestedState_Execute(t *testing.T) { didDoc.Service[0].RecipientKeys = []string{"invalid"} ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: didDoc}, - signer: &mockSigner{}} + signer: &mockSigner{}, + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{}, + )} _, followup, _, err := (&requested{}).ExecuteInbound(&stateMachineMsg{ header: &service.Header{Type: InvitationMsgType}, payload: invitationPayloadBytes, @@ -374,7 +380,7 @@ func TestRespondedState_Execute(t *testing.T) { didDoc.Service[0].RecipientKeys = []string{"invalid"} ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: didDoc}, signer: &mockSigner{}, - connectionStore: NewConnectionRecorder(store, store)} + connectionStore: NewConnectionRecorder(store, store, &didconnection.MockDIDConnection{})} _, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{ header: &service.Header{Type: RequestMsgType}, payload: requestPayloadBytes, @@ -408,7 +414,7 @@ func TestCompletedState_Execute(t *testing.T) { pubKey, privKey := generateKeyPair() _, store := getProvider() ctx := &context{signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store)} + connectionStore: NewConnectionRecorder(store, store, &didconnection.MockDIDConnection{})} newDIDDoc := createDIDDocWithKey(pubKey) connection := &Connection{ DID: newDIDDoc.ID, @@ -513,7 +519,7 @@ func TestVerifySignature(t *testing.T) { pubKey, privKey := generateKeyPair() _, store := getProvider() ctx := &context{signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(nil, store)} + connectionStore: NewConnectionRecorder(nil, store, nil)} newDIDDoc := createDIDDocWithKey(pubKey) connection := &Connection{ DID: newDIDDoc.ID, @@ -651,7 +657,7 @@ func TestPrepareConnectionSignature(t *testing.T) { ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, nil), } connectionSignature, err := ctx2.prepareConnectionSignature(connection, newDidDoc.ID) require.NoError(t, err) @@ -669,7 +675,7 @@ func TestPrepareConnectionSignature(t *testing.T) { ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, nil), } connectionSignature, err := ctx2.prepareConnectionSignature(connection, newDidDoc.ID) require.Error(t, err) @@ -697,7 +703,7 @@ func TestPrepareConnectionSignature(t *testing.T) { }) t.Run("prepare connection signature error", func(t *testing.T) { ctx := &context{signer: &mockSigner{err: errors.New("sign error")}, - connectionStore: NewConnectionRecorder(nil, store)} + connectionStore: NewConnectionRecorder(nil, store, nil)} connection := &Connection{ DIDDoc: mockdiddoc.GetMockDIDDoc(), } @@ -734,7 +740,13 @@ func TestNewRequestFromInvitation(t *testing.T) { }) t.Run("successful response to invitation with public did", func(t *testing.T) { doc := createDIDDoc() - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}} + ctx := context{ + vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}, + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{}, + )} invitationBytes, err := json.Marshal(invitation) require.NoError(t, err) @@ -793,8 +805,9 @@ func TestNewResponseFromRequest(t *testing.T) { }) t.Run("unsuccessful new response from request due to sign error", func(t *testing.T) { ctx := &context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, - signer: &mockSigner{err: errors.New("sign error")}, - connectionStore: NewConnectionRecorder(nil, store)} + signer: &mockSigner{err: errors.New("sign error")}, + connectionStore: NewConnectionRecorder(nil, store, + &didconnection.MockDIDConnection{})} request, err := createRequest(ctx) require.NoError(t, err) _, connRec, err := ctx.handleInboundRequest(request, &options{}, &ConnectionRecord{}) @@ -909,7 +922,13 @@ func TestGetPublicKey(t *testing.T) { func TestGetDIDDocAndConnection(t *testing.T) { t.Run("successfully getting did doc and connection for public did", func(t *testing.T) { doc := createDIDDoc() - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}} + ctx := context{ + vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}, + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{}, + )} didDoc, conn, err := ctx.getDIDDocAndConnection(doc.ID) require.NoError(t, err) require.NotNil(t, didDoc) @@ -917,13 +936,29 @@ func TestGetDIDDocAndConnection(t *testing.T) { require.Equal(t, didDoc.ID, conn.DID) }) t.Run("error getting public did doc from resolver", func(t *testing.T) { - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolver error")}} + ctx := context{ + vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolver error")}} didDoc, conn, err := ctx.getDIDDocAndConnection("did-id") require.Error(t, err) require.Contains(t, err.Error(), "resolver error") require.Nil(t, didDoc) require.Nil(t, conn) }) + t.Run("error saving pub did connection", func(t *testing.T) { + doc := createDIDDoc() + ctx := context{ + vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}, + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{SaveDIDErr: fmt.Errorf("did error")}, + )} + didDoc, conn, err := ctx.getDIDDocAndConnection(doc.ID) + require.Error(t, err) + require.Contains(t, err.Error(), "did error") + require.Nil(t, didDoc) + require.Nil(t, conn) + }) t.Run("error creating peer did", func(t *testing.T) { ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateErr: errors.New("creator error")}} didDoc, conn, err := ctx.getDIDDocAndConnection("") @@ -935,13 +970,31 @@ func TestGetDIDDocAndConnection(t *testing.T) { t.Run("successfully created peer did", func(t *testing.T) { ctx := context{ vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, - } + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{}, + )} didDoc, conn, err := ctx.getDIDDocAndConnection("") require.NoError(t, err) require.NotNil(t, didDoc) require.NotNil(t, conn) require.Equal(t, didDoc.ID, conn.DID) }) + t.Run("error saving peer did connection", func(t *testing.T) { + ctx := context{ + vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, + connectionStore: NewConnectionRecorder( + &mockstorage.MockStore{Store: map[string][]byte{}}, + &mockstorage.MockStore{Store: map[string][]byte{}}, + &didconnection.MockDIDConnection{SaveDIDErr: fmt.Errorf("did error")}, + )} + didDoc, conn, err := ctx.getDIDDocAndConnection("") + require.Error(t, err) + require.Contains(t, err.Error(), "did error") + require.Nil(t, didDoc) + require.Nil(t, conn) + }) } type mockSigner struct { @@ -1010,7 +1063,7 @@ func getContext(prov protocol.MockProvider, store storage.Store) *context { return &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, &didconnection.MockDIDConnection{}), } } diff --git a/pkg/didcomm/transport/http/inbound.go b/pkg/didcomm/transport/http/inbound.go index 306dc24f7..ed0907497 100644 --- a/pkg/didcomm/transport/http/inbound.go +++ b/pkg/didcomm/transport/http/inbound.go @@ -65,7 +65,7 @@ func processPOSTRequest(w http.ResponseWriter, r *http.Request, prov transport.P messageHandler := prov.InboundMessageHandler() - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { // TODO https://github.com/hyperledger/aries-framework-go/issues/271 HTTP Response Codes based on errors // from service diff --git a/pkg/didcomm/transport/http/inbound_test.go b/pkg/didcomm/transport/http/inbound_test.go index a4764f039..69769300b 100644 --- a/pkg/didcomm/transport/http/inbound_test.go +++ b/pkg/didcomm/transport/http/inbound_test.go @@ -30,7 +30,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Debugf("message received is %s", message) return nil } diff --git a/pkg/didcomm/transport/transport_interface.go b/pkg/didcomm/transport/transport_interface.go index e6b6cfd70..d7f58cc21 100644 --- a/pkg/didcomm/transport/transport_interface.go +++ b/pkg/didcomm/transport/transport_interface.go @@ -30,7 +30,7 @@ type OutboundTransport interface { // InboundMessageHandler handles the inbound requests. The transport will unpack the payload prior to the // message handle invocation. -type InboundMessageHandler func(message []byte) error +type InboundMessageHandler func(message []byte, myDID, theirDID string) error // Provider contains dependencies for starting the inbound/outbound transports. // It is typically created by using aries.Context(). diff --git a/pkg/didcomm/transport/ws/pool.go b/pkg/didcomm/transport/ws/pool.go index be8fd942f..86752d506 100644 --- a/pkg/didcomm/transport/ws/pool.go +++ b/pkg/didcomm/transport/ws/pool.go @@ -10,6 +10,7 @@ import ( "encoding/json" "sync" + "github.com/btcsuite/btcutil/base58" "nhooyr.io/websocket" commtransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" @@ -93,12 +94,12 @@ func (d *connPool) listener(conn *websocket.Conn) { } if trans != nil && trans.ReturnRoute != nil && trans.ReturnRoute.Value == decorator.TransportReturnRouteAll { - d.add(unpackMsg.FromVerKey, conn) + d.add(base58.Encode(unpackMsg.FromVerKey), conn) } messageHandler := d.msgHandler - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { logger.Errorf("incoming msg processing failed: %v", err) } diff --git a/pkg/didcomm/transport/ws/pool_test.go b/pkg/didcomm/transport/ws/pool_test.go index a008956a6..8d6d78b93 100644 --- a/pkg/didcomm/transport/ws/pool_test.go +++ b/pkg/didcomm/transport/ws/pool_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcutil/base58" "github.com/google/uuid" "github.com/stretchr/testify/require" "nhooyr.io/websocket" @@ -40,14 +41,14 @@ func TestConnectionStore(t *testing.T) { // create a transport provider (framework context) verKey := "ABCD" mockPackager := &mockpackager.Packager{ - UnpackValue: &commontransport.Envelope{Message: request, FromVerKey: verKey}, + UnpackValue: &commontransport.Envelope{Message: request, FromVerKey: base58.Decode(verKey)}, } response := "Hello" transportProvider := &mockTransportProvider{ packagerValue: mockPackager, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { resp, outboundErr := outbound.Send([]byte(response), prepareDestinationWithTransport("ws://doesnt-matter", "", []string{verKey})) require.NoError(t, outboundErr) @@ -100,7 +101,7 @@ func TestConnectionStore(t *testing.T) { transportProvider := &mockTransportProvider{ packagerValue: &mockPackager{verKey: verKey}, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { // validate the echo server response with the outbound sent message require.Equal(t, request, message) done <- struct{}{} diff --git a/pkg/didcomm/transport/ws/support_test.go b/pkg/didcomm/transport/ws/support_test.go index acd49fc50..d5665d605 100644 --- a/pkg/didcomm/transport/ws/support_test.go +++ b/pkg/didcomm/transport/ws/support_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcutil/base58" "github.com/google/uuid" "github.com/stretchr/testify/require" "nhooyr.io/websocket" @@ -32,7 +33,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Infof("message received is %s", string(message)) if string(message) == "invalid-data" { return errors.New("error") @@ -146,12 +147,12 @@ func (m *mockPackager) PackMessage(e *commontransport.Envelope) ([]byte, error) } func (m *mockPackager) UnpackMessage(encMessage []byte) (*commontransport.Envelope, error) { - return &commontransport.Envelope{Message: encMessage, FromVerKey: m.verKey}, nil + return &commontransport.Envelope{Message: encMessage, FromVerKey: base58.Decode(m.verKey)}, nil } type mockTransportProvider struct { packagerValue commontransport.Packager - executeInbound func(message []byte) error + executeInbound func(message []byte, myDID, theirDID string) error frameworkID string } diff --git a/pkg/framework/aries/api/protocol.go b/pkg/framework/aries/api/protocol.go index 0d2e941f1..9e70373ab 100644 --- a/pkg/framework/aries/api/protocol.go +++ b/pkg/framework/aries/api/protocol.go @@ -9,6 +9,7 @@ package api import ( "errors" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" @@ -25,6 +26,7 @@ type Provider interface { Service(id string) (interface{}, error) StorageProvider() storage.Provider KMS() kms.KeyManager + DIDConnectionStore() didconnection.Store Packager() transport.Packager InboundTransportEndpoint() string VDRIRegistry() vdriapi.Registry diff --git a/pkg/framework/aries/framework.go b/pkg/framework/aries/framework.go index de233078e..4d0508b95 100644 --- a/pkg/framework/aries/framework.go +++ b/pkg/framework/aries/framework.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" @@ -43,6 +44,7 @@ type Aries struct { inboundTransport transport.InboundTransport kmsCreator api.KMSCreator kms api.CloseableKMS + didConnectionStore didconnection.Store packagerCreator packager.Creator packager commontransport.Packager packerCreator packer.Creator @@ -87,8 +89,11 @@ func New(opts ...Option) (*Aries, error) { // on the context. The inbound transports require ctx.InboundMessageHandler(), which in-turn depends on // protocolServices. At the moment, there is a looping issue among these. - // Order of initializing service is important + return initializeServices(frameworkOpts) +} +func initializeServices(frameworkOpts *Aries) (*Aries, error) { + // Order of initializing service is important // Create kms if e := createKMS(frameworkOpts); e != nil { return nil, e @@ -99,7 +104,12 @@ func New(opts ...Option) (*Aries, error) { return nil, e } - // create packers and packager (must be done after KMS) + // Create connection store + if e := createDIDConnectionStore(frameworkOpts); e != nil { + return nil, e + } + + // create packers and packager (must be done after KMS and connection store) if err := createPackersAndPackager(frameworkOpts); err != nil { return nil, err } @@ -217,6 +227,7 @@ func (a *Aries) Context() (*context.Provider, error) { context.WithOutboundTransports(a.outboundTransports...), context.WithProtocolServices(a.services...), context.WithKMS(a.kms), + context.WithDIDConnectionStore(a.didConnectionStore), context.WithInboundTransportEndpoint(endPoint), context.WithStorageProvider(a.storeProvider), context.WithTransientStorageProvider(a.transientStoreProvider), @@ -328,6 +339,7 @@ func createOutboundDispatcher(frameworkOpts *Aries) error { context.WithOutboundTransports(frameworkOpts.outboundTransports...), context.WithPackager(frameworkOpts.packager), context.WithTransportReturnRoute(frameworkOpts.transportReturnRoute), + context.WithVDRIRegistry(frameworkOpts.vdriRegistry), ) if err != nil { return fmt.Errorf("context creation failed: %w", err) @@ -380,6 +392,7 @@ func loadServices(frameworkOpts *Aries) error { context.WithPackager(frameworkOpts.packager), context.WithInboundTransportEndpoint(endPoint), context.WithVDRIRegistry(frameworkOpts.vdriRegistry), + context.WithDIDConnectionStore(frameworkOpts.didConnectionStore), ) if err != nil { @@ -398,10 +411,27 @@ func loadServices(frameworkOpts *Aries) error { return nil } +func createDIDConnectionStore(frameworkOpts *Aries) error { + ctx, err := context.New( + context.WithStorageProvider(frameworkOpts.storeProvider), + context.WithVDRIRegistry(frameworkOpts.vdriRegistry), + ) + if err != nil { + return fmt.Errorf("create did lookup store context failed: %w", err) + } + + frameworkOpts.didConnectionStore, err = didconnection.New(ctx) + if err != nil { + return fmt.Errorf("create did lookup store failed: %w", err) + } + + return nil +} + func createPackersAndPackager(frameworkOpts *Aries) error { ctx, err := context.New(context.WithKMS(frameworkOpts.kms)) if err != nil { - return fmt.Errorf("create envelope context failed: %w", err) + return fmt.Errorf("create packer context failed: %w", err) } frameworkOpts.primaryPacker, err = frameworkOpts.packerCreator(ctx) @@ -423,7 +453,8 @@ func createPackersAndPackager(frameworkOpts *Aries) error { } ctx, err = context.New( - context.WithPacker(frameworkOpts.primaryPacker, frameworkOpts.packers...)) + context.WithPacker(frameworkOpts.primaryPacker, frameworkOpts.packers...), + context.WithDIDConnectionStore(frameworkOpts.didConnectionStore)) if err != nil { return fmt.Errorf("create packager context failed: %w", err) } diff --git a/pkg/framework/aries/framework_test.go b/pkg/framework/aries/framework_test.go index 8d629bf93..5536a1f75 100644 --- a/pkg/framework/aries/framework_test.go +++ b/pkg/framework/aries/framework_test.go @@ -319,6 +319,16 @@ func TestFramework(t *testing.T) { require.Contains(t, err.Error(), "error from kms") }) + t.Run("test did connection store svc", func(t *testing.T) { + fw := Aries{storeProvider: &storage.MockStoreProvider{ + ErrOpenStoreHandle: fmt.Errorf("store err"), + }} + + err := createDIDConnectionStore(&fw) + require.Error(t, err) + require.Contains(t, err.Error(), "store err") + }) + t.Run("test transient store - with user provided transient store", func(t *testing.T) { path, cleanup := generateTempDir(t) defer cleanup() diff --git a/pkg/framework/context/context.go b/pkg/framework/context/context.go index 80ab24c5b..9b4adf974 100644 --- a/pkg/framework/context/context.go +++ b/pkg/framework/context/context.go @@ -9,6 +9,7 @@ package context import ( "fmt" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" @@ -26,6 +27,7 @@ type Provider struct { storeProvider storage.Provider transientStoreProvider storage.Provider kms kms.KMS + didConnectionStore didconnection.Store packager commontransport.Packager primaryPacker packer.Packer packers []packer.Packer @@ -77,6 +79,11 @@ func (p *Provider) KMS() kms.KeyManager { return p.kms } +// DIDConnectionStore returns a didconnection.Store service. +func (p *Provider) DIDConnectionStore() didconnection.Store { + return p.didConnectionStore +} + // Packager returns a packager service. func (p *Provider) Packager() commontransport.Packager { return p.packager @@ -104,7 +111,7 @@ func (p *Provider) InboundTransportEndpoint() string { // InboundMessageHandler return an inbound message handler. func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { msg, err := service.NewDIDCommMsg(message) if err != nil { return err @@ -113,7 +120,7 @@ func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { // find the service which accepts the message type for _, svc := range p.services { if svc.Accept(msg.Header.Type) { - _, err = svc.HandleInbound(msg, "", "") + _, err = svc.HandleInbound(msg, myDID, theirDID) return err } } @@ -221,6 +228,14 @@ func WithTransientStorageProvider(s storage.Provider) ProviderOption { } } +// WithDIDConnectionStore injects a didconnection.Store into the context. +func WithDIDConnectionStore(cs didconnection.Store) ProviderOption { + return func(opts *Provider) error { + opts.didConnectionStore = cs + return nil + } +} + // WithPackager injects a packager into the context. func WithPackager(p commontransport.Packager) ProviderOption { return func(opts *Provider) error { diff --git a/pkg/framework/context/context_test.go b/pkg/framework/context/context_test.go index e1fca0dc1..726f2bbcf 100644 --- a/pkg/framework/context/context_test.go +++ b/pkg/framework/context/context_test.go @@ -18,6 +18,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" mockdidcomm "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockdispatcher "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/dispatcher" mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/protocol" @@ -97,16 +98,16 @@ func TestNewProvider(t *testing.T) { { "@frameworkID": "5678876542345", "@type": "valid-message-type" - }`)) + }`), "", "") require.NoError(t, err) // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") @@ -115,7 +116,7 @@ func TestNewProvider(t *testing.T) { { "@type": "invalid-message-type", "label": "Bob" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "no message handlers found for the message type: invalid-message-type") @@ -124,7 +125,7 @@ func TestNewProvider(t *testing.T) { { "label": "Carol", "@type": "valid-message-type" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "error handling the message") }) @@ -195,6 +196,13 @@ func TestNewProvider(t *testing.T) { require.Equal(t, r, prov.VDRIRegistry()) }) + t.Run("test new with did connection store", func(t *testing.T) { + cs := &mockdidconnection.MockDIDConnection{} + prov, err := New(WithDIDConnectionStore(cs)) + require.NoError(t, err) + require.Equal(t, cs, prov.DIDConnectionStore()) + }) + t.Run("test new with outbound transport service", func(t *testing.T) { prov, err := New(WithOutboundTransports(&mockdidcomm.MockOutboundTransport{ExpectedResponse: "data"}, &mockdidcomm.MockOutboundTransport{ExpectedResponse: "data1"})) diff --git a/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go new file mode 100644 index 000000000..b5fcf0e72 --- /dev/null +++ b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go @@ -0,0 +1,40 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" +) + +// MockDIDConnection mocks the did lookup store. +type MockDIDConnection struct { + SaveRecordErr error + SaveKeysErr error + GetDIDValue string + GetDIDErr error + SaveDIDErr error + ResolveDIDErr error +} + +// SaveDID saves a DID to the store +func (m *MockDIDConnection) SaveDID(did string, keys ...string) error { + return m.SaveRecordErr +} + +// GetDID gets the DID stored under the given key +func (m *MockDIDConnection) GetDID(key string) (string, error) { + return m.GetDIDValue, m.GetDIDErr +} + +// SaveDIDByResolving saves a DID by resolving it then using its doc +func (m *MockDIDConnection) SaveDIDByResolving(did string, keys ...string) error { + return m.ResolveDIDErr +} + +// SaveDIDFromDoc saves a DID using the given doc +func (m *MockDIDConnection) SaveDIDFromDoc(doc *diddoc.Doc) error { + return m.SaveDIDErr +} diff --git a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go index 25f683e42..99ae4a076 100644 --- a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go +++ b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go @@ -27,7 +27,7 @@ func (m *MockOutbound) Send(msg interface{}, senderVerKey string, des *service.D // SendToDID msg func (m *MockOutbound) SendToDID(msg interface{}, myDID, theirDID string) error { - return nil + return m.SendErr } // Forward msg diff --git a/pkg/internal/mock/didcomm/mock_authcrypt.go b/pkg/internal/mock/didcomm/mock_authcrypt.go index d8ef0c72f..b75acc33f 100644 --- a/pkg/internal/mock/didcomm/mock_authcrypt.go +++ b/pkg/internal/mock/didcomm/mock_authcrypt.go @@ -6,10 +6,12 @@ SPDX-License-Identifier: Apache-2.0 package didcomm +import "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + // MockAuthCrypt mock auth crypt type MockAuthCrypt struct { EncryptValue func(payload, senderPubKey []byte, recipients [][]byte) ([]byte, error) - DecryptValue func(envelope []byte) ([]byte, []byte, error) + DecryptValue func(envelope []byte) (*transport.Envelope, error) Type string } @@ -20,7 +22,7 @@ func (m *MockAuthCrypt) Pack(payload, senderPubKey []byte, } // Unpack mock message unpacking -func (m *MockAuthCrypt) Unpack(envelope []byte) ([]byte, []byte, error) { +func (m *MockAuthCrypt) Unpack(envelope []byte) (*transport.Envelope, error) { return m.DecryptValue(envelope) } diff --git a/pkg/internal/mock/packer/noop.go b/pkg/internal/mock/didcomm/packer/noop.go similarity index 69% rename from pkg/internal/mock/packer/noop.go rename to pkg/internal/mock/didcomm/packer/noop.go index 3b00679a2..8c888c310 100644 --- a/pkg/internal/mock/packer/noop.go +++ b/pkg/internal/mock/didcomm/packer/noop.go @@ -9,16 +9,19 @@ package packer import ( "encoding/base64" "encoding/json" + "fmt" "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" ) type envelope struct { - Header string `json:"protected,omitempty"` - Sender string `json:"spk,omitempty"` - Message string `json:"msg,omitempty"` + Header string `json:"protected,omitempty"` + Sender string `json:"spk,omitempty"` + Recipient string `json:"kid,omitempty"` + Message string `json:"msg,omitempty"` } type header struct { @@ -50,10 +53,15 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + if len(recipientPubKeys) == 0 { + return nil, fmt.Errorf("no recipients") + } + message := envelope{ - Header: headerB64, - Sender: base58.Encode(sender), - Message: string(payload), + Header: headerB64, + Sender: base58.Encode(sender), + Recipient: base58.Encode(recipientPubKeys[0]), + Message: string(payload), } msgBytes, err := json.Marshal(&message) @@ -62,27 +70,31 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte } // Unpack will decode the envelope using the NOOP format. -func (p *Packer) Unpack(message []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(message []byte) (*transport.Envelope, error) { var env envelope err := json.Unmarshal(message, &env) if err != nil { - return nil, nil, err + return nil, err } headerBytes, err := base64.URLEncoding.DecodeString(env.Header) if err != nil { - return nil, nil, err + return nil, err } var head header err = json.Unmarshal(headerBytes, &head) if err != nil { - return nil, nil, err + return nil, err } - return []byte(env.Message), base58.Decode(env.Sender), nil + return &transport.Envelope{ + Message: []byte(env.Message), + FromVerKey: base58.Decode(env.Sender), + ToVerKey: base58.Decode(env.Recipient), + }, nil } // EncodingType returns the type of the encoding, as found in the header `Typ` field diff --git a/pkg/internal/mock/didcomm/packer/noop_test.go b/pkg/internal/mock/didcomm/packer/noop_test.go new file mode 100644 index 000000000..ecd44aa60 --- /dev/null +++ b/pkg/internal/mock/didcomm/packer/noop_test.go @@ -0,0 +1,111 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package packer + +import ( + "encoding/base64" + "testing" + + "github.com/btcsuite/btcutil/base58" + "github.com/stretchr/testify/require" +) + +// note: does not replicate correct packing +// when msg needs to be escaped. +func testPack(msg, senderKey, recKey []byte) []byte { + headerValue := base64.URLEncoding.EncodeToString([]byte(`{"typ":"NOOP"}`)) + + return []byte(`{"protected":"` + headerValue + + `","spk":"` + base58.Encode(senderKey) + + `","kid":"` + base58.Encode(recKey) + + `","msg":"` + string(msg) + `"}`) +} + +func TestPacker(t *testing.T) { + p := New(nil) + require.NotNil(t, p) + require.Equal(t, encodingType, p.EncodingType()) + + t.Run("no rec keys", func(t *testing.T) { + _, err := p.Pack(nil, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "no recipients") + }) + + t.Run("pack, compare against correct data", func(t *testing.T) { + msgin := []byte("hello my name is zoop") + key := []byte("senderkey") + rec := []byte("recipient") + + msgout, err := p.Pack(msgin, key, [][]byte{rec}) + require.NoError(t, err) + + correct := testPack(msgin, key, rec) + require.Equal(t, correct, msgout) + }) + + t.Run("unpack fixed value, confirm data", func(t *testing.T) { + correct := []byte("this is not a test message") + key := []byte("testKey") + rec := []byte("key2") + msgin := testPack(correct, key, rec) + + envOut, err := p.Unpack(msgin) + require.NoError(t, err) + + require.Equal(t, correct, envOut.Message) + require.Equal(t, key, envOut.FromVerKey) + require.Equal(t, rec, envOut.ToVerKey) + }) + + t.Run("multiple pack/unpacks", func(t *testing.T) { + cleartext := []byte("this is not a test message") + key1 := []byte("testKey") + rec1 := []byte("rec1") + key2 := []byte("wrapperKey") + rec2 := []byte("rec2") + + correct1 := testPack(cleartext, key1, rec1) + + msg1, err := p.Pack(cleartext, key1, [][]byte{rec1}) + require.NoError(t, err) + require.Equal(t, correct1, msg1) + + msg2, err := p.Pack(msg1, key2, [][]byte{rec2}) + require.NoError(t, err) + + env1, err := p.Unpack(msg2) + require.NoError(t, err) + require.Equal(t, key2, env1.FromVerKey) + require.Equal(t, rec2, env1.ToVerKey) + require.Equal(t, correct1, env1.Message) + + env2, err := p.Unpack(env1.Message) + require.NoError(t, err) + require.Equal(t, key1, env2.FromVerKey) + require.Equal(t, rec1, env2.ToVerKey) + require.Equal(t, cleartext, env2.Message) + }) + + t.Run("unpack errors", func(t *testing.T) { + _, err := p.Unpack(nil) + require.Error(t, err) + require.Contains(t, err.Error(), "end of JSON input") + + _, err = p.Unpack([]byte("{}")) + require.Error(t, err) + require.Contains(t, err.Error(), "end of JSON input") + + _, err = p.Unpack([]byte("{\"protected\":\"$$$$$$$$$$$$\"}")) + require.Error(t, err) + require.Contains(t, err.Error(), "illegal base64 data") + + _, err = p.Unpack([]byte("{\"protected\":\"e3t7\"}")) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid character") + }) +} diff --git a/pkg/internal/mock/didcomm/protocol/mock_didexchange.go b/pkg/internal/mock/didcomm/protocol/mock_didexchange.go index 3bcd1a75b..03063acc5 100644 --- a/pkg/internal/mock/didcomm/protocol/mock_didexchange.go +++ b/pkg/internal/mock/didcomm/protocol/mock_didexchange.go @@ -9,9 +9,11 @@ package protocol import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockdispatcher "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/dispatcher" mockkms "github.com/hyperledger/aries-framework-go/pkg/internal/mock/kms" mockstore "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" @@ -135,9 +137,19 @@ func (m *MockDIDExchangeSvc) CreateImplicitInvitation(inviterLabel, inviterDID, // MockProvider is provider for DIDExchange Service type MockProvider struct { - StoreProvider *mockstore.MockStoreProvider - TransientStoreProvider *mockstore.MockStoreProvider - CustomVDRI vdriapi.Registry + StoreProvider *mockstore.MockStoreProvider + TransientStoreProvider *mockstore.MockStoreProvider + CustomVDRI vdriapi.Registry + DIDConnectionStoreValue didconnection.Store +} + +// DIDConnectionStore returns the did lookup store +func (p *MockProvider) DIDConnectionStore() didconnection.Store { + if p.DIDConnectionStoreValue == nil { + return &mockdidconnection.MockDIDConnection{} + } + + return p.DIDConnectionStoreValue } // OutboundDispatcher is mock outbound dispatcher for DID exchange service diff --git a/pkg/internal/mock/packer/noop_test.go b/pkg/internal/mock/packer/noop_test.go deleted file mode 100644 index 32f6696bf..000000000 --- a/pkg/internal/mock/packer/noop_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* -Copyright SecureKey Technologies Inc. All Rights Reserved. - -SPDX-License-Identifier: Apache-2.0 -*/ - -package packer - -import ( - "encoding/base64" - "testing" - - "github.com/btcsuite/btcutil/base58" - "github.com/stretchr/testify/require" -) - -// note: does not replicate correct packing -// when msg needs to be escaped. -func testPack(msg, key []byte) []byte { - headerValue := base64.URLEncoding.EncodeToString([]byte(`{"typ":"NOOP"}`)) - - return []byte(`{"protected":"` + headerValue + - `","spk":"` + base58.Encode(key) + - `","msg":"` + string(msg) + `"}`) -} - -func TestPacker(t *testing.T) { - p := New(nil) - require.NotNil(t, p) - require.Equal(t, encodingType, p.EncodingType()) - - t.Run("pack, compare against correct data", func(t *testing.T) { - msgin := []byte("hello my name is zoop") - key := []byte("senderkey") - - msgout, err := p.Pack(msgin, key, nil) - require.NoError(t, err) - - correct := testPack(msgin, key) - require.Equal(t, correct, msgout) - }) - - t.Run("unpack fixed value, confirm data", func(t *testing.T) { - correct := []byte("this is not a test message") - key := []byte("testKey") - msgin := testPack(correct, key) - - msgout, keyOut, err := p.Unpack(msgin) - require.NoError(t, err) - - require.Equal(t, correct, msgout) - require.Equal(t, key, keyOut) - }) - - t.Run("multiple pack/unpacks", func(t *testing.T) { - cleartext := []byte("this is not a test message") - key1 := []byte("testKey") - key2 := []byte("wrapperKey") - - correct1 := testPack(cleartext, key1) - - msg1, err := p.Pack(cleartext, key1, nil) - require.NoError(t, err) - require.Equal(t, correct1, msg1) - - msg2, err := p.Pack(msg1, key2, nil) - require.NoError(t, err) - - msg3, key1Out, err := p.Unpack(msg2) - require.NoError(t, err) - require.Equal(t, key2, key1Out) - require.Equal(t, correct1, msg3) - - msg4, key2Out, err := p.Unpack(msg3) - require.NoError(t, err) - require.Equal(t, key1, key2Out) - require.Equal(t, cleartext, msg4) - }) -} diff --git a/pkg/internal/mock/provider/mock_provider.go b/pkg/internal/mock/provider/mock_provider.go index 15aac74c6..9cfc649b2 100644 --- a/pkg/internal/mock/provider/mock_provider.go +++ b/pkg/internal/mock/provider/mock_provider.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 package provider import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" @@ -26,6 +27,7 @@ type Provider struct { PackerValue packer.Packer OutboundDispatcherValue dispatcher.Outbound VDRIRegistryValue vdriapi.Registry + ConnectionStoreValue didconnection.Store } // Service return service @@ -72,3 +74,8 @@ func (p *Provider) OutboundDispatcher() dispatcher.Outbound { func (p *Provider) VDRIRegistry() vdriapi.Registry { return p.VDRIRegistryValue } + +// DIDConnectionStore returns a didconnection.Store service. +func (p *Provider) DIDConnectionStore() didconnection.Store { + return p.ConnectionStoreValue +} diff --git a/pkg/restapi/operation/didexchange/didexchange.go b/pkg/restapi/operation/didexchange/didexchange.go index 47c39d9ec..4dc0837e5 100644 --- a/pkg/restapi/operation/didexchange/didexchange.go +++ b/pkg/restapi/operation/didexchange/didexchange.go @@ -18,6 +18,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/client/didexchange" "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/internal/common/support" "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -74,6 +75,7 @@ const ( type provider interface { Service(id string) (interface{}, error) KMS() kms.KeyManager + DIDConnectionStore() didconnection.Store InboundTransportEndpoint() string StorageProvider() storage.Provider TransientStorageProvider() storage.Provider