diff --git a/pkg/didcomm/protocol/route/models.go b/pkg/didcomm/protocol/route/models.go index 26a2c14e4..16c43059b 100644 --- a/pkg/didcomm/protocol/route/models.go +++ b/pkg/didcomm/protocol/route/models.go @@ -22,9 +22,9 @@ type Grant struct { RoutingKeys []string `json:"routing_keys,omitempty"` } -// KeyUpdate route key update message. +// KeylistUpdate route keylist update message. // https://github.com/hyperledger/aries-rfcs/tree/master/features/0211-route-coordination#keylist-update -type KeyUpdate struct { +type KeylistUpdate struct { Type string `json:"@type,omitempty"` ID string `json:"@id,omitempty"` Updates []Update `json:"updates,omitempty"` @@ -36,9 +36,9 @@ type Update struct { Action string `json:"action,omitempty"` } -// KeyUpdateResponse route key update response message. +// KeylistUpdateResponse route keylist update response message. // https://github.com/hyperledger/aries-rfcs/tree/master/features/0211-route-coordination#keylist-update-response -type KeyUpdateResponse struct { +type KeylistUpdateResponse struct { Type string `json:"@type,omitempty"` ID string `json:"@id,omitempty"` Updated []UpdateResponse `json:"updated,omitempty"` diff --git a/pkg/didcomm/protocol/route/service.go b/pkg/didcomm/protocol/route/service.go index 69846f8fb..be9ab9b67 100644 --- a/pkg/didcomm/protocol/route/service.go +++ b/pkg/didcomm/protocol/route/service.go @@ -7,15 +7,19 @@ SPDX-License-Identifier: Apache-2.0 package route import ( + "encoding/json" "errors" "fmt" + "github.com/hyperledger/aries-framework-go/pkg/common/log" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/hyperledger/aries-framework-go/pkg/storage" ) +var logger = log.New("aries-framework/did-exchange/service") + // constants for route coordination spec types const ( // Coordination route coordination protocol @@ -31,10 +35,26 @@ const ( GrantMsgType = CoordinationSpec + "route-grant" // KeyListUpdateMsgType defines the route coordination key list update message type. - KeyListUpdateMsgType = CoordinationSpec + "keylist_update" + KeylistUpdateMsgType = CoordinationSpec + "keylist_update" // KeyListUpdateResponseMsgType defines the route coordination key list update message response type. - KeyListUpdateResponseMsgType = CoordinationSpec + "keylist_update_response" + KeylistUpdateResponseMsgType = CoordinationSpec + "keylist_update_response" +) + +// constants for key list update processing +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0211-route-coordination#keylist-update +const ( + // add key to the store + add = "add" + + // remove key from the store + remove = "remove" + + // server error while storing the key + serverError = "server_error" + + // key save success + success = "success" ) // provider contains dependencies for the Routing protocol and is typically created by using aries.Context() @@ -73,7 +93,21 @@ func New(prov provider) (*Service, error) { // HandleInbound handles inbound route coordination messages. func (s *Service) HandleInbound(msg *service.DIDCommMsg) (string, error) { - return "", errors.New("not implemented") + // perform action on inbound message asynchronously + go func() { + switch msg.Header.Type { + case RequestMsgType: + if err := s.handleRequest(msg); err != nil { + logger.Errorf("handle route request error : %s", err) + } + case KeylistUpdateMsgType: + if err := s.handleKeylistUpdate(msg); err != nil { + logger.Errorf("handle route request error : %s", err) + } + } + }() + + return msg.Header.ID, nil } // HandleOutbound handles outbound route coordination messages. @@ -84,7 +118,7 @@ func (s *Service) HandleOutbound(msg *service.DIDCommMsg, destination *service.D // Accept checks whether the service can handle the message type. func (s *Service) Accept(msgType string) bool { switch msgType { - case RequestMsgType, GrantMsgType, KeyListUpdateMsgType, KeyListUpdateResponseMsgType: + case RequestMsgType, GrantMsgType, KeylistUpdateMsgType, KeylistUpdateResponseMsgType: return true } @@ -95,3 +129,84 @@ func (s *Service) Accept(msgType string) bool { func (s *Service) Name() string { return Coordination } + +func (s *Service) handleRequest(msg *service.DIDCommMsg) error { + // unmarshal the payload + request := &Request{} + + err := json.Unmarshal(msg.Payload, request) + if err != nil { + return fmt.Errorf("route request message unmarshal : %w", err) + } + + // create keys + _, sigPubKey, err := s.kms.CreateKeySet() + if err != nil { + return fmt.Errorf("failed to create keys : %w", err) + } + + // send the grant response + grant := &Grant{ + Type: GrantMsgType, + ID: msg.Header.ID, + Endpoint: s.endpoint, + RoutingKeys: []string{sigPubKey}, + } + + // TODO https://github.com/hyperledger/aries-framework-go/issues/725 get destination details from the connection + return s.outbound.Send(grant, "", nil) +} + +func (s *Service) handleKeylistUpdate(msg *service.DIDCommMsg) error { + // unmarshal the payload + keyUpdate := &KeylistUpdate{} + + err := json.Unmarshal(msg.Payload, keyUpdate) + if err != nil { + return fmt.Errorf("route key list update message unmarshal : %w", err) + } + + var updates []UpdateResponse + + // update the db + for _, v := range keyUpdate.Updates { + if v.Action == add { + // TODO https://github.com/hyperledger/aries-framework-go/issues/725 need to get the DID from the inbound transport + val := "" + result := success + + err = s.routeStore.Put(v.RecipientKey, []byte(val)) + if err != nil { + logger.Errorf("failed to add the route key to store : %s", err) + + result = serverError + } + + // construct the response doc + updates = append(updates, UpdateResponse{ + RecipientKey: v.RecipientKey, + Action: v.Action, + Result: result, + }) + } else if v.Action == remove { + // TODO remove from the store + + // construct the response doc + updates = append(updates, UpdateResponse{ + RecipientKey: v.RecipientKey, + Action: v.Action, + Result: serverError, + }) + } + } + + // send the key update response + updateResponse := &KeylistUpdateResponse{ + Type: KeylistUpdateResponseMsgType, + ID: msg.Header.ID, + Updated: updates, + } + + // TODO https://github.com/hyperledger/aries-framework-go/issues/725 get destination details from the connection + return s.outbound.Send(updateResponse, "", nil) +} diff --git a/pkg/didcomm/protocol/route/service_test.go b/pkg/didcomm/protocol/route/service_test.go index af29a3623..ca797c5c4 100644 --- a/pkg/didcomm/protocol/route/service_test.go +++ b/pkg/didcomm/protocol/route/service_test.go @@ -7,12 +7,20 @@ SPDX-License-Identifier: Apache-2.0 package route import ( + "encoding/json" "errors" "testing" "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" ) +type updateResult struct { + action string + result string +} + func TestServiceNew(t *testing.T) { t.Run("test new service - success", func(t *testing.T) { svc, err := New(&mockProvider{}) @@ -21,9 +29,10 @@ func TestServiceNew(t *testing.T) { }) t.Run("test new service name - failure", func(t *testing.T) { - _, err := New(&mockProvider{openStoreErr: errors.New("error opening the store")}) + svc, err := New(&mockProvider{openStoreErr: errors.New("error opening the store")}) require.Error(t, err) require.Contains(t, err.Error(), "open route coordination store") + require.Nil(t, svc) }) } @@ -32,8 +41,8 @@ func TestServiceAccept(t *testing.T) { require.Equal(t, true, s.Accept(RequestMsgType)) require.Equal(t, true, s.Accept(GrantMsgType)) - require.Equal(t, true, s.Accept(KeyListUpdateMsgType)) - require.Equal(t, true, s.Accept(KeyListUpdateResponseMsgType)) + require.Equal(t, true, s.Accept(KeylistUpdateMsgType)) + require.Equal(t, true, s.Accept(KeylistUpdateResponseMsgType)) require.Equal(t, false, s.Accept("unsupported msg type")) } @@ -42,9 +51,13 @@ func TestServiceHandleInbound(t *testing.T) { svc, err := New(&mockProvider{}) require.NoError(t, err) - _, err = svc.HandleInbound(nil) - require.Error(t, err) - require.Contains(t, err.Error(), "not implemented") + msgID := randomID() + + id, err := svc.HandleInbound(&service.DIDCommMsg{Header: &service.Header{ + ID: msgID, + }}) + require.NoError(t, err) + require.Equal(t, msgID, id) }) } @@ -58,3 +71,124 @@ func TestServiceHandleOutbound(t *testing.T) { require.Contains(t, err.Error(), "not implemented") }) } + +func TestServiceRequestMsg(t *testing.T) { + t.Run("test service handle inbound request msg - success", func(t *testing.T) { + svc, err := New(&mockProvider{}) + require.NoError(t, err) + + msgID := randomID() + + id, err := svc.HandleInbound(generateRequestMsgPayload(t, msgID)) + require.NoError(t, err) + require.Equal(t, msgID, id) + }) + + t.Run("test service handle request msg - success", func(t *testing.T) { + svc, err := New(&mockProvider{}) + require.NoError(t, err) + + msg := &service.DIDCommMsg{Payload: []byte("invalid json")} + + err = svc.handleRequest(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "route request message unmarshal") + }) + + t.Run("test service handle request msg - verify outbound message", func(t *testing.T) { + endpoint := "ws://agent.example.com" + svc, err := New(&mockProvider{ + endpoint: endpoint, + outbound: &mockOutbound{validateSend: func(msg interface{}) error { + res, err := json.Marshal(msg) + require.NoError(t, err) + + grant := &Grant{} + err = json.Unmarshal(res, grant) + require.NoError(t, err) + + require.Equal(t, endpoint, grant.Endpoint) + require.Equal(t, 1, len(grant.RoutingKeys)) + + return nil + }, + }, + }) + require.NoError(t, err) + + msgID := randomID() + + err = svc.handleRequest(generateRequestMsgPayload(t, msgID)) + require.NoError(t, err) + }) +} + +func TestServiceUpdateKeyListMsg(t *testing.T) { + t.Run("test service handle inbound key list update msg - success", func(t *testing.T) { + svc, err := New(&mockProvider{}) + require.NoError(t, err) + + msgID := randomID() + + id, err := svc.HandleInbound(generateKeyUpdateListMsgPayload(t, msgID, []Update{{ + RecipientKey: "ABC", + Action: "add", + }})) + require.NoError(t, err) + require.Equal(t, msgID, id) + }) + + t.Run("test service handle key list update msg - success", func(t *testing.T) { + svc, err := New(&mockProvider{}) + require.NoError(t, err) + + msg := &service.DIDCommMsg{Payload: []byte("invalid json")} + + err = svc.handleKeylistUpdate(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "route key list update message unmarshal") + }) + + t.Run("test service handle request msg - verify outbound message", func(t *testing.T) { + update := make(map[string]updateResult) + update["ABC"] = updateResult{action: add, result: success} + update["XYZ"] = updateResult{action: remove, result: serverError} + update[""] = updateResult{action: add, result: serverError} + + svc, err := New(&mockProvider{ + + outbound: &mockOutbound{validateSend: func(msg interface{}) error { + res, err := json.Marshal(msg) + require.NoError(t, err) + + updateRes := &KeylistUpdateResponse{} + err = json.Unmarshal(res, updateRes) + require.NoError(t, err) + + require.Equal(t, len(update), len(updateRes.Updated)) + + for _, v := range updateRes.Updated { + require.Equal(t, update[v.RecipientKey].action, v.Action) + require.Equal(t, update[v.RecipientKey].result, v.Result) + } + + return nil + }, + }, + }) + require.NoError(t, err) + + msgID := randomID() + + var updates []Update + for k, v := range update { + updates = append(updates, Update{ + RecipientKey: k, + Action: v.action, + }) + } + + err = svc.handleKeylistUpdate(generateKeyUpdateListMsgPayload(t, msgID, updates)) + require.NoError(t, err) + }) +} diff --git a/pkg/didcomm/protocol/route/support_test.go b/pkg/didcomm/protocol/route/support_test.go index 82fabba67..1fc9fc08d 100644 --- a/pkg/didcomm/protocol/route/support_test.go +++ b/pkg/didcomm/protocol/route/support_test.go @@ -7,6 +7,13 @@ SPDX-License-Identifier: Apache-2.0 package route import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" mockdispatcher "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/dispatcher" mockkms "github.com/hyperledger/aries-framework-go/pkg/internal/mock/kms" @@ -18,9 +25,15 @@ import ( // mock route coordination provider type mockProvider struct { openStoreErr error + outbound dispatcher.Outbound + endpoint string } func (p *mockProvider) OutboundDispatcher() dispatcher.Outbound { + if p.outbound != nil { + return p.outbound + } + return &mockdispatcher.MockOutbound{} } @@ -33,9 +46,53 @@ func (p *mockProvider) StorageProvider() storage.Provider { } func (p *mockProvider) InboundTransportEndpoint() string { + if p.endpoint != "" { + return p.endpoint + } + return "ws://example.com" } func (p *mockProvider) KMS() kms.KeyManager { return &mockkms.CloseableKMS{CreateEncryptionKeyValue: "sample-key"} } + +// mock outbound +type mockOutbound struct { + validateSend func(msg interface{}) error +} + +func (m *mockOutbound) Send(msg interface{}, senderVerKey string, des *service.Destination) error { + return m.validateSend(msg) +} + +func generateRequestMsgPayload(t *testing.T, id string) *service.DIDCommMsg { + requestBytes, err := json.Marshal(&Request{ + Type: RequestMsgType, + ID: id, + }) + require.NoError(t, err) + + didMsg, err := service.NewDIDCommMsg(requestBytes) + require.NoError(t, err) + + return didMsg +} + +func generateKeyUpdateListMsgPayload(t *testing.T, id string, updates []Update) *service.DIDCommMsg { + requestBytes, err := json.Marshal(&KeylistUpdate{ + Type: KeylistUpdateMsgType, + ID: id, + Updates: updates, + }) + require.NoError(t, err) + + didMsg, err := service.NewDIDCommMsg(requestBytes) + require.NoError(t, err) + + return didMsg +} + +func randomID() string { + return uuid.New().String() +}