diff --git a/pkg/client/didexchange/client.go b/pkg/client/didexchange/client.go index f1f870c4fd..462fcc1824 100644 --- a/pkg/client/didexchange/client.go +++ b/pkg/client/didexchange/client.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -40,6 +41,7 @@ type provider interface { InboundTransportEndpoint() string StorageProvider() storage.Provider TransientStorageProvider() storage.Provider + DIDConnectionStore() didconnection.Store } // Client enable access to didexchange api @@ -94,7 +96,7 @@ func New(ctx provider) (*Client, error) { didexchangeSvc: didexchangeSvc, kms: ctx.KMS(), inboundTransportEndpoint: ctx.InboundTransportEndpoint(), - connectionStore: didexchange.NewConnectionRecorder(transientStore, store), + connectionStore: didexchange.NewConnectionRecorder(transientStore, store, ctx.DIDConnectionStore()), }, nil } @@ -157,7 +159,7 @@ func (c *Client) HandleInvitation(invitation *Invitation) (string, error) { return "", fmt.Errorf("failed to create DIDCommMsg: %w", err) } - connectionID, err := c.didexchangeSvc.HandleInbound(msg) + connectionID, err := c.didexchangeSvc.HandleInbound(msg, "", "") if err != nil { return "", fmt.Errorf("failed from didexchange service handle: %w", err) } diff --git a/pkg/client/didexchange/client_test.go b/pkg/client/didexchange/client_test.go index febc07f441..7d826fb361 100644 --- a/pkg/client/didexchange/client_test.go +++ b/pkg/client/didexchange/client_test.go @@ -202,7 +202,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { require.NoError(t, err) require.NoError(t, err) require.NoError(t, transientStore.Put("conn_id1", connBytes)) - c := didexchange.NewConnectionRecorder(transientStore, store) + c := didexchange.NewConnectionRecorder(transientStore, store, nil) result, err := c.GetConnectionRecord(connID) require.NoError(t, err) require.Equal(t, "complete", result.State) @@ -219,7 +219,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { connRec := &didexchange.ConnectionRecord{ConnectionID: connID, ThreadID: threadID, State: "complete"} connBytes, err := json.Marshal(connRec) require.NoError(t, err) - c := didexchange.NewConnectionRecorder(transientStore, store) + c := didexchange.NewConnectionRecorder(transientStore, store, nil) require.NoError(t, transientStore.Put("conn_id1", connBytes)) _, err = c.GetConnectionRecord(connID) require.Error(t, err) @@ -234,7 +234,7 @@ func TestClient_QueryConnectionByID(t *testing.T) { transientStore := mockstore.MockStore{ErrGet: storage.ErrDataNotFound} store := mockstore.MockStore{} require.NoError(t, err) - c := didexchange.NewConnectionRecorder(&transientStore, &store) + c := didexchange.NewConnectionRecorder(&transientStore, &store, nil) _, err = c.GetConnectionRecord(connID) require.Error(t, err) require.True(t, errors.Is(err, storage.ErrDataNotFound)) @@ -589,7 +589,7 @@ func TestServiceEvents(t *testing.T) { msg, err := service.NewDIDCommMsg(request) require.NoError(t, err) - _, err = didExSvc.HandleInbound(msg) + _, err = didExSvc.HandleInbound(msg, "", "") require.NoError(t, err) select { @@ -679,7 +679,7 @@ func TestAcceptExchangeRequest(t *testing.T) { msg, err := service.NewDIDCommMsg(request) require.NoError(t, err) - _, err = didExSvc.HandleInbound(msg) + _, err = didExSvc.HandleInbound(msg, "", "") require.NoError(t, err) select { @@ -760,7 +760,7 @@ func TestAcceptInvitation(t *testing.T) { msg, svcErr := service.NewDIDCommMsg(invitation) require.NoError(t, svcErr) - _, err = didExSvc.HandleInbound(msg) + _, err = didExSvc.HandleInbound(msg, "", "") require.NoError(t, err) select { diff --git a/pkg/didcomm/common/didconnection/api.go b/pkg/didcomm/common/didconnection/api.go new file mode 100644 index 0000000000..cde9c3ef30 --- /dev/null +++ b/pkg/didcomm/common/didconnection/api.go @@ -0,0 +1,18 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +// Store stores DIDs indexed by public key, so agents can find the DID associated with a given key. +type Store interface { + // SaveDID saves a DID indexed by the given public keys to the Store + SaveDID(did string, keys ...string) error + // GetDID gets the DID stored under the given key + GetDID(key string) (string, error) + // SaveDIDConnection saves a connection between this agent's DID and another agent's DID + SaveDIDConnection(myDID, theirDID string, theirKeys []string) error + // SaveDIDFromDoc resolves a DID using the VDR then saves the map from keys -> did + SaveDIDFromDoc(did, serviceType, keyType string) error +} diff --git a/pkg/didcomm/common/didconnection/didconnection.go b/pkg/didcomm/common/didconnection/didconnection.go new file mode 100644 index 0000000000..0ff71a63fe --- /dev/null +++ b/pkg/didcomm/common/didconnection/didconnection.go @@ -0,0 +1,121 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "encoding/json" + "fmt" + + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" + "github.com/hyperledger/aries-framework-go/pkg/storage" +) + +// BaseDIDConnectionStore stores DIDs indexed by key +type BaseDIDConnectionStore struct { + store storage.Store + vdr vdri.Registry +} + +type didRecord struct { + DID string `json:"did,omitempty"` +} + +type provider interface { + StorageProvider() storage.Provider + VDRIRegistry() vdri.Registry +} + +// New returns a new did lookup Store +func New(ctx provider) (*BaseDIDConnectionStore, error) { + store, err := ctx.StorageProvider().OpenStore("con-store") + if err != nil { + return nil, err + } + + return &BaseDIDConnectionStore{store: store, vdr: ctx.VDRIRegistry()}, nil +} + +// saveDID saves a DID, indexed using the given public key +func (c *BaseDIDConnectionStore) saveDID(did, key string) error { + data := didRecord{ + DID: did, + } + + bytes, err := json.Marshal(data) + if err != nil { + return err + } + + return c.store.Put(key, bytes) +} + +// SaveDID saves a DID, indexed using the given public keys +func (c *BaseDIDConnectionStore) SaveDID(did string, keys ...string) error { + for _, key := range keys { + err := c.saveDID(did, key) + if err != nil { + return fmt.Errorf("saving DID in did map: %w", err) + } + } + + return nil +} + +// SaveDIDFromDoc resolves a DID using the VDR then saves the map from keys -> did +func (c *BaseDIDConnectionStore) SaveDIDFromDoc(did, serviceType, keyType string) error { + doc, err := c.vdr.Resolve(did) + if err != nil { + return err + } + + keys, err := diddoc.GetRecipientKeys(doc, serviceType, keyType) + if err != nil { + return err + } + + err = c.SaveDID(did, keys...) + if err != nil { + return err + } + + return nil +} + +// GetDID gets the DID stored under the given key +func (c *BaseDIDConnectionStore) GetDID(key string) (string, error) { + bytes, err := c.store.Get(key) + if err != nil { + return "", err + } + + var record didRecord + + err = json.Unmarshal(bytes, &record) + if err != nil { + return "", err + } + + return record.DID, nil +} + +// SaveDIDConnection saves a connection between this agent's did and another agent +func (c *BaseDIDConnectionStore) SaveDIDConnection(myDID, theirDID string, theirKeys []string) error { + // map their pub keys -> their DID + err := c.SaveDID(theirDID, theirKeys...) + if err != nil { + return err + } + + // map their DID -> my DID + err = c.SaveDID(myDID, theirDID) + if err != nil { + return fmt.Errorf("save DID in did map: %w", err) + } + + return nil +} diff --git a/pkg/didcomm/common/didconnection/didconnection_test.go b/pkg/didcomm/common/didconnection/didconnection_test.go new file mode 100644 index 0000000000..03c68e38c7 --- /dev/null +++ b/pkg/didcomm/common/didconnection/didconnection_test.go @@ -0,0 +1,67 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +import ( + "testing" + + "github.com/stretchr/testify/require" + + vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" + mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" + mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" + "github.com/hyperledger/aries-framework-go/pkg/storage" +) + +type ctx struct { + store storage.Provider + vdr vdriapi.Registry +} + +func (c *ctx) StorageProvider() storage.Provider { + return c.store +} + +func (c *ctx) VDRIRegistry() vdriapi.Registry { + return c.vdr +} + +func TestBaseConnectionStore(t *testing.T) { + prov := ctx{ + store: mockstorage.NewMockStoreProvider(), + vdr: &mockvdri.MockVDRIRegistry{ + CreateErr: nil, + CreateValue: mockdiddoc.GetMockDIDDoc(), + MemStore: nil, + PutErr: nil, + ResolveErr: nil, + ResolveValue: mockdiddoc.GetMockDIDDoc(), + }, + } + + connStore, err := New(&prov) + require.NoError(t, err) + + err = connStore.SaveDID("did:abcde", "abcde") + require.NoError(t, err) + + did, err := connStore.GetDID("abcde") + require.NoError(t, err) + require.Equal(t, "did:abcde", did) + + wrong, err := connStore.GetDID("fhtagn") + require.EqualError(t, err, storage.ErrDataNotFound.Error()) + require.Equal(t, "", wrong) + + err = connStore.store.Put("bad-data", []byte("aaooga")) + require.NoError(t, err) + + _, err = connStore.GetDID("bad-data") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid character") +} diff --git a/pkg/didcomm/common/service/destination.go b/pkg/didcomm/common/service/destination.go new file mode 100644 index 0000000000..d0cc32a43d --- /dev/null +++ b/pkg/didcomm/common/service/destination.go @@ -0,0 +1,50 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package service + +import ( + diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" +) + +// Destination provides the recipientKeys, routingKeys, and serviceEndpoint for an outbound message. +// Can be populated from an Invitation or DIDDoc. +type Destination struct { + RecipientKeys []string + ServiceEndpoint string + RoutingKeys []string + TransportReturnRoute string +} + +// GetDestination constructs a Destination struct based on the given DID and parameters +// It resolves the DID using the given VDR, and collects relevant data from the resolved DIDDoc. +func GetDestination(did, serviceType, keyType string, vdr vdri.Registry) (*Destination, error) { + didDoc, err := vdr.Resolve(did) + if err != nil { + return nil, err + } + + return MakeDestination(didDoc, serviceType, keyType) +} + +// MakeDestination makes a Destination object from a DID Doc, following the given parameters. +func MakeDestination(didDoc *diddoc.Doc, serviceType, keyType string) (*Destination, error) { + didCommService, err := diddoc.GetDIDCommService(didDoc, serviceType) + if err != nil { + return nil, err + } + + recipientKeys, err := diddoc.GetRecipientKeys(didDoc, serviceType, keyType) + if err != nil { + return nil, err + } + + return &Destination{ + RecipientKeys: recipientKeys, + ServiceEndpoint: didCommService.ServiceEndpoint, + }, nil +} diff --git a/pkg/didcomm/common/service/destination_test.go b/pkg/didcomm/common/service/destination_test.go new file mode 100644 index 0000000000..694ba8cd2a --- /dev/null +++ b/pkg/didcomm/common/service/destination_test.go @@ -0,0 +1,146 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package service + +import ( + "crypto/ed25519" + "crypto/rand" + "errors" + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcutil/base58" + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/doc/did" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" + mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" +) + +func TestGetDestinationFromDID(t *testing.T) { + doc := createDIDDoc() + ed25519KeyType := "Ed25519VerificationKey2018" + didCommServiceType := "did-communication" + + t.Run("successfully getting destination from public DID", func(t *testing.T) { + vdr := mockvdri.MockVDRIRegistry{ResolveValue: doc} + destination, err := GetDestination(doc.ID, didCommServiceType, ed25519KeyType, &vdr) + require.NoError(t, err) + require.NotNil(t, destination) + }) + + t.Run("test public key not found", func(t *testing.T) { + doc.PublicKey = nil + vdr := mockvdri.MockVDRIRegistry{ResolveValue: doc} + destination, err := GetDestination(doc.ID, didCommServiceType, ed25519KeyType, &vdr) + require.Error(t, err) + require.Contains(t, err.Error(), "key not found in DID document") + require.Nil(t, destination) + }) + + t.Run("test service not found", func(t *testing.T) { + doc2 := createDIDDoc() + doc2.Service = nil + vdr := mockvdri.MockVDRIRegistry{ResolveValue: doc2} + destination, err := GetDestination(doc2.ID, didCommServiceType, ed25519KeyType, &vdr) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found in DID document: did-communication") + require.Nil(t, destination) + }) + + t.Run("test did document not found", func(t *testing.T) { + vdr := mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolver error")} + destination, err := GetDestination(doc.ID, didCommServiceType, ed25519KeyType, &vdr) + require.Error(t, err) + require.Contains(t, err.Error(), "resolver error") + require.Nil(t, destination) + }) +} + +func TestPrepareDestination(t *testing.T) { + ed25519KeyType := "Ed25519VerificationKey2018" + didCommServiceType := "did-communication" + + t.Run("successfully prepared destination", func(t *testing.T) { + dest, err := MakeDestination(mockdiddoc.GetMockDIDDoc(), didCommServiceType, ed25519KeyType) + require.NoError(t, err) + require.NotNil(t, dest) + require.Equal(t, dest.ServiceEndpoint, "https://localhost:8090") + }) + + t.Run("error while getting service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service = nil + + dest, err := MakeDestination(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found in DID document: did-communication") + require.Nil(t, dest) + }) + + t.Run("error while getting recipient keys from did doc", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service[0].RecipientKeys = []string{} + + recipientKeys, err := did.GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "missing recipient keys in did-communication service") + require.Nil(t, recipientKeys) + }) +} + +func createDIDDoc() *did.Doc { + pubKey, _ := generateKeyPair() + return createDIDDocWithKey(pubKey) +} + +func generateKeyPair() (string, []byte) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + return base58.Encode(pubKey[:]), privKey +} + +func createDIDDocWithKey(pub string) *did.Doc { + const ( + didFormat = "did:%s:%s" + didPKID = "%s#keys-%d" + didServiceID = "%s#endpoint-%d" + method = "test" + ) + + id := fmt.Sprintf(didFormat, method, pub[:16]) + pubKeyID := fmt.Sprintf(didPKID, id, 1) + pubKey := did.PublicKey{ + ID: pubKeyID, + Type: "Ed25519VerificationKey2018", + Controller: id, + Value: []byte(pub), + } + services := []did.Service{ + { + ID: fmt.Sprintf(didServiceID, id, 1), + Type: "did-communication", + ServiceEndpoint: "http://localhost:58416", + Priority: 0, + RecipientKeys: []string{pubKeyID}, + }, + } + createdTime := time.Now() + didDoc := &did.Doc{ + Context: []string{did.Context}, + ID: id, + PublicKey: []did.PublicKey{pubKey}, + Service: services, + Created: &createdTime, + Updated: &createdTime, + } + + return didDoc +} diff --git a/pkg/didcomm/common/service/mocks/mocks.go b/pkg/didcomm/common/service/mocks/mocks.go index 9cb165e734..e40e76320a 100644 --- a/pkg/didcomm/common/service/mocks/mocks.go +++ b/pkg/didcomm/common/service/mocks/mocks.go @@ -5,9 +5,11 @@ package mocks import ( + reflect "reflect" + gomock "github.com/golang/mock/gomock" + service "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" - reflect "reflect" ) // MockDIDComm is a mock of DIDComm interface @@ -34,9 +36,9 @@ func (m *MockDIDComm) EXPECT() *MockDIDCommMockRecorder { } // HandleInbound mocks base method -func (m *MockDIDComm) HandleInbound(arg0 *service.DIDCommMsg) (string, error) { +func (m *MockDIDComm) HandleInbound(msg *service.DIDCommMsg, myDID string, theirDID string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleInbound", arg0) + ret := m.ctrl.Call(m, "HandleInbound", msg) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 diff --git a/pkg/didcomm/common/service/service.go b/pkg/didcomm/common/service/service.go index f729400e1e..a83d7716dc 100644 --- a/pkg/didcomm/common/service/service.go +++ b/pkg/didcomm/common/service/service.go @@ -16,7 +16,7 @@ import ( // Handler provides protocol service handle api. type Handler interface { // HandleInbound handles inbound messages. - HandleInbound(msg *DIDCommMsg) (string, error) + HandleInbound(msg *DIDCommMsg, myDID string, theirDID string) (string, error) // HandleOutbound handles outbound messages. HandleOutbound(msg *DIDCommMsg, dest *Destination) error } @@ -107,11 +107,3 @@ func (m *DIDCommMsg) ThreadID() (string, error) { return "", ErrThreadIDNotFound } - -// Destination provides the recipientKeys, routingKeys, and serviceEndpoint populated from Invitation -type Destination struct { - RecipientKeys []string - ServiceEndpoint string - RoutingKeys []string - TransportReturnRoute string -} diff --git a/pkg/didcomm/common/transport/envelope.go b/pkg/didcomm/common/transport/envelope.go index c910c7a0f8..a2c6998499 100644 --- a/pkg/didcomm/common/transport/envelope.go +++ b/pkg/didcomm/common/transport/envelope.go @@ -11,4 +11,6 @@ type Envelope struct { Message []byte FromVerKey string ToVerKeys []string + FromDID string + ToDID string } diff --git a/pkg/didcomm/dispatcher/api.go b/pkg/didcomm/dispatcher/api.go index 99ca95b14d..004b5f8df4 100644 --- a/pkg/didcomm/dispatcher/api.go +++ b/pkg/didcomm/dispatcher/api.go @@ -20,4 +20,5 @@ type Service interface { // Outbound interface type Outbound interface { Send(interface{}, string, *service.Destination) error + SendToDID(msg interface{}, myDID, theirDID string) error } diff --git a/pkg/didcomm/dispatcher/outbound.go b/pkg/didcomm/dispatcher/outbound.go index b6f651c84b..05e5662ab6 100644 --- a/pkg/didcomm/dispatcher/outbound.go +++ b/pkg/didcomm/dispatcher/outbound.go @@ -15,6 +15,7 @@ import ( commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" ) // provider interface for outbound ctx @@ -22,6 +23,7 @@ type provider interface { Packager() commontransport.Packager OutboundTransports() []transport.OutboundTransport TransportReturnRoute() string + VDRIRegistry() vdri.Registry } // OutboundDispatcher dispatch msgs to destination @@ -29,15 +31,39 @@ type OutboundDispatcher struct { outboundTransports []transport.OutboundTransport packager commontransport.Packager transportReturnRoute string + vdRegistry vdri.Registry } +const ( + ed25519KeyType = "Ed25519VerificationKey2018" + didCommServiceType = "did-communication" +) + // NewOutbound return new dispatcher outbound instance func NewOutbound(prov provider) *OutboundDispatcher { return &OutboundDispatcher{ outboundTransports: prov.OutboundTransports(), packager: prov.Packager(), transportReturnRoute: prov.TransportReturnRoute(), + vdRegistry: prov.VDRIRegistry(), + } +} + +// SendToDID sends a message from myDID to the agent who owns theirDID +func (o *OutboundDispatcher) SendToDID(msg interface{}, myDID, theirDID string) error { + dest, err := service.GetDestination(theirDID, didCommServiceType, ed25519KeyType, o.vdRegistry) + if err != nil { + return err } + + src, err := service.GetDestination(myDID, didCommServiceType, ed25519KeyType, o.vdRegistry) + if err != nil { + return err + } + + key := src.RecipientKeys[0] + + return o.Send(msg, key, dest) } // Send msg diff --git a/pkg/didcomm/dispatcher/outbound_test.go b/pkg/didcomm/dispatcher/outbound_test.go index c69056f1ce..b1ea174dff 100644 --- a/pkg/didcomm/dispatcher/outbound_test.go +++ b/pkg/didcomm/dispatcher/outbound_test.go @@ -19,6 +19,7 @@ import ( commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" mockdidcomm "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" ) @@ -141,6 +142,7 @@ type mockProvider struct { packagerValue commontransport.Packager outboundTransportsValue []transport.OutboundTransport transportReturnRoute string + vdriRegistry vdri.Registry } func (p *mockProvider) Packager() commontransport.Packager { @@ -155,6 +157,10 @@ func (p *mockProvider) TransportReturnRoute() string { return p.transportReturnRoute } +func (p *mockProvider) VDRIRegistry() vdri.Registry { + return p.vdriRegistry +} + // mockOutboundTransport mock outbound transport type mockOutboundTransport struct { expectedRequest string diff --git a/pkg/didcomm/packager/package_test.go b/pkg/didcomm/packager/package_test.go index 99070b66e4..fec11af816 100644 --- a/pkg/didcomm/packager/package_test.go +++ b/pkg/didcomm/packager/package_test.go @@ -13,12 +13,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" . "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" jwe "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/jwe/authcrypt" legacy "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer/legacy/authcrypt" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" + mockdidconnection "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -26,16 +28,15 @@ import ( func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test failed to unmarshal encMessage", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -49,16 +50,15 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test bad encoding type", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -87,17 +87,16 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test key not found", func(t *testing.T) { - wp := newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }}) + wp := newMockKMSProvider(newMockStoreProvider()) w, err := kms.New(wp) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } testPacker, err := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, err) @@ -131,27 +130,26 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { }) t.Run("test Pack/Unpack fails", func(t *testing.T) { - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: make(map[string][]byte), - }})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) - decryptValue := func(envelope []byte) ([]byte, []byte, error) { - return nil, nil, fmt.Errorf("unpack error") + decryptValue := func(envelope []byte) ([]byte, []byte, []byte, error) { + return nil, nil, nil, fmt.Errorf("unpack error") } mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // use a mocked packager with a mocked KMS to validate pack/unpack e := func(payload []byte, senderPubKey []byte, recipientsKeys [][]byte) (bytes []byte, e error) { - packer, e := jwe.New(mockedProviders, jwe.XC20P) + p, e := jwe.New(mockedProviders, jwe.XC20P) require.NoError(t, e) - return packer.Pack(payload, senderPubKey, recipientsKeys) + return p.Pack(payload, senderPubKey, recipientsKeys) } mockPacker := &didcomm.MockAuthCrypt{DecryptValue: decryptValue, EncryptValue: e, Type: "prs.hyperledger.aries-auth-message"} @@ -203,14 +201,14 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { t.Run("test Pack/Unpack success", func(t *testing.T) { // create a mock KMS with storage as a map - w, err := kms.New(newMockKMSProvider(&mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ - Store: map[string][]byte{}}})) + w, err := kms.New(newMockKMSProvider(newMockStoreProvider())) require.NoError(t, err) mockedProviders := &mockProvider{ - storage: nil, + storage: newMockStoreProvider(), kms: w, primaryPacker: nil, packers: nil, + lookupStore: &mockdidconnection.MockDIDConnection{GetDIDValue: ""}, } // create a real testPacker (no mocking here) @@ -262,7 +260,13 @@ func TestBaseKMSInPackager_UnpackMessage(t *testing.T) { } func newMockKMSProvider(storagePvdr *mockstorage.MockStoreProvider) *mockProvider { - return &mockProvider{storagePvdr, nil, nil, nil} + return &mockProvider{storagePvdr, nil, nil, nil, nil} +} + +func newMockStoreProvider() *mockstorage.MockStoreProvider { + return &mockstorage.MockStoreProvider{Store: &mockstorage.MockStore{ + Store: make(map[string][]byte), + }} } // mockProvider mocks provider for KMS @@ -271,6 +275,7 @@ type mockProvider struct { kms kms.KeyManager packers []packer.Packer primaryPacker packer.Packer + lookupStore didconnection.Store } func (m *mockProvider) Packers() []packer.Packer { @@ -292,3 +297,8 @@ func (m *mockProvider) InboundTransportEndpoint() string { func (m *mockProvider) PrimaryPacker() packer.Packer { return m.primaryPacker } + +// DIDConnectionStore returns a didconnection.Store service. +func (m *mockProvider) DIDConnectionStore() didconnection.Store { + return m.lookupStore +} diff --git a/pkg/didcomm/packager/packager.go b/pkg/didcomm/packager/packager.go index 8ceda79c68..8df10f0541 100644 --- a/pkg/didcomm/packager/packager.go +++ b/pkg/didcomm/packager/packager.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcutil/base58" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" ) @@ -22,6 +23,7 @@ import ( type Provider interface { Packers() []packer.Packer PrimaryPacker() packer.Packer + DIDConnectionStore() didconnection.Store } // Creator method to create new packager service @@ -29,8 +31,9 @@ type Creator func(prov Provider) (transport.Packager, error) // Packager is the basic implementation of Packager type Packager struct { - primaryPacker packer.Packer - packers map[string]packer.Packer + primaryPacker packer.Packer + packers map[string]packer.Packer + connectionStore didconnection.Store } // PackerCreator holds a creator function for a Packer and the name of the Packer's encoding method. @@ -52,11 +55,13 @@ func New(ctx Provider) (*Packager, error) { basePackager.primaryPacker = ctx.PrimaryPacker() if basePackager.primaryPacker == nil { - return nil, fmt.Errorf("need primary primaryPacker to initialize packager") + return nil, fmt.Errorf("need primary packer to initialize packager") } basePackager.addPacker(basePackager.primaryPacker) + basePackager.connectionStore = ctx.DIDConnectionStore() + return &basePackager, nil } @@ -146,10 +151,27 @@ func (bp *Packager) UnpackMessage(encMessage []byte) (*transport.Envelope, error return nil, fmt.Errorf("message Type not recognized") } - data, senderVerKey, err := p.Unpack(encMessage) + data, senderVerKey, recVerKey, err := p.Unpack(encMessage) if err != nil { return nil, fmt.Errorf("unpack: %w", err) } - return &transport.Envelope{Message: data, FromVerKey: base58.Encode(senderVerKey)}, nil + theirDID, err := bp.connectionStore.GetDID(base58.Encode(senderVerKey)) + if err != nil { + // ignore error - agents can communicate without using DIDs - for example, in DIDExchange + theirDID = "" + } + + myDID, err := bp.connectionStore.GetDID(base58.Encode(recVerKey)) + if err != nil { + // TODO: find the reason my DID isn't persisted + return nil, fmt.Errorf("cannot find my DID from my key, %w", err) + } + + return &transport.Envelope{ + Message: data, + FromVerKey: base58.Encode(senderVerKey), + FromDID: theirDID, + ToDID: myDID, + }, nil } diff --git a/pkg/didcomm/packer/api.go b/pkg/didcomm/packer/api.go index 0c7b505034..ed24c9794c 100644 --- a/pkg/didcomm/packer/api.go +++ b/pkg/didcomm/packer/api.go @@ -32,9 +32,10 @@ type Packer interface { // returns: // []byte containing the decrypted payload // []byte contains the sender verification key + // []byte contains the recipient verification key // error if decryption failed // TODO add key type of recipients keys to be validated by the implementation - Issue #272 - Unpack(envelope []byte) ([]byte, []byte, error) + Unpack(envelope []byte) ([]byte, []byte, []byte, error) // Encoding returns the type of the encoding, as found in the header `Typ` field EncodingType() string diff --git a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go index d2dadf8ca5..8006b8ca83 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/jwe/authcrypt/authcrypt_test.go @@ -304,7 +304,7 @@ func TestEncrypt(t *testing.T) { t.Logf("Encryption with XC20P: %s", m) // decrypt for rec1 (as found in kms) - dec, senderVerKey, e := packer.Unpack(enc) + dec, senderVerKey, _, e := packer.Unpack(enc) require.NoError(t, e) require.NotEmpty(t, dec) require.EqualValues(t, dec, pld) @@ -331,10 +331,14 @@ func TestEncrypt(t *testing.T) { [][]byte{base58.Decode(recSign)}) require.NoError(t, err) - msgOut, sendKey, err := recPacker.Unpack(enc) + msgOut, sendKey, recKey, err := recPacker.Unpack(enc) require.NoError(t, err) require.Equal(t, msgIn, msgOut) require.Equal(t, base58.Encode(sendKey), encSign) + // this won't work, since input keys and output keys to this Packer + // are different types. + // require.Equal(t, base58.Encode(recKey), recSign) + require.NotNil(t, recKey) }) t.Run("Success test case: Decrypting a message with two PackerValue instances to simulate two agents", func(t *testing.T) { //nolint:lll @@ -356,7 +360,7 @@ func TestEncrypt(t *testing.T) { // now decrypt with recipient3 packer1, e := New(recipient3KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + dec, senderVerKey, _, e := packer1.Unpack(enc) require.NoError(t, e) require.NotEmpty(t, dec) require.EqualValues(t, dec, pld) @@ -366,7 +370,7 @@ func TestEncrypt(t *testing.T) { // now try decrypting with recipient2 packer2, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e = packer2.Unpack(enc) + dec, senderVerKey, _, e = packer2.Unpack(enc) require.NoError(t, e) require.NotEmpty(t, dec) require.EqualValues(t, dec, pld) @@ -392,7 +396,7 @@ func TestEncrypt(t *testing.T) { // decrypting for recipient 2 (unauthorized) packer1, e := New(recipient2KMSProvider, XC20P) require.NoError(t, e) - dec, senderVerKey, e := packer1.Unpack(enc) + dec, senderVerKey, _, e := packer1.Unpack(enc) require.Error(t, e) require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -427,7 +431,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e := packer.Unpack(enc) + dec, senderVerKey, _, e := packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 12") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -438,7 +442,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad nonce format - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 5") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -449,7 +453,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag format - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: illegal base64 data at input byte 6") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -460,7 +464,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: sender key: bad SPK format") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -471,7 +475,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -482,7 +486,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 5") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -493,7 +497,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: bad nonce size") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -504,7 +508,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 6") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -515,7 +519,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, fmt.Sprintf("unpack: %s", cryptoutil.ErrKeyNotFound.Error())) require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -526,7 +530,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: illegal base64 data at input byte 15") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -537,7 +541,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -549,7 +553,7 @@ func TestEncrypt(t *testing.T) { require.NoError(t, e) // decrypt with bad nonce value require.PanicsWithValue(t, "chacha20poly1305: bad nonce length passed to Open", func() { - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) }) require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -560,7 +564,7 @@ func TestEncrypt(t *testing.T) { enc, e = json.Marshal(jwe) require.NoError(t, e) // decrypt with bad tag - dec, senderVerKey, e = packer.Unpack(enc) + dec, senderVerKey, _, e = packer.Unpack(enc) require.EqualError(t, e, "unpack: decrypt shared key: chacha20poly1305: message authentication failed") require.Empty(t, dec) require.Empty(t, senderVerKey) @@ -658,7 +662,7 @@ func TestRefEncrypt(t *testing.T) { require.NoError(t, err) require.NotNil(t, packer) - dec, senderVerKey, err := packer.Unpack([]byte(refJWE)) + dec, senderVerKey, _, err := packer.Unpack([]byte(refJWE)) require.NoError(t, err) require.NotEmpty(t, dec) require.NotEmpty(t, senderVerKey) diff --git a/pkg/didcomm/packer/jwe/authcrypt/unpack.go b/pkg/didcomm/packer/jwe/authcrypt/unpack.go index ecbf4db551..4a321d62cd 100644 --- a/pkg/didcomm/packer/jwe/authcrypt/unpack.go +++ b/pkg/didcomm/packer/jwe/authcrypt/unpack.go @@ -22,22 +22,22 @@ import ( // encrypted CEK. // The current recipient is the one with the sender's encrypted key that successfully // decrypts with recipientKeyPair.Priv Key. -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, []byte, error) { jwe := &Envelope{} err := json.Unmarshal(envelope, jwe) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, nil, nil, fmt.Errorf("unpack: %w", err) } recipientPubKey, recipient, err := p.findRecipient(jwe.Recipients) if err != nil { - return nil, nil, fmt.Errorf("unpack: %w", err) + return nil, nil, nil, fmt.Errorf("unpack: %w", err) } senderKey, err := p.decryptSPK(recipientPubKey, recipient.Header.SPK) if err != nil { - return nil, nil, fmt.Errorf("unpack: sender key: %w", err) + return nil, nil, nil, fmt.Errorf("unpack: sender key: %w", err) } // senderKey must not be empty to proceed @@ -47,18 +47,18 @@ func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { sharedKey, er := p.decryptCEK(recipientPubKey, senderPubKey, recipient) if er != nil { - return nil, nil, fmt.Errorf("unpack: decrypt shared key: %w", er) + return nil, nil, nil, fmt.Errorf("unpack: decrypt shared key: %w", er) } symOutput, er := p.decryptPayload(sharedKey, jwe) if er != nil { - return nil, nil, fmt.Errorf("unpack: %w", er) + return nil, nil, nil, fmt.Errorf("unpack: %w", er) } - return symOutput, senderKey, nil + return symOutput, senderKey, recipientPubKey[:], nil } - return nil, nil, errors.New("unpack: invalid sender key in envelope") + return nil, nil, nil, errors.New("unpack: invalid sender key in envelope") } func (p *Packer) decryptPayload(cek []byte, jwe *Envelope) ([]byte, error) { diff --git a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go index 60d13677cc..acac27f26f 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go +++ b/pkg/didcomm/packer/legacy/authcrypt/authcrypt_test.go @@ -388,11 +388,12 @@ func TestDecrypt(t *testing.T) { enc, err := packer.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(recKey)}) require.NoError(t, err) - msgOut, senderVerKey, err := packer.Unpack(enc) + msgOut, senderVerKey, recVerKey, err := packer.Unpack(enc) require.NoError(t, err) require.ElementsMatch(t, msgIn, msgOut) require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.Equal(t, recKey, base58.Encode(recVerKey)) }) t.Run("Success: pack and unpack, different packers, including fail recipient who wasn't sent the message", func(t *testing.T) { // nolint: lll @@ -417,15 +418,16 @@ func TestDecrypt(t *testing.T) { enc, err := sendPacker.Pack(msgIn, base58.Decode(senderKey), [][]byte{base58.Decode(rec1Key), base58.Decode(rec2Key), base58.Decode(rec3Key)}) require.NoError(t, err) - msgOut, senderVerKey, err := rec2Packer.Unpack(enc) + msgOut, senderVerKey, recVerKey, err := rec2Packer.Unpack(enc) require.NoError(t, err) require.ElementsMatch(t, msgIn, msgOut) require.Equal(t, senderKey, base58.Encode(senderVerKey)) + require.Equal(t, rec2Key, base58.Encode(recVerKey)) emptyKMS, _ := newKMS(t) rec4Packer := newWithKMS(emptyKMS) - _, _, err = rec4Packer.Unpack(enc) + _, _, _, err = rec4Packer.Unpack(enc) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -444,10 +446,11 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + msgOut, senderVerKey, recVerKey, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) require.ElementsMatch(t, []byte(msg), msgOut) require.NotEmpty(t, senderVerKey) + require.NotEmpty(t, recVerKey) }) t.Run("Test unpacking python envelope with multiple recipients", func(t *testing.T) { @@ -465,10 +468,11 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - msgOut, senderVerKey, err := recPacker.Unpack([]byte(env)) + msgOut, senderVerKey, recVerKey, err := recPacker.Unpack([]byte(env)) require.NoError(t, err) require.ElementsMatch(t, []byte(msg), msgOut) require.NotEmpty(t, senderVerKey) + require.NotEmpty(t, recVerKey) }) t.Run("Test unpacking python envelope with invalid recipient", func(t *testing.T) { @@ -483,7 +487,7 @@ func TestDecrypt(t *testing.T) { recPacker := newWithKMS(recKMS) - _, _, err = recPacker.Unpack([]byte(env)) + _, _, _, err = recPacker.Unpack([]byte(env)) require.NotNil(t, err) require.Contains(t, err.Error(), "no key accessible") }) @@ -506,7 +510,8 @@ func unpackComponentFailureTest( } recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(fullMessage)) + // nolint: dogsled + _, _, _, err = recPacker.Unpack([]byte(fullMessage)) require.NotNil(t, err) require.Contains(t, err.Error(), errString) } @@ -524,7 +529,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, _, _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "invalid character 'e' looking for beginning of value") }) @@ -537,7 +542,7 @@ func TestUnpackComponents(t *testing.T) { recPacker := newWithKMS(w) - _, _, err = recPacker.Unpack([]byte(msg)) + _, _, _, err = recPacker.Unpack([]byte(msg)) require.EqualError(t, err, "illegal base64 data at input byte 0") }) @@ -698,7 +703,8 @@ func Test_getCEK(t *testing.T) { }, } - _, _, err := getCEK(recs, &k) + // nolint: dogsled + _, _, _, err := getCEK(recs, &k) require.EqualError(t, err, "mock error") } diff --git a/pkg/didcomm/packer/legacy/authcrypt/unpack.go b/pkg/didcomm/packer/legacy/authcrypt/unpack.go index ade93d45cc..3542ba995a 100644 --- a/pkg/didcomm/packer/legacy/authcrypt/unpack.go +++ b/pkg/didcomm/packer/legacy/authcrypt/unpack.go @@ -20,46 +20,46 @@ import ( // Unpack will decode the envelope using the legacy format // Using (X)Chacha20 encryption algorithm and Poly1035 authenticator -func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(envelope []byte) ([]byte, []byte, []byte, error) { var envelopeData legacyEnvelope err := json.Unmarshal(envelope, &envelopeData) if err != nil { - return nil, nil, err + return nil, nil, nil, err } protectedBytes, err := base64.URLEncoding.DecodeString(envelopeData.Protected) if err != nil { - return nil, nil, err + return nil, nil, nil, err } var protectedData protected err = json.Unmarshal(protectedBytes, &protectedData) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if protectedData.Typ != encodingType { - return nil, nil, fmt.Errorf("message type %s not supported", protectedData.Typ) + return nil, nil, nil, fmt.Errorf("message type %s not supported", protectedData.Typ) } if protectedData.Alg != "Authcrypt" { // TODO https://github.com/hyperledger/aries-framework-go/issues/41 change this when anoncrypt is introduced - return nil, nil, fmt.Errorf("message format %s not supported", protectedData.Alg) + return nil, nil, nil, fmt.Errorf("message format %s not supported", protectedData.Alg) } - cek, recKey, err := getCEK(protectedData.Recipients, p.kms) + cek, senderKey, recKey, err := getCEK(protectedData.Recipients, p.kms) if err != nil { - return nil, nil, err + return nil, nil, nil, err } data, err := p.decodeCipherText(cek, &envelopeData) - return data, recKey, err + return data, senderKey, recKey, err } -func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, []byte, error) { +func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, []byte, []byte, error) { var candidateKeys []string for _, candidate := range recipients { @@ -68,47 +68,47 @@ func getCEK(recipients []recipient, km kms.KeyManager) (*[chacha.KeySize]byte, [ recKeyIdx, err := km.FindVerKey(candidateKeys) if err != nil { - return nil, nil, fmt.Errorf("no key accessible %w", err) + return nil, nil, nil, fmt.Errorf("no key accessible %w", err) } recip := recipients[recKeyIdx] - recKey := recip.Header.KID + recKey := base58.Decode(recip.Header.KID) - recCurvePub, err := km.ConvertToEncryptionKey(base58.Decode(recKey)) + recCurvePub, err := km.ConvertToEncryptionKey(recKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } senderPub, senderPubCurve, err := decodeSender(recip.Header.Sender, recCurvePub, km) if err != nil { - return nil, nil, err + return nil, nil, nil, err } nonceSlice, err := base64.URLEncoding.DecodeString(recip.Header.IV) if err != nil { - return nil, nil, err + return nil, nil, nil, err } encCEK, err := base64.URLEncoding.DecodeString(recip.EncryptedKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } b, err := kms.NewCryptoBox(km) if err != nil { - return nil, nil, err + return nil, nil, nil, err } cekSlice, err := b.EasyOpen(encCEK, nonceSlice, senderPubCurve, recCurvePub) if err != nil { - return nil, nil, fmt.Errorf("failed to decrypt CEK: %s", err) + return nil, nil, nil, fmt.Errorf("failed to decrypt CEK: %s", err) } var cek [chacha.KeySize]byte copy(cek[:], cekSlice) - return &cek, senderPub, nil + return &cek, senderPub, recKey, nil } func decodeSender(b64Sender string, pk []byte, km kms.KeyManager) ([]byte, []byte, error) { diff --git a/pkg/didcomm/protocol/didexchange/persistence.go b/pkg/didcomm/protocol/didexchange/persistence.go index fc9f670fe9..4071a2f3ac 100644 --- a/pkg/didcomm/protocol/didexchange/persistence.go +++ b/pkg/didcomm/protocol/didexchange/persistence.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/storage" ) @@ -54,14 +55,15 @@ func (r *ConnectionRecord) isValid() error { } // NewConnectionRecorder returns new connection record instance -func NewConnectionRecorder(transientStore, store storage.Store) *ConnectionRecorder { - return &ConnectionRecorder{transientStore: transientStore, store: store} +func NewConnectionRecorder(transientStore, store storage.Store, didMap didconnection.Store) *ConnectionRecorder { + return &ConnectionRecorder{transientStore: transientStore, store: store, didMap: didMap} } // ConnectionRecorder takes care of connection related persistence features type ConnectionRecorder struct { transientStore storage.Store store storage.Store + didMap didconnection.Store } // SaveInvitation saves connection invitation to underlying store @@ -231,6 +233,10 @@ func (c *ConnectionRecorder) saveConnectionRecord(record *ConnectionRecord) erro if err := marshalAndSave(connectionKeyPrefix(record.ConnectionID), record, c.store); err != nil { return fmt.Errorf("save connection record in permanent store: %w", err) } + + if err := c.didMap.SaveDIDConnection(record.MyDID, record.TheirDID, record.RecipientKeys); err != nil { + return err + } } return nil @@ -259,6 +265,12 @@ func (c *ConnectionRecorder) saveNewConnectionRecord(record *ConnectionRecord) e return fmt.Errorf("save new connection record: %w", err) } + if record.MyDID != "" { + if err := c.didMap.SaveDIDFromDoc(record.MyDID, didCommServiceType, ed25519KeyType); err != nil { + return err + } + } + return c.saveNSThreadID(record.ThreadID, record.Namespace, record.ConnectionID) } diff --git a/pkg/didcomm/protocol/didexchange/persistence_test.go b/pkg/didcomm/protocol/didexchange/persistence_test.go index 46272b26ac..cac636edd5 100644 --- a/pkg/didcomm/protocol/didexchange/persistence_test.go +++ b/pkg/didcomm/protocol/didexchange/persistence_test.go @@ -47,7 +47,7 @@ func Test_ComputeHash(t *testing.T) { func TestConnectionRecord_SaveInvitation(t *testing.T) { t.Run("test save invitation success", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) value := &Invitation{ @@ -71,7 +71,7 @@ func TestConnectionRecord_SaveInvitation(t *testing.T) { t.Run("test save invitation failure due to invalid key", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) value := &Invitation{ @@ -86,7 +86,7 @@ func TestConnectionRecord_SaveInvitation(t *testing.T) { func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - success", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueStored := &Invitation{ @@ -104,7 +104,7 @@ func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - not found scenario", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueFound, err := record.GetInvitation("sample-key4") @@ -115,7 +115,7 @@ func TestConnectionRecorder_GetInvitation(t *testing.T) { t.Run("test get invitation - invalid key scenario", func(t *testing.T) { store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(nil, store) + record := NewConnectionRecorder(nil, store, nil) require.NotNil(t, record) valueFound, err := record.GetInvitation("") @@ -129,7 +129,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test success found data in transient store ", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix} @@ -148,7 +148,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test success found data in store ", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix} @@ -169,7 +169,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { Store: make(map[string][]byte), ErrGet: fmt.Errorf("get error transientstore")} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRecBytes, err := json.Marshal(&ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix}) @@ -183,7 +183,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { t.Run("test error from store", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte), ErrGet: fmt.Errorf("get error store")} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRecBytes, err := json.Marshal(&ConnectionRecord{ConnectionID: connIDValue, ThreadID: threadIDValue, Namespace: myNSPrefix}) @@ -198,7 +198,7 @@ func TestConnectionRecorder_GetConnectionRecord(t *testing.T) { func TestConnectionRecordByState(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrGet: nil} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ConnectionID: generateRandomID(), ThreadID: threadIDValue, @@ -235,7 +235,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { t.Run("save connection record and get connection Record success", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -248,7 +248,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record and fetch from no namespace error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited} @@ -258,7 +258,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: "", ConnectionID: "test", State: stateNameInvited, Namespace: theirNSPrefix} @@ -267,7 +267,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { }) t.Run("save connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -277,7 +277,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { t.Run("save connection record in permanent store error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} store := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: fmt.Errorf("get error")} - record := NewConnectionRecorder(transientStore, store) + record := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameCompleted, Namespace: theirNSPrefix} @@ -289,7 +289,7 @@ func TestConnectionRecorder_SaveConnectionRecord(t *testing.T) { func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { t.Run(" get connection record by namespace threadID in my namespace", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: myNSPrefix} @@ -305,7 +305,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { }) t.Run(" get connection record by namespace threadID their namespace", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec := &ConnectionRecord{ThreadID: threadIDValue, ConnectionID: connIDValue, State: stateNameInvited, Namespace: theirNSPrefix} @@ -321,7 +321,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { }) t.Run(" data not found error due to missing input parameter", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec, err := record.GetConnectionRecordByNSThreadID("") require.Contains(t, err.Error(), "data not found") @@ -332,7 +332,7 @@ func TestConnectionRecorder_GetConnectionRecordByNSThreadID(t *testing.T) { func TestConnectionRecorder_PrepareConnectionRecord(t *testing.T) { t.Run(" prepare connection record error", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) connRec, err := prepareConnectionRecord(nil) require.Contains(t, err.Error(), "prepare connection record") @@ -355,7 +355,7 @@ func TestConnectionRecorder_CreateNSKeys(t *testing.T) { func TestConnectionRecorder_SaveNSThreadID(t *testing.T) { t.Run("missing required parameters", func(t *testing.T) { transientStore := &mockstorage.MockStore{Store: make(map[string][]byte)} - record := NewConnectionRecorder(transientStore, nil) + record := NewConnectionRecorder(transientStore, nil, nil) require.NotNil(t, record) err := record.saveNSThreadID("", theirNSPrefix, connIDValue) require.Error(t, err) @@ -398,7 +398,7 @@ func TestConnectionRecorder_QueryConnectionRecord(t *testing.T) { require.NoError(t, err) } - recorder := NewConnectionRecorder(transientStore, store) + recorder := NewConnectionRecorder(transientStore, store, nil) require.NotNil(t, recorder) result, err := recorder.QueryConnectionRecords() require.NoError(t, err) @@ -410,7 +410,7 @@ func TestConnectionRecorder_QueryConnectionRecord(t *testing.T) { err := store.Put(fmt.Sprintf("%s_abc123", connIDKeyPrefix), []byte("-----")) require.NoError(t, err) - recorder := NewConnectionRecorder(nil, store) + recorder := NewConnectionRecorder(nil, store, nil) require.NotNil(t, recorder) result, err := recorder.QueryConnectionRecords() require.Error(t, err) diff --git a/pkg/didcomm/protocol/didexchange/service.go b/pkg/didcomm/protocol/didexchange/service.go index 78420cffcd..32fa1b6eaf 100644 --- a/pkg/didcomm/protocol/didexchange/service.go +++ b/pkg/didcomm/protocol/didexchange/service.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" @@ -57,6 +58,7 @@ type provider interface { OutboundDispatcher() dispatcher.Outbound StorageProvider() storage.Provider TransientStorageProvider() storage.Provider + DIDConnectionStore() didconnection.Store Signer() kms.Signer VDRIRegistry() vdriapi.Registry } @@ -76,6 +78,7 @@ type Service struct { ctx *context callbackChannel chan *message connectionStore *ConnectionRecorder + didConnections didconnection.Store } type context struct { @@ -106,7 +109,7 @@ func New(prov provider) (*Service, error) { return nil, err } - connRecorder := NewConnectionRecorder(transientStore, store) + connRecorder := NewConnectionRecorder(transientStore, store, prov.DIDConnectionStore()) svc := &Service{ ctx: &context{ outboundDispatcher: prov.OutboundDispatcher(), @@ -117,6 +120,7 @@ func New(prov provider) (*Service, error) { // TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246 callbackChannel: make(chan *message, 10), connectionStore: connRecorder, + didConnections: prov.DIDConnectionStore(), } // start the listener @@ -126,7 +130,7 @@ func New(prov provider) (*Service, error) { } // HandleInbound handles inbound didexchange messages. -func (s *Service) HandleInbound(msg *service.DIDCommMsg) (string, error) { +func (s *Service) HandleInbound(msg *service.DIDCommMsg, myDID, theirDID string) (string, error) { logger.Debugf("receive inbound message : %s", msg.Payload) // fetch the thread id @@ -252,13 +256,13 @@ func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) erro logger.Debugf("finished execute state: %s", next.Name()) if err = s.update(msg.Msg.Header.Type, connectionRecord); err != nil { - return fmt.Errorf("failed to persist state %s %w", next.Name(), err) + return fmt.Errorf("failed to persist state %s: %w", next.Name(), err) } logger.Debugf("persisted the connection record using connection id %s", connectionRecord.ConnectionID) if err = action(); err != nil { - return fmt.Errorf("failed to execute state action %s %w", next.Name(), err) + return fmt.Errorf("failed to execute state action %s: %w", next.Name(), err) } logger.Debugf("finish execute state action: %s", next.Name()) @@ -271,7 +275,7 @@ func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) erro if canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) { msg.NextStateName = next.Name() if err = s.sendActionEvent(msg, aEvent); err != nil { - return fmt.Errorf("handle inbound : %w", err) + return fmt.Errorf("handle inbound: %w", err) } haltExecution = true @@ -636,7 +640,11 @@ func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID, inviteeLabe return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err) } - dest, err := prepareDestination(didDoc) + dest, err := service.MakeDestination(didDoc, didCommServiceType, ed25519KeyType) + if err != nil { + return "", err + } + thID := generateRandomID() connRecord := &ConnectionRecord{ ConnectionID: generateRandomID(), diff --git a/pkg/didcomm/protocol/didexchange/service_test.go b/pkg/didcomm/protocol/didexchange/service_test.go index ea62e3b90c..fb91033f6d 100644 --- a/pkg/didcomm/protocol/didexchange/service_test.go +++ b/pkg/didcomm/protocol/didexchange/service_test.go @@ -14,15 +14,15 @@ import ( "testing" "time" - "github.com/btcsuite/btcutil/base58" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" - "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/protocol" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -69,17 +69,21 @@ func TestService_Handle_Inviter(t *testing.T) { prov := protocol.MockProvider{} store := mockstorage.NewMockStoreProvider() pubKey, privKey := generateKeyPair() + didConnectionStore := didconnection.MockDIDConnection{} + ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, &didConnectionStore), } newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) - s, err := New(&protocol.MockProvider{StoreProvider: store}) + s, err := New(&protocol.MockProvider{StoreProvider: store, + DIDConnectionStoreValue: &didConnectionStore, + }) require.NoError(t, err) actionCh := make(chan service.DIDCommAction, 10) @@ -128,7 +132,7 @@ func TestService_Handle_Inviter(t *testing.T) { require.NoError(t, err) msg, err := service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(msg) + _, err = s.HandleInbound(msg, newDidDoc.ID, "") require.NoError(t, err) select { @@ -150,7 +154,7 @@ func TestService_Handle_Inviter(t *testing.T) { didMsg, err := service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(didMsg) + _, err = s.HandleInbound(didMsg, newDidDoc.ID, "theirDID") require.NoError(t, err) select { @@ -203,19 +207,24 @@ func TestService_Handle_Invitee(t *testing.T) { store := mockstorage.NewMockStoreProvider() prov := protocol.MockProvider{} pubKey, privKey := generateKeyPair() + didConnectionStore := didconnection.MockDIDConnection{} + ctx := context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(transientStore.Store, store.Store), + connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, &didConnectionStore), } newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) s, err := New( - &protocol.MockProvider{StoreProvider: store, - TransientStoreProvider: transientStore}) + &protocol.MockProvider{ + StoreProvider: store, + TransientStoreProvider: transientStore, + DIDConnectionStoreValue: &didConnectionStore, + }) require.NoError(t, err) s.ctx.vdriRegistry = &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc} @@ -251,7 +260,7 @@ func TestService_Handle_Invitee(t *testing.T) { didMsg, err := service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(didMsg) + _, err = s.HandleInbound(didMsg, "", "") require.NoError(t, err) var connID string @@ -293,7 +302,7 @@ func TestService_Handle_Invitee(t *testing.T) { didMsg, err = service.NewDIDCommMsg(payloadBytes) require.NoError(t, err) - _, err = s.HandleInbound(didMsg) + _, err = s.HandleInbound(didMsg, "", "") require.NoError(t, err) // Alice automatically sends an ACK to Bob @@ -345,7 +354,7 @@ func TestService_Handle_EdgeCases(t *testing.T) { didMsg, err := service.NewDIDCommMsg(response) require.NoError(t, err) - _, err = s.HandleInbound(didMsg) + _, err = s.HandleInbound(didMsg, "", "") require.Error(t, err) require.Contains(t, err.Error(), "handle inbound - next state : invalid state transition: "+ "null -> responded") @@ -366,7 +375,7 @@ func TestService_Handle_EdgeCases(t *testing.T) { didMsg, err := service.NewDIDCommMsg(requestBytes) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.Error(t, err) require.Equal(t, err.Error(), "threadID not found") }) @@ -379,9 +388,9 @@ func TestService_Handle_EdgeCases(t *testing.T) { require.NoError(t, err) transientStore := &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")} - svc.connectionStore = NewConnectionRecorder(transientStore, nil) + svc.connectionStore = NewConnectionRecorder(transientStore, nil, nil) - _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "")) + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), ""), "", "") require.Error(t, err) require.Contains(t, err.Error(), "save connection record") }) @@ -419,7 +428,7 @@ func TestService_Handle_EdgeCases(t *testing.T) { didMsg, err := service.NewDIDCommMsg(requestBytes) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.NoError(t, err) }) } @@ -453,7 +462,7 @@ func TestService_CurrentState(t *testing.T) { svc := &Service{ connectionStore: NewConnectionRecorder(&mockStore{ get: func(string) ([]byte, error) { return nil, storage.ErrDataNotFound }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -469,7 +478,7 @@ func TestService_CurrentState(t *testing.T) { svc := &Service{ connectionStore: NewConnectionRecorder(&mockStore{ get: func(string) ([]byte, error) { return connRec, nil }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -484,7 +493,7 @@ func TestService_CurrentState(t *testing.T) { get: func(string) ([]byte, error) { return nil, errors.New("test") }, - }, nil), + }, nil, nil), } thid, err := createNSKey(theirNSPrefix, "ignored") require.NoError(t, err) @@ -511,7 +520,7 @@ func TestService_Update(t *testing.T) { get: func(k string) ([]byte, error) { return bytes, nil }, - }, nil), + }, nil, nil), } require.NoError(t, svc.update(RequestMsgType, connRec)) @@ -542,47 +551,6 @@ func (m *mockStore) Iterator(start, limit string) storage.StoreIterator { return nil } -func getMockDID() *did.Doc { - return &did.Doc{ - Context: []string{"https://w3id.org/did/v1"}, - ID: "did:peer:123456789abcdefghi#inbox", - Service: []did.Service{ - { - ServiceEndpoint: "https://localhost:8090", - Type: "did-communication", - Priority: 0, - RecipientKeys: []string{"did:example:123456789abcdefghi#keys-2"}, - }, - { - ServiceEndpoint: "https://localhost:8090", - Type: "did-communication", - Priority: 1, - RecipientKeys: []string{"did:example:123456789abcdefghi#keys-1"}, - }, - }, - PublicKey: []did.PublicKey{ - { - ID: "did:example:123456789abcdefghi#keys-1", - Controller: "did:example:123456789abcdefghi", - Type: "Secp256k1VerificationKey2018", - Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), - }, - { - ID: "did:example:123456789abcdefghi#keys-2", - Controller: "did:example:123456789abcdefghi", - Type: "Ed25519VerificationKey2018", - Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), - }, - { - ID: "did:example:123456789abcdefghw#key2", - Controller: "did:example:123456789abcdefghw", - Type: "RsaVerificationKey2018", - Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), - }, - }, - } -} - func randomString() string { u := uuid.New() return u.String() @@ -628,7 +596,7 @@ func TestEventsSuccess(t *testing.T) { didMsg, err := service.NewDIDCommMsg(invite) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.NoError(t, err) select { @@ -639,7 +607,7 @@ func TestEventsSuccess(t *testing.T) { } func TestContinueWithPublicDID(t *testing.T) { - didDoc := getMockDID() + didDoc := mockdiddoc.GetMockDIDDoc() svc, err := New(&protocol.MockProvider{}) require.NoError(t, err) @@ -665,7 +633,7 @@ func TestContinueWithPublicDID(t *testing.T) { didMsg, err := service.NewDIDCommMsg(invite) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.NoError(t, err) } @@ -722,7 +690,7 @@ func TestEventsUserError(t *testing.T) { err = svc.connectionStore.saveNewConnectionRecord(connRec) require.NoError(t, err) - _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, id, "")) + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, id, ""), "", "") require.NoError(t, err) select { @@ -749,7 +717,7 @@ func TestEventStoreError(t *testing.T) { } }() - _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "")) + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), ""), "", "") require.NoError(t, err) } @@ -803,8 +771,8 @@ func TestServiceErrors(t *testing.T) { mockStore := &mockStore{get: func(s string) (bytes []byte, e error) { return nil, errors.New("error") }} - svc.connectionStore = NewConnectionRecorder(mockStore, nil) - _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "")) + svc.connectionStore = NewConnectionRecorder(mockStore, nil, nil) + _, err = svc.HandleInbound(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), ""), "", "") require.Error(t, err) require.Contains(t, err.Error(), "cannot fetch state from store") @@ -813,9 +781,9 @@ func TestServiceErrors(t *testing.T) { transientStore, err := mockstorage.NewMockStoreProvider().OpenStore(DIDExchange) require.NoError(t, err) - svc.connectionStore = NewConnectionRecorder(transientStore, nil) + svc.connectionStore = NewConnectionRecorder(transientStore, nil, nil) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.Error(t, err) require.Contains(t, err.Error(), "unrecognized msgType: invalid") @@ -898,7 +866,7 @@ func TestInvitationRecord(t *testing.T) { // db error svc.connectionStore = NewConnectionRecorder( - &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil) + &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil, nil) invitationBytes, err = json.Marshal(&Invitation{ Type: InvitationMsgType, @@ -926,7 +894,7 @@ func TestRequestRecord(t *testing.T) { // db error svc.connectionStore = NewConnectionRecorder( - &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil) + &mockstorage.MockStore{Store: make(map[string][]byte), ErrPut: errors.New("db error")}, nil, nil) _, err = svc.requestMsgRecord(generateRequestMsgPayload(t, &protocol.MockProvider{}, randomString(), "")) @@ -977,7 +945,7 @@ func TestAcceptExchangeRequest(t *testing.T) { }() _, err = svc.HandleInbound(generateRequestMsgPayload(t, - &protocol.MockProvider{StoreProvider: mockstorage.NewMockStoreProvider()}, randomString(), invitation.ID)) + &protocol.MockProvider{StoreProvider: mockstorage.NewMockStoreProvider()}, randomString(), invitation.ID), "", "") require.NoError(t, err) select { @@ -1037,7 +1005,7 @@ func TestAcceptExchangeRequestWithPublicDID(t *testing.T) { }() _, err = svc.HandleInbound(generateRequestMsgPayload(t, - &protocol.MockProvider{StoreProvider: mockstorage.NewMockStoreProvider()}, randomString(), invitation.ID)) + &protocol.MockProvider{StoreProvider: mockstorage.NewMockStoreProvider()}, randomString(), invitation.ID), "", "") require.NoError(t, err) select { @@ -1098,7 +1066,7 @@ func TestAcceptInvitation(t *testing.T) { didMsg, err := service.NewDIDCommMsg(invitationBytes) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.NoError(t, err) select { @@ -1213,7 +1181,7 @@ func TestAcceptInvitationWithPublicDID(t *testing.T) { didMsg, err := service.NewDIDCommMsg(invitationBytes) require.NoError(t, err) - _, err = svc.HandleInbound(didMsg) + _, err = svc.HandleInbound(didMsg, "", "") require.NoError(t, err) select { @@ -1380,8 +1348,8 @@ func TestFetchConnectionRecord(t *testing.T) { func generateRequestMsgPayload(t *testing.T, prov provider, id, invitationID string) *service.DIDCommMsg { store := mockstorage.NewMockStoreProvider() ctx := context{outboundDispatcher: prov.OutboundDispatcher(), - vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: getMockDID()}, - connectionStore: NewConnectionRecorder(nil, store.Store)} + vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, + connectionStore: NewConnectionRecorder(nil, store.Store, nil)} newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) @@ -1413,7 +1381,7 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, nil), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1433,7 +1401,7 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolve error")}, - connectionStore: NewConnectionRecorder(nil, store.Store), + connectionStore: NewConnectionRecorder(nil, store.Store, nil), } s, err := New(&protocol.MockProvider{StoreProvider: store}) @@ -1456,7 +1424,7 @@ func TestService_CreateImplicitInvitation(t *testing.T) { ctx := &context{ outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDIDDoc}, - connectionStore: NewConnectionRecorder(transientStore.Store, store.Store), + connectionStore: NewConnectionRecorder(transientStore.Store, store.Store, nil), } s, err := New(&protocol.MockProvider{StoreProvider: store, TransientStoreProvider: transientStore}) diff --git a/pkg/didcomm/protocol/didexchange/states.go b/pkg/didcomm/protocol/didexchange/states.go index a6f43367d2..8d7ad7ccba 100644 --- a/pkg/didcomm/protocol/didexchange/states.go +++ b/pkg/didcomm/protocol/didexchange/states.go @@ -308,7 +308,7 @@ func (ctx *context) handleInboundInvitation(invitation *Invitation, } connRec.MyDID = request.Connection.DID - senderVerKeys, err := getRecipientKeys(didDoc) + senderVerKeys, err := did.GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) if err != nil { return nil, nil, fmt.Errorf("getting sender verification keys: %w", err) } @@ -326,6 +326,7 @@ func (ctx *context) handleInboundRequest(request *Request, options *options, con } // get did document that will be used in exchange response + // (my did doc) responseDidDoc, connection, err := ctx.getDIDDocAndConnection(getPublicDID(options)) if err != nil { return nil, nil, err @@ -351,12 +352,12 @@ func (ctx *context) handleInboundRequest(request *Request, options *options, con connRec.MyDID = connection.DID connRec.TheirLabel = request.Label - destination, err := prepareDestination(requestDidDoc) + destination, err := service.MakeDestination(requestDidDoc, didCommServiceType, ed25519KeyType) if err != nil { return nil, nil, err } - senderVerKeys, err := getRecipientKeys(responseDidDoc) + senderVerKeys, err := did.GetRecipientKeys(responseDidDoc, didCommServiceType, ed25519KeyType) if err != nil { return nil, nil, err } @@ -385,7 +386,7 @@ func getLabel(options *options) string { func (ctx *context) getDestination(invitation *Invitation) (*service.Destination, error) { if invitation.DID != "" { - return ctx.getDestinationFromDID(invitation.DID) + return service.GetDestination(invitation.DID, didCommServiceType, ed25519KeyType, ctx.vdriRegistry) } return &service.Destination{ @@ -432,95 +433,6 @@ func (ctx *context) resolveDidDocFromConnection(conn *Connection) (*did.Doc, err return didDoc, nil } -func (ctx *context) getDestinationFromDID(id string) (*service.Destination, error) { - didDoc, err := ctx.vdriRegistry.Resolve(id) - if err != nil { - return nil, err - } - - return prepareDestination(didDoc) -} - -func getRecipientKeys(didDoc *did.Doc) ([]string, error) { - didCommService, err := getDidCommService(didDoc) - if err != nil { - return nil, err - } - - if len(didCommService.RecipientKeys) == 0 { - return nil, fmt.Errorf("missing recipient keys in did-communication service") - } - - var recipientKeys []string - - for _, keyID := range didCommService.RecipientKeys { - key, err := getPublicKey(keyID, didDoc) - if err != nil { - return nil, err - } - - if isSupportedKeyType(key.Type) { - recipientKeys = append(recipientKeys, string(key.Value)) - } - } - - if len(recipientKeys) == 0 { - return nil, fmt.Errorf("recipient keys in did-communication service not supported") - } - - return recipientKeys, nil -} - -func getPublicKey(id string, didDoc *did.Doc) (*did.PublicKey, error) { - for _, key := range didDoc.PublicKey { - if key.ID == id { - return &key, nil - } - } - - return nil, fmt.Errorf("key not found in DID document: %s", id) -} - -func isSupportedKeyType(keyType string) bool { - return keyType == ed25519KeyType -} - -func getDidCommService(didDoc *did.Doc) (*did.Service, error) { - const notFound = -1 - index := notFound - - for i, s := range didDoc.Service { - if s.Type == didCommServiceType { - if index == notFound || didDoc.Service[index].Priority > s.Priority { - index = i - } - } - } - - if index == notFound { - return nil, fmt.Errorf("service not found in DID document: %s", didCommServiceType) - } - - return &didDoc.Service[index], nil -} - -func prepareDestination(didDoc *did.Doc) (*service.Destination, error) { - didCommService, err := getDidCommService(didDoc) - if err != nil { - return nil, err - } - - recipientKeys, err := getRecipientKeys(didDoc) - if err != nil { - return nil, err - } - - return &service.Destination{ - RecipientKeys: recipientKeys, - ServiceEndpoint: didCommService.ServiceEndpoint, - }, nil -} - // Encode the connection and convert to Connection Signature as per the spec: // https://github.com/hyperledger/aries-rfcs/tree/master/features/0023-did-exchange func (ctx *context) prepareConnectionSignature(connection *Connection, @@ -597,7 +509,7 @@ func (ctx *context) handleInboundResponse(response *Response) (stateAction, *Con return nil, nil, fmt.Errorf("resolve did doc from exchange response connection: %w", err) } - destination, err := prepareDestination(responseDidDoc) + destination, err := service.MakeDestination(responseDidDoc, didCommServiceType, ed25519KeyType) if err != nil { return nil, nil, fmt.Errorf("prepare destination from response did doc: %w", err) } @@ -607,7 +519,7 @@ func (ctx *context) handleInboundResponse(response *Response) (stateAction, *Con return nil, nil, fmt.Errorf("fetching did document: %w", err) } - senderVerKeys, err := getRecipientKeys(myDidDoc) + senderVerKeys, err := did.GetRecipientKeys(myDidDoc, didCommServiceType, ed25519KeyType) if err != nil { return nil, nil, fmt.Errorf("get public keys: %w", err) } @@ -672,7 +584,7 @@ func (ctx *context) getInvitationRecipientKey(invitation *Invitation) (string, e return "", fmt.Errorf("get invitation recipient key: %w", err) } - recipientKeys, err := getRecipientKeys(didDoc) + recipientKeys, err := did.GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) if err != nil { return "", fmt.Errorf("get recipient keys from did: %w", err) } diff --git a/pkg/didcomm/protocol/didexchange/states_test.go b/pkg/didcomm/protocol/didexchange/states_test.go index 7adad15aea..c5d7e613a0 100644 --- a/pkg/didcomm/protocol/didexchange/states_test.go +++ b/pkg/didcomm/protocol/didexchange/states_test.go @@ -25,7 +25,9 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" diddoc "github.com/hyperledger/aries-framework-go/pkg/doc/did" + "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/didconnection" "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/protocol" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" mockstorage "github.com/hyperledger/aries-framework-go/pkg/internal/mock/storage" mockvdri "github.com/hyperledger/aries-framework-go/pkg/internal/mock/vdri" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -303,7 +305,7 @@ func TestRequestedState_Execute(t *testing.T) { ThreadID: "test", ConnectionID: "123", } - didDoc := getMockDID() + didDoc := mockdiddoc.GetMockDIDDoc() didDoc.Service[0].RecipientKeys = []string{"invalid"} ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: didDoc}, @@ -369,11 +371,11 @@ func TestRespondedState_Execute(t *testing.T) { require.Equal(t, (&completed{}).Name(), followup.Name()) }) t.Run("handle inbound request public key error", func(t *testing.T) { - didDoc := getMockDID() + didDoc := mockdiddoc.GetMockDIDDoc() didDoc.Service[0].RecipientKeys = []string{"invalid"} ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: didDoc}, signer: &mockSigner{}, - connectionStore: NewConnectionRecorder(store, store)} + connectionStore: NewConnectionRecorder(store, store, nil)} _, followup, _, err := (&responded{}).ExecuteInbound(&stateMachineMsg{ header: &service.Header{Type: RequestMsgType}, payload: requestPayloadBytes, @@ -407,7 +409,7 @@ func TestCompletedState_Execute(t *testing.T) { pubKey, privKey := generateKeyPair() _, store := getProvider() ctx := &context{signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store)} + connectionStore: NewConnectionRecorder(store, store, &didconnection.MockDIDConnection{})} newDIDDoc := createDIDDocWithKey(pubKey) connection := &Connection{ DID: newDIDDoc.ID, @@ -441,7 +443,7 @@ func TestCompletedState_Execute(t *testing.T) { } err = ctx.connectionStore.saveNewConnectionRecord(connRec) require.NoError(t, err) - ctx.vdriRegistry = &mockvdri.MockVDRIRegistry{ResolveValue: getMockDID()} + ctx.vdriRegistry = &mockvdri.MockVDRIRegistry{ResolveValue: mockdiddoc.GetMockDIDDoc()} require.NoError(t, err) _, followup, _, e := (&completed{}).ExecuteInbound(&stateMachineMsg{ header: &service.Header{Type: ResponseMsgType}, @@ -512,7 +514,7 @@ func TestVerifySignature(t *testing.T) { pubKey, privKey := generateKeyPair() _, store := getProvider() ctx := &context{signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(nil, store)} + connectionStore: NewConnectionRecorder(nil, store, nil)} newDIDDoc := createDIDDocWithKey(pubKey) connection := &Connection{ DID: newDIDDoc.ID, @@ -650,7 +652,7 @@ func TestPrepareConnectionSignature(t *testing.T) { ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, nil), } connectionSignature, err := ctx2.prepareConnectionSignature(connection, newDidDoc.ID) require.NoError(t, err) @@ -668,7 +670,7 @@ func TestPrepareConnectionSignature(t *testing.T) { ctx2 := &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: newDidDoc}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, nil), } connectionSignature, err := ctx2.prepareConnectionSignature(connection, newDidDoc.ID) require.Error(t, err) @@ -696,9 +698,9 @@ func TestPrepareConnectionSignature(t *testing.T) { }) t.Run("prepare connection signature error", func(t *testing.T) { ctx := &context{signer: &mockSigner{err: errors.New("sign error")}, - connectionStore: NewConnectionRecorder(nil, store)} + connectionStore: NewConnectionRecorder(nil, store, nil)} connection := &Connection{ - DIDDoc: getMockDID(), + DIDDoc: mockdiddoc.GetMockDIDDoc(), } connectionSignature, err := ctx.prepareConnectionSignature(connection, invitation.ID) require.Error(t, err) @@ -707,35 +709,6 @@ func TestPrepareConnectionSignature(t *testing.T) { }) } -func TestPrepareDestination(t *testing.T) { - t.Run("successfully prepared destination", func(t *testing.T) { - dest, err := prepareDestination(getMockDID()) - require.NoError(t, err) - require.NotNil(t, dest) - require.Equal(t, dest.ServiceEndpoint, "https://localhost:8090") - }) - - t.Run("error while getting service", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service = nil - - dest, err := prepareDestination(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "service not found in DID document: did-communication") - require.Nil(t, dest) - }) - - t.Run("error while getting recipient keys from did doc", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service[0].RecipientKeys = []string{} - - recipientKeys, err := getRecipientKeys(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "missing recipient keys in did-communication service") - require.Nil(t, recipientKeys) - }) -} - func TestNewRequestFromInvitation(t *testing.T) { invitation := &Invitation{ Type: InvitationMsgType, @@ -807,9 +780,12 @@ func TestNewResponseFromRequest(t *testing.T) { require.NotNil(t, connRec.TheirDID) }) t.Run("unsuccessful new response from request due to create did error", func(t *testing.T) { - didDoc := getMockDID() + didDoc := mockdiddoc.GetMockDIDDoc() ctx := &context{ - vdriRegistry: &mockvdri.MockVDRIRegistry{CreateErr: fmt.Errorf("create DID error"), ResolveValue: getMockDID()}} + vdriRegistry: &mockvdri.MockVDRIRegistry{ + CreateErr: fmt.Errorf("create DID error"), + ResolveValue: mockdiddoc.GetMockDIDDoc(), + }} request := &Request{Connection: &Connection{DID: didDoc.ID, DIDDoc: didDoc}} _, connRec, err := ctx.handleInboundRequest(request, &options{}, &ConnectionRecord{}) require.Error(t, err) @@ -817,9 +793,9 @@ func TestNewResponseFromRequest(t *testing.T) { require.Nil(t, connRec) }) t.Run("unsuccessful new response from request due to sign error", func(t *testing.T) { - ctx := &context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: getMockDID()}, + ctx := &context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}, signer: &mockSigner{err: errors.New("sign error")}, - connectionStore: NewConnectionRecorder(nil, store)} + connectionStore: NewConnectionRecorder(nil, store, nil)} request, err := createRequest(ctx) require.NoError(t, err) _, connRec, err := ctx.handleInboundRequest(request, &options{}, &ConnectionRecord{}) @@ -887,7 +863,7 @@ func TestGetInvitationRecipientKey(t *testing.T) { require.Equal(t, invitation.RecipientKeys[0], recKey) }) t.Run("failed to get invitation recipient key", func(t *testing.T) { - doc := getMockDID() + doc := mockdiddoc.GetMockDIDDoc() ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}} invitation := &Invitation{ Type: InvitationMsgType, @@ -916,7 +892,7 @@ func TestGetPublicKey(t *testing.T) { ctx := getContext(prov, nil) newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) - pubkey, err := getPublicKey(newDidDoc.PublicKey[0].ID, newDidDoc) + pubkey, err := diddoc.GetPublicKey(newDidDoc.PublicKey[0].ID, newDidDoc) require.NoError(t, err) require.NotNil(t, pubkey) }) @@ -925,95 +901,13 @@ func TestGetPublicKey(t *testing.T) { ctx := getContext(prov, nil) newDidDoc, err := ctx.vdriRegistry.Create(testMethod) require.NoError(t, err) - pubkey, err := getPublicKey("invalid-key", newDidDoc) + pubkey, err := diddoc.GetPublicKey("invalid-key", newDidDoc) require.Error(t, err) require.Contains(t, err.Error(), "key not found in DID document: invalid-key") require.Nil(t, pubkey) }) } -func TestGetDidCommService(t *testing.T) { - t.Run("successfully getting did-communication service", func(t *testing.T) { - didDoc := getMockDID() - - s, err := getDidCommService(didDoc) - require.NoError(t, err) - require.Equal(t, "did-communication", s.Type) - require.Equal(t, uint(0), s.Priority) - }) - - t.Run("error due to missing service", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service = nil - - s, err := getDidCommService(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "service not found in DID document: did-communication") - require.Nil(t, s) - }) - - t.Run("error due to missing did-communication service", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service[0].Type = "some-type" - didDoc.Service[1].Type = "other-type" - - s, err := getDidCommService(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "service not found in DID document: did-communication") - require.Nil(t, s) - }) -} - -func TestGetRecipientKeys(t *testing.T) { - t.Run("successfully getting recipient keys", func(t *testing.T) { - didDoc := getMockDID() - - recipientKeys, err := getRecipientKeys(didDoc) - require.NoError(t, err) - require.Equal(t, 1, len(recipientKeys)) - }) - - t.Run("error due to missing did-communication service", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service = nil - - recipientKeys, err := getRecipientKeys(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "service not found in DID document: did-communication") - require.Nil(t, recipientKeys) - }) - - t.Run("error due to missing recipient keys in did-communication service", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service[0].RecipientKeys = []string{} - - recipientKeys, err := getRecipientKeys(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "missing recipient keys in did-communication service") - require.Nil(t, recipientKeys) - }) - - t.Run("error due to missing public key in did doc", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service[0].RecipientKeys = []string{"invalid"} - - recipientKeys, err := getRecipientKeys(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "key not found in DID document: invalid") - require.Nil(t, recipientKeys) - }) - - t.Run("error due to unsupported key types", func(t *testing.T) { - didDoc := getMockDID() - didDoc.Service[0].RecipientKeys = []string{didDoc.PublicKey[0].ID} - - recipientKeys, err := getRecipientKeys(didDoc) - require.Error(t, err) - require.Contains(t, err.Error(), "recipient keys in did-communication service not supported") - require.Nil(t, recipientKeys) - }) -} - func TestGetDIDDocAndConnection(t *testing.T) { t.Run("successfully getting did doc and connection for public did", func(t *testing.T) { doc := createDIDDoc() @@ -1041,7 +935,7 @@ func TestGetDIDDocAndConnection(t *testing.T) { require.Nil(t, conn) }) t.Run("successfully created peer did", func(t *testing.T) { - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: getMockDID()}} + ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: mockdiddoc.GetMockDIDDoc()}} didDoc, conn, err := ctx.getDIDDocAndConnection("") require.NoError(t, err) require.NotNil(t, didDoc) @@ -1050,48 +944,6 @@ func TestGetDIDDocAndConnection(t *testing.T) { }) } -func TestGetDestinationFromDID(t *testing.T) { - doc := createDIDDoc() - - t.Run("successfully getting destination from public DID", func(t *testing.T) { - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}} - destination, err := ctx.getDestinationFromDID(doc.ID) - require.NoError(t, err) - require.NotNil(t, destination) - }) - t.Run("test public key not found", func(t *testing.T) { - doc.PublicKey = nil - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc}} - destination, err := ctx.getDestinationFromDID(doc.ID) - require.Error(t, err) - require.Contains(t, err.Error(), "key not found in DID document") - require.Nil(t, destination) - }) - t.Run("test service not found", func(t *testing.T) { - doc2 := createDIDDoc() - doc2.Service = nil - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: doc2}} - destination, err := ctx.getDestinationFromDID(doc2.ID) - require.Error(t, err) - require.Contains(t, err.Error(), "service not found in DID document: did-communication") - require.Nil(t, destination) - }) - t.Run("get destination by invitation", func(t *testing.T) { - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveValue: createDIDDoc()}} - invitation := &Invitation{DID: "test"} - destination, err := ctx.getDestination(invitation) - require.NoError(t, err) - require.NotNil(t, destination) - }) - t.Run("test did document not found", func(t *testing.T) { - ctx := context{vdriRegistry: &mockvdri.MockVDRIRegistry{ResolveErr: errors.New("resolver error")}} - destination, err := ctx.getDestinationFromDID(doc.ID) - require.Error(t, err) - require.Contains(t, err.Error(), "resolver error") - require.Nil(t, destination) - }) -} - type mockSigner struct { privateKey []byte err error @@ -1158,7 +1010,7 @@ func getContext(prov protocol.MockProvider, store storage.Store) *context { return &context{outboundDispatcher: prov.OutboundDispatcher(), vdriRegistry: &mockvdri.MockVDRIRegistry{CreateValue: createDIDDocWithKey(pubKey)}, signer: &mockSigner{privateKey: privKey}, - connectionStore: NewConnectionRecorder(store, store), + connectionStore: NewConnectionRecorder(store, store, nil), } } diff --git a/pkg/didcomm/protocol/introduce/service.go b/pkg/didcomm/protocol/introduce/service.go index 9eb908d590..64092a5dc1 100644 --- a/pkg/didcomm/protocol/introduce/service.go +++ b/pkg/didcomm/protocol/introduce/service.go @@ -274,13 +274,13 @@ func (s *Service) InvitationReceived(thID string) error { return fmt.Errorf("invitation received new DIDComm msg: %w", err) } - _, err = s.HandleInbound(msg) + _, err = s.HandleInbound(msg, "", "") return err } // HandleInbound handles inbound message (introduce protocol) -func (s *Service) HandleInbound(msg *service.DIDCommMsg) (string, error) { +func (s *Service) HandleInbound(msg *service.DIDCommMsg, myDID, theirDID string) (string, error) { aEvent := s.ActionEvent() logger.Infof("entered into HandleInbound: %v", msg.Header) diff --git a/pkg/didcomm/protocol/introduce/service_test.go b/pkg/didcomm/protocol/introduce/service_test.go index 252ed4fe5e..321ccc6b59 100644 --- a/pkg/didcomm/protocol/introduce/service_test.go +++ b/pkg/didcomm/protocol/introduce/service_test.go @@ -302,7 +302,7 @@ func TestService_HandleInboundStop(t *testing.T) { require.NoError(t, svc.RegisterMsgEvent(sCh)) go func() { - _, err := svc.HandleInbound(msg) + _, err := svc.HandleInbound(msg, "", "") require.NoError(t, err) }() @@ -339,7 +339,7 @@ func TestService_HandleInbound(t *testing.T) { svc, err := New(provider) require.NoError(t, err) defer stop(t, svc) - _, err = svc.HandleInbound(&service.DIDCommMsg{}) + _, err = svc.HandleInbound(&service.DIDCommMsg{}, "", "") require.EqualError(t, err, "no clients are registered to handle the message") }) @@ -363,7 +363,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.EqualError(t, err, service.ErrThreadIDNotFound.Error()) }) @@ -392,7 +392,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.EqualError(t, err, "cannot fetch state from store: thid=ID : test err") }) @@ -424,7 +424,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.EqualError(t, err, "invalid state transition: noop -> deciding") }) @@ -452,7 +452,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.EqualError(t, err, "invalid state name unknown") }) @@ -479,7 +479,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, "", "") require.EqualError(t, err, "unrecognized msgType: unknown") }) @@ -507,7 +507,7 @@ func TestService_HandleInbound(t *testing.T) { ch := make(chan service.DIDCommAction) require.NoError(t, svc.RegisterActionEvent(ch)) go func() { - _, err := svc.HandleInbound(msg) + _, err := svc.HandleInbound(msg, "", "") require.NoError(t, err) }() @@ -964,7 +964,7 @@ func handleIntroducer(f *flow) { // nolint: govet resp, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(resp) + _, err = f.svc.HandleInbound(resp, "", "") require.NoError(f.t, err) }() continueAction(f.t, aCh, RequestMsgType, f.dep) @@ -996,7 +996,7 @@ func handleIntroducer(f *flow) { // nolint: govet resp, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(resp) + _, err = f.svc.HandleInbound(resp, "", "") require.NoError(f.t, err) }() continueAction(f.t, aCh, ResponseMsgType, f.dep) @@ -1009,7 +1009,7 @@ func handleIntroducer(f *flow) { // nolint: govet resp, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(resp) + _, err = f.svc.HandleInbound(resp, "", "") require.NoError(f.t, err) }() continueAction(f.t, aCh, ResponseMsgType, f.dep) @@ -1081,7 +1081,7 @@ func handleIntroduceeDone(f *flow) { // nolint: govet reqMsg, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(reqMsg) + _, err = f.svc.HandleInbound(reqMsg, "", "") require.NoError(f.t, err) }() @@ -1096,7 +1096,7 @@ func handleIntroduceeDone(f *flow) { // nolint: govet ackMsg, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(ackMsg) + _, err = f.svc.HandleInbound(ackMsg, "", "") require.NoError(f.t, err) }() checkStateMsg(f.t, sCh, service.PreState, AckMsgType, stateNameDone) @@ -1138,7 +1138,7 @@ func handleIntroduceeRecipient(f *flow) { // creates proposal msg reqMsg, err := service.NewDIDCommMsg(toBytes(f.t, <-f.transport[f.transportKey])) require.NoError(f.t, err) - _, err = f.svc.HandleInbound(reqMsg) + _, err = f.svc.HandleInbound(reqMsg, "", "") require.NoError(f.t, err) }() diff --git a/pkg/didcomm/transport/http/inbound.go b/pkg/didcomm/transport/http/inbound.go index 1c1d0ea173..f50b17e692 100644 --- a/pkg/didcomm/transport/http/inbound.go +++ b/pkg/didcomm/transport/http/inbound.go @@ -58,15 +58,18 @@ func processPOSTRequest(w http.ResponseWriter, r *http.Request, prov transport.P logger.Errorf("failed to unpack msg: %s - returning Code: %d", err, http.StatusInternalServerError) http.Error(w, "failed to unpack msg", http.StatusInternalServerError) + println("########## Failed to UnpackMessage ##########") + return } messageHandler := prov.InboundMessageHandler() - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { // TODO https://github.com/hyperledger/aries-framework-go/issues/271 HTTP Response Codes based on errors // from service + println("########## Failed to handle message ##########") logger.Errorf("incoming msg processing failed: %s", err) w.WriteHeader(http.StatusInternalServerError) } else { diff --git a/pkg/didcomm/transport/http/inbound_test.go b/pkg/didcomm/transport/http/inbound_test.go index a4764f039d..69769300b7 100644 --- a/pkg/didcomm/transport/http/inbound_test.go +++ b/pkg/didcomm/transport/http/inbound_test.go @@ -30,7 +30,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Debugf("message received is %s", message) return nil } diff --git a/pkg/didcomm/transport/transport_interface.go b/pkg/didcomm/transport/transport_interface.go index 283f802e04..93af54fa53 100644 --- a/pkg/didcomm/transport/transport_interface.go +++ b/pkg/didcomm/transport/transport_interface.go @@ -26,7 +26,7 @@ type OutboundTransport interface { // InboundMessageHandler handles the inbound requests. The transport will unpack the payload prior to the // message handle invocation. -type InboundMessageHandler func(message []byte) error +type InboundMessageHandler func(message []byte, myDID, theirDID string) error // Provider contains dependencies for starting the inbound/outbound transports. // It is typically created by using aries.Context(). diff --git a/pkg/didcomm/transport/ws/pool.go b/pkg/didcomm/transport/ws/pool.go index 705b198121..97e31afe2e 100644 --- a/pkg/didcomm/transport/ws/pool.go +++ b/pkg/didcomm/transport/ws/pool.go @@ -91,7 +91,7 @@ func (d *connPool) listener(conn *websocket.Conn) { messageHandler := d.msgHandler - err = messageHandler(unpackMsg.Message) + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) if err != nil { logger.Errorf("incoming msg processing failed: %v", err) } diff --git a/pkg/didcomm/transport/ws/pool_test.go b/pkg/didcomm/transport/ws/pool_test.go index a008956a6b..22596ada07 100644 --- a/pkg/didcomm/transport/ws/pool_test.go +++ b/pkg/didcomm/transport/ws/pool_test.go @@ -47,7 +47,7 @@ func TestConnectionStore(t *testing.T) { transportProvider := &mockTransportProvider{ packagerValue: mockPackager, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { resp, outboundErr := outbound.Send([]byte(response), prepareDestinationWithTransport("ws://doesnt-matter", "", []string{verKey})) require.NoError(t, outboundErr) @@ -100,7 +100,7 @@ func TestConnectionStore(t *testing.T) { transportProvider := &mockTransportProvider{ packagerValue: &mockPackager{verKey: verKey}, frameworkID: uuid.New().String(), - executeInbound: func(message []byte) error { + executeInbound: func(message []byte, myDID, theirDID string) error { // validate the echo server response with the outbound sent message require.Equal(t, request, message) done <- struct{}{} diff --git a/pkg/didcomm/transport/ws/support_test.go b/pkg/didcomm/transport/ws/support_test.go index acd49fc503..4c75df12f7 100644 --- a/pkg/didcomm/transport/ws/support_test.go +++ b/pkg/didcomm/transport/ws/support_test.go @@ -32,7 +32,7 @@ type mockProvider struct { } func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { logger.Infof("message received is %s", string(message)) if string(message) == "invalid-data" { return errors.New("error") @@ -151,7 +151,7 @@ func (m *mockPackager) UnpackMessage(encMessage []byte) (*commontransport.Envelo type mockTransportProvider struct { packagerValue commontransport.Packager - executeInbound func(message []byte) error + executeInbound func(message []byte, myDID, theirDID string) error frameworkID string } diff --git a/pkg/doc/did/helpers.go b/pkg/doc/did/helpers.go new file mode 100644 index 0000000000..a1d3f27e3f --- /dev/null +++ b/pkg/doc/did/helpers.go @@ -0,0 +1,72 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package did + +import ( + "fmt" +) + +// GetDIDCommService returns the service from the given DIDDoc matching the given service type. +func GetDIDCommService(didDoc *Doc, serviceType string) (*Service, error) { + const notFound = -1 + index := notFound + + for i, s := range didDoc.Service { + if s.Type == serviceType { + if index == notFound || didDoc.Service[index].Priority > s.Priority { + index = i + } + } + } + + if index == notFound { + return nil, fmt.Errorf("service not found in DID document: %s", serviceType) + } + + return &didDoc.Service[index], nil +} + +// GetRecipientKeys gets the recipient keys from the did doc which match the given parameters. +func GetRecipientKeys(didDoc *Doc, serviceType, keyType string) ([]string, error) { + didCommService, err := GetDIDCommService(didDoc, serviceType) + if err != nil { + return nil, err + } + + if len(didCommService.RecipientKeys) == 0 { + return nil, fmt.Errorf("missing recipient keys in did-communication service") + } + + var recipientKeys []string + + for _, keyID := range didCommService.RecipientKeys { + key, err := GetPublicKey(keyID, didDoc) + if err != nil { + return nil, err + } + + if key.Type == keyType { + recipientKeys = append(recipientKeys, string(key.Value)) + } + } + + if len(recipientKeys) == 0 { + return nil, fmt.Errorf("recipient keys in did-communication service not supported") + } + + return recipientKeys, nil +} + +// GetPublicKey returns the public key with the given id from the given DID Doc +func GetPublicKey(id string, didDoc *Doc) (*PublicKey, error) { + for _, key := range didDoc.PublicKey { + if key.ID == id { + return &key, nil + } + } + + return nil, fmt.Errorf("key not found in DID document: %s", id) +} diff --git a/pkg/doc/did/helpers_test.go b/pkg/doc/did/helpers_test.go new file mode 100644 index 0000000000..db1b4b5e32 --- /dev/null +++ b/pkg/doc/did/helpers_test.go @@ -0,0 +1,102 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package did_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + . "github.com/hyperledger/aries-framework-go/pkg/doc/did" + mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/internal/mock/diddoc" +) + +func TestGetRecipientKeys(t *testing.T) { + ed25519KeyType := "Ed25519VerificationKey2018" + didCommServiceType := "did-communication" + + t.Run("successfully getting recipient keys", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + + recipientKeys, err := GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.NoError(t, err) + require.Equal(t, 1, len(recipientKeys)) + }) + + t.Run("error due to missing did-communication service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service = nil + + recipientKeys, err := GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found in DID document: did-communication") + require.Nil(t, recipientKeys) + }) + + t.Run("error due to missing recipient keys in did-communication service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service[0].RecipientKeys = []string{} + + recipientKeys, err := GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "missing recipient keys in did-communication service") + require.Nil(t, recipientKeys) + }) + + t.Run("error due to missing public key in did doc", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service[0].RecipientKeys = []string{"invalid"} + + recipientKeys, err := GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "key not found in DID document: invalid") + require.Nil(t, recipientKeys) + }) + + t.Run("error due to unsupported key types", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service[0].RecipientKeys = []string{didDoc.PublicKey[0].ID} + + recipientKeys, err := GetRecipientKeys(didDoc, didCommServiceType, ed25519KeyType) + require.Error(t, err) + require.Contains(t, err.Error(), "recipient keys in did-communication service not supported") + require.Nil(t, recipientKeys) + }) +} + +func TestGetDidCommService(t *testing.T) { + didCommServiceType := "did-communication" + + t.Run("successfully getting did-communication service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + + s, err := GetDIDCommService(didDoc, didCommServiceType) + require.NoError(t, err) + require.Equal(t, "did-communication", s.Type) + require.Equal(t, uint(0), s.Priority) + }) + + t.Run("error due to missing service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service = nil + + s, err := GetDIDCommService(didDoc, didCommServiceType) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found in DID document: did-communication") + require.Nil(t, s) + }) + + t.Run("error due to missing did-communication service", func(t *testing.T) { + didDoc := mockdiddoc.GetMockDIDDoc() + didDoc.Service[0].Type = "some-type" + didDoc.Service[1].Type = "other-type" + + s, err := GetDIDCommService(didDoc, didCommServiceType) + require.Error(t, err) + require.Contains(t, err.Error(), "service not found in DID document: did-communication") + require.Nil(t, s) + }) +} diff --git a/pkg/framework/aries/api/protocol.go b/pkg/framework/aries/api/protocol.go index 0d2e941f16..9e70373aba 100644 --- a/pkg/framework/aries/api/protocol.go +++ b/pkg/framework/aries/api/protocol.go @@ -9,6 +9,7 @@ package api import ( "errors" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" @@ -25,6 +26,7 @@ type Provider interface { Service(id string) (interface{}, error) StorageProvider() storage.Provider KMS() kms.KeyManager + DIDConnectionStore() didconnection.Store Packager() transport.Packager InboundTransportEndpoint() string VDRIRegistry() vdriapi.Registry diff --git a/pkg/framework/aries/framework.go b/pkg/framework/aries/framework.go index 3820831026..10c7bcbb0b 100644 --- a/pkg/framework/aries/framework.go +++ b/pkg/framework/aries/framework.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" @@ -37,6 +38,7 @@ type Aries struct { inboundTransport transport.InboundTransport kmsCreator api.KMSCreator kms api.CloseableKMS + didConnectionStore didconnection.Store packagerCreator packager.Creator packager commontransport.Packager packerCreator packer.Creator @@ -81,8 +83,11 @@ func New(opts ...Option) (*Aries, error) { // on the context. The inbound transports require ctx.InboundMessageHandler(), which in-turn depends on // protocolServices. At the moment, there is a looping issue among these. - // Order of initializing service is important + return initializeServices(frameworkOpts) +} +func initializeServices(frameworkOpts *Aries) (*Aries, error) { + // Order of initializing service is important // Create kms if e := createKMS(frameworkOpts); e != nil { return nil, e @@ -93,7 +98,12 @@ func New(opts ...Option) (*Aries, error) { return nil, e } - // create packers and packager (must be done after KMS) + // Create connection store + if e := createDIDConnectionStore(frameworkOpts); e != nil { + return nil, e + } + + // create packers and packager (must be done after KMS and connection store) if err := createPackersAndPackager(frameworkOpts); err != nil { return nil, err } @@ -205,6 +215,7 @@ func (a *Aries) Context() (*context.Provider, error) { context.WithOutboundTransports(a.outboundTransports...), context.WithProtocolServices(a.services...), context.WithKMS(a.kms), + context.WithDIDConnectionStore(a.didConnectionStore), context.WithInboundTransportEndpoint(a.inboundTransport.Endpoint()), context.WithStorageProvider(a.storeProvider), context.WithTransientStorageProvider(a.transientStoreProvider), @@ -349,7 +360,9 @@ func loadServices(frameworkOpts *Aries) error { context.WithKMS(frameworkOpts.kms), context.WithPackager(frameworkOpts.packager), context.WithInboundTransportEndpoint(frameworkOpts.inboundTransport.Endpoint()), - context.WithVDRIRegistry(frameworkOpts.vdriRegistry)) + context.WithVDRIRegistry(frameworkOpts.vdriRegistry), + context.WithDIDConnectionStore(frameworkOpts.didConnectionStore), + ) if err != nil { return fmt.Errorf("create context failed: %w", err) @@ -367,10 +380,27 @@ func loadServices(frameworkOpts *Aries) error { return nil } +func createDIDConnectionStore(frameworkOpts *Aries) error { + ctx, err := context.New( + context.WithStorageProvider(frameworkOpts.storeProvider), + context.WithVDRIRegistry(frameworkOpts.vdriRegistry), + ) + if err != nil { + return fmt.Errorf("create did lookup store context failed: %w", err) + } + + frameworkOpts.didConnectionStore, err = didconnection.New(ctx) + if err != nil { + return fmt.Errorf("create did lookup store failed: %w", err) + } + + return nil +} + func createPackersAndPackager(frameworkOpts *Aries) error { ctx, err := context.New(context.WithKMS(frameworkOpts.kms)) if err != nil { - return fmt.Errorf("create envelope context failed: %w", err) + return fmt.Errorf("create packer context failed: %w", err) } frameworkOpts.primaryPacker, err = frameworkOpts.packerCreator(ctx) @@ -392,7 +422,8 @@ func createPackersAndPackager(frameworkOpts *Aries) error { } ctx, err = context.New( - context.WithPacker(frameworkOpts.primaryPacker, frameworkOpts.packers...)) + context.WithPacker(frameworkOpts.primaryPacker, frameworkOpts.packers...), + context.WithDIDConnectionStore(frameworkOpts.didConnectionStore)) if err != nil { return fmt.Errorf("create packager context failed: %w", err) } diff --git a/pkg/framework/context/context.go b/pkg/framework/context/context.go index d748a71ec2..9b4adf9745 100644 --- a/pkg/framework/context/context.go +++ b/pkg/framework/context/context.go @@ -9,6 +9,7 @@ package context import ( "fmt" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" @@ -26,6 +27,7 @@ type Provider struct { storeProvider storage.Provider transientStoreProvider storage.Provider kms kms.KMS + didConnectionStore didconnection.Store packager commontransport.Packager primaryPacker packer.Packer packers []packer.Packer @@ -77,6 +79,11 @@ func (p *Provider) KMS() kms.KeyManager { return p.kms } +// DIDConnectionStore returns a didconnection.Store service. +func (p *Provider) DIDConnectionStore() didconnection.Store { + return p.didConnectionStore +} + // Packager returns a packager service. func (p *Provider) Packager() commontransport.Packager { return p.packager @@ -104,7 +111,7 @@ func (p *Provider) InboundTransportEndpoint() string { // InboundMessageHandler return an inbound message handler. func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { - return func(message []byte) error { + return func(message []byte, myDID, theirDID string) error { msg, err := service.NewDIDCommMsg(message) if err != nil { return err @@ -113,7 +120,7 @@ func (p *Provider) InboundMessageHandler() transport.InboundMessageHandler { // find the service which accepts the message type for _, svc := range p.services { if svc.Accept(msg.Header.Type) { - _, err = svc.HandleInbound(msg) + _, err = svc.HandleInbound(msg, myDID, theirDID) return err } } @@ -221,6 +228,14 @@ func WithTransientStorageProvider(s storage.Provider) ProviderOption { } } +// WithDIDConnectionStore injects a didconnection.Store into the context. +func WithDIDConnectionStore(cs didconnection.Store) ProviderOption { + return func(opts *Provider) error { + opts.didConnectionStore = cs + return nil + } +} + // WithPackager injects a packager into the context. func WithPackager(p commontransport.Packager) ProviderOption { return func(opts *Provider) error { diff --git a/pkg/framework/context/context_test.go b/pkg/framework/context/context_test.go index e1fca0dc1b..f42e1a3742 100644 --- a/pkg/framework/context/context_test.go +++ b/pkg/framework/context/context_test.go @@ -97,16 +97,16 @@ func TestNewProvider(t *testing.T) { { "@frameworkID": "5678876542345", "@type": "valid-message-type" - }`)) + }`), "", "") require.NoError(t, err) // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") // invalid json - err = inboundHandler([]byte("invalid json")) + err = inboundHandler([]byte("invalid json"), "", "") require.Error(t, err) require.Contains(t, err.Error(), "invalid payload data format") @@ -115,7 +115,7 @@ func TestNewProvider(t *testing.T) { { "@type": "invalid-message-type", "label": "Bob" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "no message handlers found for the message type: invalid-message-type") @@ -124,7 +124,7 @@ func TestNewProvider(t *testing.T) { { "label": "Carol", "@type": "valid-message-type" - }`)) + }`), "", "") require.Error(t, err) require.Contains(t, err.Error(), "error handling the message") }) diff --git a/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go new file mode 100644 index 0000000000..114a92d13d --- /dev/null +++ b/pkg/internal/mock/didcomm/didconnection/mock_didconnection.go @@ -0,0 +1,36 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package didconnection + +// MockDIDConnection mocks the did lookup store. +type MockDIDConnection struct { + SaveRecordErr error + SaveConnectionErr error + SaveKeysErr error + GetDIDValue string + GetDIDErr error + SaveDIDErr error +} + +// SaveDID saves a DID to the store +func (m *MockDIDConnection) SaveDID(did string, keys ...string) error { + return m.SaveRecordErr +} + +// GetDID gets the DID stored under the given key +func (m *MockDIDConnection) GetDID(key string) (string, error) { + return m.GetDIDValue, m.GetDIDErr +} + +// SaveDIDConnection saves a DID connection +func (m *MockDIDConnection) SaveDIDConnection(myDID, theirDID string, theirKeys []string) error { + return m.SaveConnectionErr +} + +// SaveDIDFromDoc saves a DID using its doc +func (m *MockDIDConnection) SaveDIDFromDoc(did, serviceType, keyType string) error { + return m.SaveDIDErr +} diff --git a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go index 7d09863014..49a5bec816 100644 --- a/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go +++ b/pkg/internal/mock/didcomm/dispatcher/mock_outbound.go @@ -18,3 +18,8 @@ type MockOutbound struct { func (m *MockOutbound) Send(msg interface{}, senderVerKey string, des *service.Destination) error { return m.SendErr } + +// SendToDID msg +func (m *MockOutbound) SendToDID(msg interface{}, myDID, theirDID string) error { + return m.SendErr +} diff --git a/pkg/internal/mock/didcomm/mock_authcrypt.go b/pkg/internal/mock/didcomm/mock_authcrypt.go index d8ef0c72fc..b7a804b942 100644 --- a/pkg/internal/mock/didcomm/mock_authcrypt.go +++ b/pkg/internal/mock/didcomm/mock_authcrypt.go @@ -9,7 +9,7 @@ package didcomm // MockAuthCrypt mock auth crypt type MockAuthCrypt struct { EncryptValue func(payload, senderPubKey []byte, recipients [][]byte) ([]byte, error) - DecryptValue func(envelope []byte) ([]byte, []byte, error) + DecryptValue func(envelope []byte) ([]byte, []byte, []byte, error) Type string } @@ -20,7 +20,7 @@ func (m *MockAuthCrypt) Pack(payload, senderPubKey []byte, } // Unpack mock message unpacking -func (m *MockAuthCrypt) Unpack(envelope []byte) ([]byte, []byte, error) { +func (m *MockAuthCrypt) Unpack(envelope []byte) ([]byte, []byte, []byte, error) { return m.DecryptValue(envelope) } diff --git a/pkg/internal/mock/didcomm/protocol/mock_didexchange.go b/pkg/internal/mock/didcomm/protocol/mock_didexchange.go index 7641649fd1..994166738e 100644 --- a/pkg/internal/mock/didcomm/protocol/mock_didexchange.go +++ b/pkg/internal/mock/didcomm/protocol/mock_didexchange.go @@ -9,6 +9,7 @@ package protocol import ( "github.com/google/uuid" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" vdriapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdri" @@ -35,7 +36,7 @@ type MockDIDExchangeSvc struct { } // HandleInbound msg -func (m *MockDIDExchangeSvc) HandleInbound(msg *service.DIDCommMsg) (string, error) { +func (m *MockDIDExchangeSvc) HandleInbound(msg *service.DIDCommMsg, myDID, theirDID string) (string, error) { if m.HandleFunc != nil { return m.HandleFunc(msg) } @@ -135,9 +136,15 @@ func (m *MockDIDExchangeSvc) CreateImplicitInvitation(inviterLabel, inviterDID, // MockProvider is provider for DIDExchange Service type MockProvider struct { - StoreProvider *mockstore.MockStoreProvider - TransientStoreProvider *mockstore.MockStoreProvider - CustomVDRI vdriapi.Registry + StoreProvider *mockstore.MockStoreProvider + TransientStoreProvider *mockstore.MockStoreProvider + CustomVDRI vdriapi.Registry + DIDConnectionStoreValue didconnection.Store +} + +// DIDConnectionStore returns the did lookup store +func (p *MockProvider) DIDConnectionStore() didconnection.Store { + return p.DIDConnectionStoreValue } // OutboundDispatcher is mock outbound dispatcher for DID exchange service diff --git a/pkg/internal/mock/diddoc/mock_diddoc.go b/pkg/internal/mock/diddoc/mock_diddoc.go new file mode 100644 index 0000000000..3a82ab807c --- /dev/null +++ b/pkg/internal/mock/diddoc/mock_diddoc.go @@ -0,0 +1,54 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +*/ + +package mockdiddoc + +import ( + "github.com/btcsuite/btcutil/base58" + + "github.com/hyperledger/aries-framework-go/pkg/doc/did" +) + +// GetMockDIDDoc creates a mock DID Doc for testing. +func GetMockDIDDoc() *did.Doc { + return &did.Doc{ + Context: []string{"https://w3id.org/did/v1"}, + ID: "did:peer:123456789abcdefghi#inbox", + Service: []did.Service{ + { + ServiceEndpoint: "https://localhost:8090", + Type: "did-communication", + Priority: 0, + RecipientKeys: []string{"did:example:123456789abcdefghi#keys-2"}, + }, + { + ServiceEndpoint: "https://localhost:8090", + Type: "did-communication", + Priority: 1, + RecipientKeys: []string{"did:example:123456789abcdefghi#keys-1"}, + }, + }, + PublicKey: []did.PublicKey{ + { + ID: "did:example:123456789abcdefghi#keys-1", + Controller: "did:example:123456789abcdefghi", + Type: "Secp256k1VerificationKey2018", + Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), + }, + { + ID: "did:example:123456789abcdefghi#keys-2", + Controller: "did:example:123456789abcdefghi", + Type: "Ed25519VerificationKey2018", + Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), + }, + { + ID: "did:example:123456789abcdefghw#key2", + Controller: "did:example:123456789abcdefghw", + Type: "RsaVerificationKey2018", + Value: base58.Decode("H3C2AVvLMv6gmMNam3uVAjZpfkcJCwDwnZn6z3wXmqPV"), + }, + }, + } +} diff --git a/pkg/internal/mock/packer/noop.go b/pkg/internal/mock/packer/noop.go index 3b00679a24..edd8c5311f 100644 --- a/pkg/internal/mock/packer/noop.go +++ b/pkg/internal/mock/packer/noop.go @@ -9,6 +9,7 @@ package packer import ( "encoding/base64" "encoding/json" + "fmt" "github.com/btcsuite/btcutil/base58" @@ -16,9 +17,10 @@ import ( ) type envelope struct { - Header string `json:"protected,omitempty"` - Sender string `json:"spk,omitempty"` - Message string `json:"msg,omitempty"` + Header string `json:"protected,omitempty"` + Sender string `json:"spk,omitempty"` + Recipient string `json:"kid,omitempty"` + Message string `json:"msg,omitempty"` } type header struct { @@ -50,10 +52,15 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte headerB64 := base64.URLEncoding.EncodeToString(headerBytes) + if len(recipientPubKeys) == 0 { + return nil, fmt.Errorf("no recipients") + } + message := envelope{ - Header: headerB64, - Sender: base58.Encode(sender), - Message: string(payload), + Header: headerB64, + Sender: base58.Encode(sender), + Recipient: base58.Encode(recipientPubKeys[0]), + Message: string(payload), } msgBytes, err := json.Marshal(&message) @@ -62,27 +69,27 @@ func (p *Packer) Pack(payload, sender []byte, recipientPubKeys [][]byte) ([]byte } // Unpack will decode the envelope using the NOOP format. -func (p *Packer) Unpack(message []byte) ([]byte, []byte, error) { +func (p *Packer) Unpack(message []byte) ([]byte, []byte, []byte, error) { var env envelope err := json.Unmarshal(message, &env) if err != nil { - return nil, nil, err + return nil, nil, nil, err } headerBytes, err := base64.URLEncoding.DecodeString(env.Header) if err != nil { - return nil, nil, err + return nil, nil, nil, err } var head header err = json.Unmarshal(headerBytes, &head) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return []byte(env.Message), base58.Decode(env.Sender), nil + return []byte(env.Message), base58.Decode(env.Sender), base58.Decode(env.Recipient), nil } // EncodingType returns the type of the encoding, as found in the header `Typ` field diff --git a/pkg/internal/mock/packer/noop_test.go b/pkg/internal/mock/packer/noop_test.go index 32f6696bf0..71b69907e8 100644 --- a/pkg/internal/mock/packer/noop_test.go +++ b/pkg/internal/mock/packer/noop_test.go @@ -16,11 +16,12 @@ import ( // note: does not replicate correct packing // when msg needs to be escaped. -func testPack(msg, key []byte) []byte { +func testPack(msg, senderKey, recKey []byte) []byte { headerValue := base64.URLEncoding.EncodeToString([]byte(`{"typ":"NOOP"}`)) return []byte(`{"protected":"` + headerValue + - `","spk":"` + base58.Encode(key) + + `","spk":"` + base58.Encode(senderKey) + + `","kid":"` + base58.Encode(recKey) + `","msg":"` + string(msg) + `"}`) } @@ -32,48 +33,55 @@ func TestPacker(t *testing.T) { t.Run("pack, compare against correct data", func(t *testing.T) { msgin := []byte("hello my name is zoop") key := []byte("senderkey") + rec := []byte("recipient") - msgout, err := p.Pack(msgin, key, nil) + msgout, err := p.Pack(msgin, key, [][]byte{rec}) require.NoError(t, err) - correct := testPack(msgin, key) + correct := testPack(msgin, key, rec) require.Equal(t, correct, msgout) }) t.Run("unpack fixed value, confirm data", func(t *testing.T) { correct := []byte("this is not a test message") key := []byte("testKey") - msgin := testPack(correct, key) + rec := []byte("key2") + msgin := testPack(correct, key, rec) - msgout, keyOut, err := p.Unpack(msgin) + msgout, keyOut, myKey, err := p.Unpack(msgin) require.NoError(t, err) require.Equal(t, correct, msgout) require.Equal(t, key, keyOut) + require.Equal(t, rec, myKey) }) t.Run("multiple pack/unpacks", func(t *testing.T) { cleartext := []byte("this is not a test message") key1 := []byte("testKey") + rec1 := []byte("rec1") key2 := []byte("wrapperKey") + rec2 := []byte("rec2") - correct1 := testPack(cleartext, key1) + correct1 := testPack(cleartext, key1, rec1) - msg1, err := p.Pack(cleartext, key1, nil) + msg1, err := p.Pack(cleartext, key1, [][]byte{rec1}) require.NoError(t, err) require.Equal(t, correct1, msg1) - msg2, err := p.Pack(msg1, key2, nil) + msg2, err := p.Pack(msg1, key2, [][]byte{rec2}) require.NoError(t, err) - msg3, key1Out, err := p.Unpack(msg2) + msg3, key1Out, rec1Out, err := p.Unpack(msg2) require.NoError(t, err) require.Equal(t, key2, key1Out) + require.Equal(t, rec2, rec1Out) require.Equal(t, correct1, msg3) - msg4, key2Out, err := p.Unpack(msg3) + msg4, key2Out, rec2Out, err := p.Unpack(msg3) require.NoError(t, err) require.Equal(t, key1, key2Out) + require.Equal(t, rec1, rec2Out) require.Equal(t, cleartext, msg4) }) } diff --git a/pkg/internal/mock/provider/mock_provider.go b/pkg/internal/mock/provider/mock_provider.go index e52373a389..f2d77f1261 100644 --- a/pkg/internal/mock/provider/mock_provider.go +++ b/pkg/internal/mock/provider/mock_provider.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 package provider import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packer" "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/hyperledger/aries-framework-go/pkg/storage" @@ -22,6 +23,7 @@ type Provider struct { TransientStorageProviderValue storage.Provider PackerList []packer.Packer PackerValue packer.Packer + ConnectionStoreValue didconnection.Store } // Service return service @@ -58,3 +60,8 @@ func (p *Provider) Packers() []packer.Packer { func (p *Provider) PrimaryPacker() packer.Packer { return p.PackerValue } + +// DIDConnectionStore returns a didconnection.Store service. +func (p *Provider) DIDConnectionStore() didconnection.Store { + return p.ConnectionStoreValue +} diff --git a/pkg/restapi/operation/didexchange/didexchange.go b/pkg/restapi/operation/didexchange/didexchange.go index 47c39d9ec7..4dc0837e51 100644 --- a/pkg/restapi/operation/didexchange/didexchange.go +++ b/pkg/restapi/operation/didexchange/didexchange.go @@ -18,6 +18,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/client/didexchange" "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/didconnection" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/internal/common/support" "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -74,6 +75,7 @@ const ( type provider interface { Service(id string) (interface{}, error) KMS() kms.KeyManager + DIDConnectionStore() didconnection.Store InboundTransportEndpoint() string StorageProvider() storage.Provider TransientStorageProvider() storage.Provider diff --git a/pkg/restapi/operation/didexchange/didexchange_test.go b/pkg/restapi/operation/didexchange/didexchange_test.go index 428e33dd4d..85b5dfbe6b 100644 --- a/pkg/restapi/operation/didexchange/didexchange_test.go +++ b/pkg/restapi/operation/didexchange/didexchange_test.go @@ -514,7 +514,7 @@ func TestAcceptExchangeRequest(t *testing.T) { msg, err := service.NewDIDCommMsg(request) require.NoError(t, err) - _, err = didExSvc.HandleInbound(msg) + _, err = didExSvc.HandleInbound(msg, "", "") require.NoError(t, err) cid := <-connID @@ -586,7 +586,7 @@ func TestAcceptInvitation(t *testing.T) { msg, err := service.NewDIDCommMsg(invitation) require.NoError(t, err) - _, err = didExSvc.HandleInbound(msg) + _, err = didExSvc.HandleInbound(msg, "", "") require.NoError(t, err) var cid string diff --git a/pkg/vdri/registry.go b/pkg/vdri/registry.go index 91eb67bfea..3fee13cb85 100644 --- a/pkg/vdri/registry.go +++ b/pkg/vdri/registry.go @@ -129,7 +129,7 @@ func (r *Registry) applyDefaultDocOpts(docOpts *vdriapi.CreateDIDOpts, opts ...v return opts } -// Store did store +// Store stores the given DID doc. func (r *Registry) Store(doc *diddoc.Doc) error { didMethod, err := getDidMethod(doc.ID) if err != nil {