diff --git a/cmd/aries-agent-rest/go.sum b/cmd/aries-agent-rest/go.sum index c8134da8c2..33fd4e0857 100644 --- a/cmd/aries-agent-rest/go.sum +++ b/cmd/aries-agent-rest/go.sum @@ -2,6 +2,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/VictoriaMetrics/fastcache v1.5.7 h1:4y6y0G8PRzszQUYIQHHssv/jgPHAb5qQuuDNdCbyAgw= github.com/VictoriaMetrics/fastcache v1.5.7/go.mod h1:ptDBkNMQI4RtmVo8VS/XwRY6RoTu1dAWCbrk+6WsEM8= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= @@ -32,10 +33,14 @@ github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/flimzy/diff v0.1.7/go.mod h1:lFJtC7SPsK0EroDmGTSrdtWKAxOk3rO+q+e04LL05Hs= +github.com/flimzy/testy v0.1.17/go.mod h1:3szguN8NXqgq9bt9Gu8TQVj698PJWmyx/VY1frwwKrM= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-kivik/couchdb v2.0.0+incompatible/go.mod h1:5XJRkAMpBlEVA4q0ktIZjUPYBjoBmRoiWvwUBzP3BOQ= github.com/go-kivik/kivik v2.0.0+incompatible/go.mod h1:nIuJ8z4ikBrVUSk3Ua8NoDqYKULPNjuddjqRvlSUyyQ= +github.com/go-kivik/kiviktest v2.0.0+incompatible/go.mod h1:JdhVyzixoYhoIDUt6hRf1yAfYyaDa5/u9SDOindDkfQ= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= @@ -73,6 +78,7 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= @@ -120,11 +126,18 @@ github.com/multiformats/go-varint v0.0.5/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/otiai10/copy v1.0.2/go.mod h1:c7RpqBkwMom4bYTSkLSym4VSJz/XtncWRAj/J4PEIMY= +github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= +github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/piprate/json-gold v0.3.0 h1:a1vHx7Q1jOO1pjCtKwTI/WCzwaQwRt9VM7apK2uy200= github.com/piprate/json-gold v0.3.0/go.mod h1:OK1z7UgtBZk06n2cDE2OSq1kffmjFFp5/2yhLLCz9UM= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= @@ -166,6 +179,7 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1: github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +gitlab.com/flimzy/testy v0.2.1/go.mod h1:YObF4cq711ubd/3U0ydRQQVz7Cnq/ChgJpVwNr/AJac= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/go.mod b/go.mod index 0dca52f297..ed71b62320 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( github.com/onsi/ginkgo v1.10.1 // indirect github.com/onsi/gomega v1.7.0 // indirect github.com/piprate/json-gold v0.3.0 - github.com/pkg/errors v0.9.1 // indirect + github.com/pkg/errors v0.9.1 github.com/rs/cors v1.7.0 github.com/square/go-jose/v3 v3.0.0-20191119004800-96c717272387 github.com/stretchr/testify v1.4.0 diff --git a/pkg/didcomm/protocol/messagepickup/api.go b/pkg/didcomm/protocol/messagepickup/api.go new file mode 100644 index 0000000000..aef916f67b --- /dev/null +++ b/pkg/didcomm/protocol/messagepickup/api.go @@ -0,0 +1,16 @@ +/* +Copyright Scoir Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package messagepickup + +import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" +) + +// ProtocolService fix +type ProtocolService interface { + AddMessage(message *model.Envelope, theirDID string) error +} diff --git a/pkg/didcomm/protocol/messagepickup/models.go b/pkg/didcomm/protocol/messagepickup/models.go new file mode 100644 index 0000000000..dae4b448f8 --- /dev/null +++ b/pkg/didcomm/protocol/messagepickup/models.go @@ -0,0 +1,68 @@ +/* +Copyright Scoir Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package messagepickup + +import ( + "time" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" +) + +// StatusRequest sent by the recipient to the message_holder to request a status message./0212-pickup#statusrequest +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0212-pickup#statusrequest +type StatusRequest struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + Thread *decorator.Thread `json:"~thread,omitempty"` +} + +// Status details about pending messages +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0212-pickup#status +type Status struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + MessageCount int `json:"message_count"` + DurationWaited int `json:"duration_waited,omitempty"` + LastAddedTime time.Time `json:"last_added_time,omitempty"` + LastDeliveredTime time.Time `json:"last_delivered_time,omitempty"` + LastRemovedTime time.Time `json:"last_removed_time,omitempty"` + TotalSize int `json:"total_size,omitempty"` + Thread *decorator.Thread `json:"~thread,omitempty"` +} + +// BatchPickup a request to have multiple waiting messages sent inside a batch message. +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0212-pickup#batch-pickup +type BatchPickup struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + BatchSize int `json:"batch_size"` + Thread *decorator.Thread `json:"~thread,omitempty"` +} + +// Batch a message that contains multiple waiting messages. +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0212-pickup#batch +type Batch struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` + Messages []*Message `json:"messages~attach"` + Thread *decorator.Thread `json:"~thread,omitempty"` +} + +// Message messagepickup wrapper +type Message struct { + ID string `json:"id"` + AddedTime time.Time `json:"added_time"` + Message *model.Envelope `json:"msg,omitempty"` +} + +// Noop message +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0212-pickup#noop +type Noop struct { + Type string `json:"@type,omitempty"` + ID string `json:"@id,omitempty"` +} diff --git a/pkg/didcomm/protocol/messagepickup/service.go b/pkg/didcomm/protocol/messagepickup/service.go new file mode 100644 index 0000000000..453e46c8f5 --- /dev/null +++ b/pkg/didcomm/protocol/messagepickup/service.go @@ -0,0 +1,597 @@ +/* +Copyright Scoir Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package messagepickup + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" + + "github.com/hyperledger/aries-framework-go/pkg/common/log" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + commtransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api" + "github.com/hyperledger/aries-framework-go/pkg/storage" + "github.com/hyperledger/aries-framework-go/pkg/store/connection" +) + +const ( + // MessagePickup defines the protocol name + MessagePickup = "messagepickup" + // Spec defines the protocol spec + Spec = "https://didcomm.org/messagepickup/1.0/" + // StatusMsgType defines the protocol propose-credential message type. + StatusMsgType = Spec + "status" + // StatusRequestMsgType defines the protocol propose-credential message type. + StatusRequestMsgType = Spec + "status-request" + // BatchPickupMsgType defines the protocol offer-credential message type. + BatchPickupMsgType = Spec + "batch-pickup" + // BatchMsgType defines the protocol offer-credential message type. + BatchMsgType = Spec + "batch" + // NoopMsgType defines the protocol request-credential message type. + NoopMsgType = Spec + "noop" +) + +const ( + updateTimeout = 50 * time.Second + + // Namespace is namespace of messagepickup store name + Namespace = "mailbox" +) + +// ErrConnectionNotFound connection not found error +var ErrConnectionNotFound = errors.New("connection not found") +var logger = log.New("aries-framework/messagepickup") + +type provider interface { + OutboundDispatcher() dispatcher.Outbound + StorageProvider() storage.Provider + TransientStorageProvider() storage.Provider +} + +type connections interface { + GetConnectionRecord(string) (*connection.Record, error) +} + +// Service for the messagepickup protocol +type Service struct { + connectionLookup connections + outbound dispatcher.Outbound + msgStore storage.Store + packager commtransport.Packager + msgHandler transport.InboundMessageHandler + batchMap map[string]chan Batch + batchMapLock sync.RWMutex + statusMap map[string]chan Status + statusMapLock sync.RWMutex + inboxMap map[string]*lockBox + inboxMapLock sync.RWMutex +} + +type lockBox struct { + mux sync.RWMutex +} + +// ServiceCreator for the messagepickup protocol +func ServiceCreator() api.ProtocolSvcCreator { + return func(prv api.Provider) (dispatcher.ProtocolService, error) { + tp, ok := prv.(transport.Provider) + if !ok { + return nil, errors.New("failed to cast transport provider") + } + + return New(prv, tp) + } +} + +// New returns the messagepickup service +func New(prov provider, tp transport.Provider) (*Service, error) { + store, err := prov.StorageProvider().OpenStore(Namespace) + if err != nil { + return nil, fmt.Errorf("open mailbox store : %w", err) + } + + connectionLookup, err := connection.NewLookup(prov) + if err != nil { + return nil, err + } + + svc := &Service{ + outbound: prov.OutboundDispatcher(), + msgStore: store, + connectionLookup: connectionLookup, + packager: tp.Packager(), + msgHandler: tp.InboundMessageHandler(), + batchMap: make(map[string]chan Batch), + statusMap: make(map[string]chan Status), + inboxMap: make(map[string]*lockBox), + } + + return svc, nil +} + +// HandleInbound handles inbound message pick up messages +func (s *Service) HandleInbound(msg service.DIDCommMsg, myDID, theirDID string) (string, error) { + // perform action asynchronously + go func() { + var err error + + switch msg.Type() { + case StatusMsgType: + err = s.handleStatus(msg) + case StatusRequestMsgType: + err = s.handleStatusRequest(msg, myDID, theirDID) + case BatchPickupMsgType: + err = s.handleBatchPickup(msg, myDID, theirDID) + case BatchMsgType: + err = s.handleBatch(msg) + case NoopMsgType: + err = s.handleNoop(msg) + } + + if err != nil { + logger.Errorf("Error handling message: (%w)\n", err) + } + }() + + return msg.ID(), nil +} + +// HandleOutbound adherence to dispatcher.ProtocolService +func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) error { + return errors.New("not implemented") +} + +// Accept checks whether the service can handle the message type +func (s *Service) Accept(msgType string) bool { + switch msgType { + case BatchPickupMsgType, BatchMsgType, StatusRequestMsgType, StatusMsgType, NoopMsgType: + return true + } + + return false +} + +// Name of the service +func (s *Service) Name() string { + return MessagePickup +} + +func (s *Service) handleStatus(msg service.DIDCommMsg) error { + // unmarshal the payload + statusMsg := &Status{} + + err := msg.Decode(statusMsg) + if err != nil { + return fmt.Errorf("status message unmarshal: %w", err) + } + + // check if there are any channels registered for the message ID + statusCh := s.getStatusCh(statusMsg.ID) + if statusCh != nil { + // invoke the channel for the incoming message + statusCh <- *statusMsg + } + + return nil +} + +func (s *Service) handleStatusRequest(msg service.DIDCommMsg, myDID, theirDID string) error { + // unmarshal the payload + request := &StatusRequest{} + + err := msg.Decode(request) + if err != nil { + return fmt.Errorf("status request message unmarshal: %w", err) + } + + logger.Debugf("retrieving stored messages for %s\n", theirDID) + + outbox, err := s.getInbox(theirDID) + if err != nil { + return fmt.Errorf("error in status request getting inbox: %w", err) + } + + resp := &Status{ + Type: StatusMsgType, + ID: msg.ID(), + MessageCount: outbox.MessageCount, + DurationWaited: int(time.Since(outbox.LastDeliveredTime).Seconds()), + LastAddedTime: outbox.LastAddedTime, + LastDeliveredTime: outbox.LastDeliveredTime, + LastRemovedTime: outbox.LastRemovedTime, + TotalSize: outbox.TotalSize, + Thread: &decorator.Thread{ + PID: request.Thread.ID, + }, + } + + return s.outbound.SendToDID(resp, myDID, theirDID) +} + +func (s *Service) handleBatchPickup(msg service.DIDCommMsg, myDID, theirDID string) error { + // unmarshal the payload + request := &BatchPickup{} + + err := msg.Decode(request) + if err != nil { + return fmt.Errorf("batch pickup message unmarshal : %w", err) + } + + msgs, err := s.pullMessages(theirDID, request.BatchSize) + if err != nil { + return fmt.Errorf("batch pick up pull messages : %w", err) + } + + batch := &Batch{ + Type: BatchMsgType, + ID: msg.ID(), + Messages: msgs, + } + + return s.outbound.SendToDID(batch, myDID, theirDID) +} + +func (s *Service) handleBatch(msg service.DIDCommMsg) error { + // unmarshal the payload + batchMsg := &Batch{} + + err := msg.Decode(batchMsg) + if err != nil { + return fmt.Errorf("batch message unmarshal : %w", err) + } + + // check if there are any channels registered for the message ID + batchCh := s.getBatchCh(batchMsg.ID) + + if batchCh != nil { + // invoke the channel for the incoming message + batchCh <- *batchMsg + } + + return nil +} + +func (s *Service) handleNoop(msg service.DIDCommMsg) error { + // unmarshal the payload + request := &Noop{} + + err := msg.Decode(request) + if err != nil { + return fmt.Errorf("noop message unmarshal : %w", err) + } + + return nil +} + +type inbox struct { + DID string `json:"DID"` + MessageCount int `json:"message_count"` + LastAddedTime time.Time `json:"last_added_time,omitempty"` + LastDeliveredTime time.Time `json:"last_delivered_time,omitempty"` + LastRemovedTime time.Time `json:"last_removed_time,omitempty"` + TotalSize int `json:"total_size,omitempty"` + Messages json.RawMessage `json:"messages"` +} + +// DecodeMessages Messages +func (r *inbox) DecodeMessages() ([]*Message, error) { + var out []*Message + err := json.Unmarshal(r.Messages, &out) + + return out, err +} + +// EncodeMessages Messages +func (r *inbox) EncodeMessages(msg []*Message) error { + d, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("unable to marshal: %w", err) + } + + r.Messages = d + r.MessageCount = len(msg) + r.TotalSize = len(d) + + return nil +} + +// AddMessage add message to inbox +func (s *Service) AddMessage(message *model.Envelope, theirDID string) error { + outbox, err := s.getInbox(theirDID) + if err != nil { + return fmt.Errorf("unable to pull messages: %w", err) + } + + msgs, err := outbox.DecodeMessages() + if err != nil { + return fmt.Errorf("unable to decode messages: %w", err) + } + + m := Message{ + ID: uuid.New().String(), + AddedTime: time.Now(), + Message: message, + } + + msgs = append(msgs, &m) + + outbox.LastDeliveredTime = time.Now() + outbox.LastRemovedTime = outbox.LastDeliveredTime + + err = outbox.EncodeMessages(msgs) + if err != nil { + return fmt.Errorf("unable to pull messages: %w", err) + } + + err = s.putInbox(theirDID, outbox) + if err != nil { + return fmt.Errorf("unable to pull messages: %w", err) + } + + return nil +} + +func (s *Service) pullMessages(theirDID string, size int) ([]*Message, error) { + outbox, err := s.getInbox(theirDID) + if err != nil { + return nil, fmt.Errorf("unable to pull messages: %w", err) + } + + msgs, err := outbox.DecodeMessages() + if err != nil { + return nil, fmt.Errorf("unable to pull messages: %w", err) + } + + end := len(msgs) + if size < end { + end = size + } + + outbox.LastDeliveredTime = time.Now() + outbox.LastRemovedTime = time.Now() + + err = outbox.EncodeMessages(msgs[end:]) + if err != nil { + return nil, fmt.Errorf("unable to pull messages: %w", err) + } + + err = s.putInbox(theirDID, outbox) + if err != nil { + return nil, fmt.Errorf("unable to put messages: %w", err) + } + + return msgs[0:end], nil +} + +func (s *Service) getInboxLock(theirDID string) { + s.inboxMapLock.Lock() + defer s.inboxMapLock.Unlock() + + if _, ok := s.inboxMap[theirDID]; !ok { + s.inboxMap[theirDID] = &lockBox{} + } + + s.inboxMap[theirDID].mux.Lock() +} + +func (s *Service) releaseInboxLock(theirDID string) { + s.inboxMapLock.Lock() + defer s.inboxMapLock.Unlock() + + s.inboxMap[theirDID].mux.Unlock() + delete(s.inboxMap, theirDID) +} + +func (s *Service) getInbox(theirDID string) (*inbox, error) { + msgs := &inbox{DID: theirDID} + + s.getInboxLock(theirDID) + defer s.releaseInboxLock(theirDID) + + b, err := s.msgStore.Get(theirDID) + if err != nil { + return nil, err + } + + err = json.Unmarshal(b, msgs) + if err != nil { + return nil, err + } + + return msgs, nil +} + +func (s *Service) putInbox(theirDID string, o *inbox) error { + b, err := json.Marshal(o) + if err != nil { + return err + } + + s.getInboxLock(theirDID) + defer s.releaseInboxLock(theirDID) + + return s.msgStore.Put(theirDID, b) +} + +// StatusRequest request a status message +func (s *Service) StatusRequest(connectionID string) (*Status, error) { + // get the connection record for the ID to fetch DID information + conn, err := s.getConnection(connectionID) + if err != nil { + return nil, err + } + + // generate message ID + msgID := uuid.New().String() + + // register chan for callback processing + statusCh := make(chan Status) + s.setStatusCh(msgID, statusCh) + + defer s.setStatusCh(msgID, nil) + + // create request message + req := &StatusRequest{ + Type: StatusRequestMsgType, + ID: msgID, + Thread: &decorator.Thread{ + PID: uuid.New().String(), + }, + } + + // send message to the router + if err := s.outbound.SendToDID(req, conn.MyDID, conn.TheirDID); err != nil { + return nil, fmt.Errorf("send route request: %w", err) + } + + // callback processing (to make this function look like a sync function) + var sts *Status + select { + case s := <-statusCh: + sts = &s + // TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level + case <-time.After(updateTimeout): + return nil, errors.New("timeout waiting for status request") + } + + return sts, nil +} + +// BatchPickup a request to have multiple waiting messages sent inside a batch message +func (s *Service) BatchPickup(connectionID string, size int) (int, error) { + // get the connection record for the ID to fetch DID information + conn, err := s.getConnection(connectionID) + if err != nil { + return -1, err + } + + // generate message ID + msgID := uuid.New().String() + + // register chan for callback processing + batchCh := make(chan Batch) + s.setBatchCh(msgID, batchCh) + + defer s.setBatchCh(msgID, nil) + + // create request message + req := &BatchPickup{ + Type: BatchPickupMsgType, + ID: msgID, + BatchSize: size, + } + + // send message to the router + if err := s.outbound.SendToDID(req, conn.MyDID, conn.TheirDID); err != nil { + return -1, fmt.Errorf("send route request: %w", err) + } + + // callback processing (to make this function look like a sync function) + var processed int + select { + case batchResp := <-batchCh: + for _, msg := range batchResp.Messages { + err := s.handle(msg) + if err != nil { + logger.Errorf("error handling batch message %s: %w", msg.ID, err) + continue + } + processed++ + } + // TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level + case <-time.After(updateTimeout): + return -1, errors.New("timeout waiting for batch") + } + + return processed, nil +} + +func (s *Service) getConnection(routerConnID string) (*connection.Record, error) { + conn, err := s.connectionLookup.GetConnectionRecord(routerConnID) + if err != nil { + if errors.Is(err, storage.ErrDataNotFound) { + return nil, ErrConnectionNotFound + } + + return nil, fmt.Errorf("fetch connection record from store : %w", err) + } + + return conn, nil +} + +func (s *Service) getBatchCh(msgID string) chan Batch { + s.batchMapLock.RLock() + defer s.batchMapLock.RUnlock() + + return s.batchMap[msgID] +} + +func (s *Service) setBatchCh(msgID string, batchCh chan Batch) { + s.batchMapLock.Lock() + defer s.batchMapLock.Unlock() + + if batchCh == nil { + delete(s.batchMap, msgID) + } else { + s.batchMap[msgID] = batchCh + } +} + +func (s *Service) getStatusCh(msgID string) chan Status { + s.statusMapLock.RLock() + defer s.statusMapLock.RUnlock() + + return s.statusMap[msgID] +} + +func (s *Service) setStatusCh(msgID string, statusCh chan Status) { + s.statusMapLock.Lock() + defer s.statusMapLock.Unlock() + + if statusCh == nil { + delete(s.statusMap, msgID) + } else { + s.statusMap[msgID] = statusCh + } +} + +func (s *Service) handle(msg *Message) error { + d, err := json.Marshal(msg.Message) + if err != nil { + return fmt.Errorf("failed to marshal msg: %w", err) + } + + unpackMsg, err := s.packager.UnpackMessage(d) + if err != nil { + return fmt.Errorf("failed to unpack msg: %w", err) + } + + trans := &decorator.Transport{} + err = json.Unmarshal(unpackMsg.Message, trans) + + if err != nil { + return fmt.Errorf("unmarshal transport decorator : %w", err) + } + + messageHandler := s.msgHandler + + err = messageHandler(unpackMsg.Message, unpackMsg.ToDID, unpackMsg.FromDID) + if err != nil { + return fmt.Errorf("incoming msg processing failed: %w", err) + } + + return nil +} diff --git a/pkg/didcomm/protocol/messagepickup/service_test.go b/pkg/didcomm/protocol/messagepickup/service_test.go new file mode 100644 index 0000000000..0b66db077c --- /dev/null +++ b/pkg/didcomm/protocol/messagepickup/service_test.go @@ -0,0 +1,916 @@ +/* +Copyright Scoir Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package messagepickup + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" + mockdispatcher "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/dispatcher" + mockprovider "github.com/hyperledger/aries-framework-go/pkg/mock/provider" + mockstore "github.com/hyperledger/aries-framework-go/pkg/mock/storage" + "github.com/hyperledger/aries-framework-go/pkg/storage" + "github.com/hyperledger/aries-framework-go/pkg/store/connection" +) + +const ( + MYDID = "sample-my-did" + THEIRDID = "sample-their-did" +) + +func TestServiceNew(t *testing.T) { + t.Run("test new service - success", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + require.Equal(t, MessagePickup, svc.Name()) + }) + + t.Run("test new service name - store error", func(t *testing.T) { + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: &mockstore.MockStoreProvider{ + ErrOpenStoreHandle: fmt.Errorf("error opening the store")}, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "open mailbox store") + require.Nil(t, svc) + }) +} + +func TestHandleInbound(t *testing.T) { + t.Run("test MessagePickupService.HandleInbound() - Status", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/status", + "message_count": 7, + "duration_waited": 3600, + "last_added_time": "2019-05-01T12:00:00Z", + "last_delivered_time": "2019-05-01T12:00:00Z", + "last_removed_time": "2019-05-01T12:00:00Z", + "total_size": 8096 + }` + + svc, err := getService() + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + statusCh := make(chan Status) + svc.setStatusCh(msg.ID(), statusCh) + + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + + tyme, err := time.Parse(time.RFC3339, "2019-05-01T12:00:00Z") + require.NoError(t, err) + + select { + case x := <-svc.statusMap[msg.ID()]: + require.NotNil(t, x) + require.Equal(t, "123456781", x.ID) + require.Equal(t, 3600, x.DurationWaited) + require.Equal(t, tyme, x.LastAddedTime) + require.Equal(t, tyme, x.LastDeliveredTime) + require.Equal(t, tyme, x.LastRemovedTime) + require.Equal(t, 7, x.MessageCount) + require.Equal(t, 8096, x.TotalSize) + + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive message to handle") + } + }) + + t.Run("test MessagePickupService.HandleInbound() - unknown type", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "unknown" + }` + + svc, err := getService() + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + statusCh := make(chan Status) + svc.setStatusCh(msg.ID(), statusCh) + + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + }) + + t.Run("test MessagePickupService.HandleInbound() - Status - msg error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + msg := &service.DIDCommMsgMap{"@id": map[int]int{}} + err = svc.handleStatus(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "status message unmarshal") + }) + + t.Run("test MessagePickupService.HandleInbound() - StatusRequest success", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/status-request", + "~thread" : {"thid": "2d798168-8abf-4410-8535-bc1e8406a5ff"} + }` + msgID := make(chan string) + + tyme, err := time.Parse(time.RFC3339, "2019-05-01T12:00:00Z") + require.NoError(t, err) + + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + require.Equal(t, myDID, MYDID) + require.Equal(t, theirDID, THEIRDID) + request, ok := msg.(*Status) + require.True(t, ok) + + require.Equal(t, 1, request.MessageCount) + require.Equal(t, tyme, request.LastAddedTime) + require.Equal(t, tyme, request.LastDeliveredTime) + require.Equal(t, tyme, request.LastRemovedTime) + require.Equal(t, 3096, request.TotalSize) + require.Equal(t, "2d798168-8abf-4410-8535-bc1e8406a5ff", request.Thread.PID) + + msgID <- request.ID + + return nil + }}, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + b, err := json.Marshal(inbox{ + DID: "sample-their-did", + MessageCount: 1, + LastAddedTime: tyme, + LastDeliveredTime: tyme, + LastRemovedTime: tyme, + TotalSize: 3096, + Messages: []byte(`[{"test": "message"}]`), + }) + require.NoError(t, err) + + err = svc.msgStore.Put(THEIRDID, b) + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + go func() { + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + }() + + select { + case id := <-msgID: + require.NotNil(t, id) + require.Equal(t, "123456781", id) + + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive message to handle") + } + }) + + t.Run("test MessagePickupService.HandleInbound() - StatusRequest - msg error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + msg := &service.DIDCommMsgMap{"@id": map[int]int{}} + err = svc.handleStatusRequest(msg, MYDID, THEIRDID) + require.Error(t, err) + require.Contains(t, err.Error(), "status request message unmarshal") + }) + + t.Run("test MessagePickupService.HandleInbound() - StatusRequest - get error", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/status-request", + "~thread" : {"thid": "2d798168-8abf-4410-8535-bc1e8406a5ff"} + }` + + svc, err := getService() + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + err = svc.handleStatusRequest(msg, MYDID, "not found") + require.Error(t, err) + require.Contains(t, err.Error(), "error in status request getting inbox") + }) + + t.Run("test MessagePickupService.HandleInbound() - BatchPickup", func(t *testing.T) { + msgID := make(chan string) + + tyme, err := time.Parse(time.RFC3339, "2019-05-01T12:00:00Z") + require.NoError(t, err) + + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + require.Equal(t, myDID, MYDID) + require.Equal(t, theirDID, THEIRDID) + + request, ok := msg.(*Batch) + require.True(t, ok) + + require.Equal(t, 2, len(request.Messages)) + + msgID <- request.ID + + return nil + }}, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + b, err := json.Marshal(inbox{ + DID: "sample-their-did", + MessageCount: 2, + LastAddedTime: tyme, + LastDeliveredTime: tyme, + LastRemovedTime: tyme, + TotalSize: 3096, + Messages: []byte(`[{"id": "8910"}, {"id": "8911"}, {"id": "8912"}]`), + }) + require.NoError(t, err) + + err = svc.msgStore.Put(THEIRDID, b) + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(`{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/batch-pickup", + "batch_size": 2, + "~thread" : {"thid": "2d798168-8abf-4410-8535-bc1e8406a5ff"} + }`)) + require.NoError(t, err) + + go func() { + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + }() + + select { + case id := <-msgID: + require.NotNil(t, id) + require.Equal(t, id, "123456781") + + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive message to handle") + } + }) + + t.Run("test MessagePickupService.HandleInbound() - BatchPickup - msg error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + msg := &service.DIDCommMsgMap{"@id": map[int]int{}} + err = svc.handleBatchPickup(msg, MYDID, THEIRDID) + require.Error(t, err) + require.Contains(t, err.Error(), "batch pickup message unmarshal") + }) + + t.Run("test MessagePickupService.HandleInbound() - BatchPickup - pull error", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + mockStore.Store.ErrGet = errors.New("error pull messages") + + msg, err := service.ParseDIDCommMsgMap([]byte(`{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/batch-pickup", + "batch_size": 2, + "~thread" : {"thid": "2d798168-8abf-4410-8535-bc1e8406a5ff"} + }`)) + require.NoError(t, err) + + err = svc.handleBatchPickup(msg, MYDID, THEIRDID) + require.Error(t, err) + require.Contains(t, err.Error(), "error pull messages") + }) + + t.Run("test MessagePickupService.HandleInbound() - Batch", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/batch", + "messages~attach": [ + { + "@id" : "06ca25f6-d3c5-48ac-8eee-1a9e29120c31", + "message" : "{\"id\": \"8910\"}" + }, + { + "@id" : "344a51cf-379f-40ab-ab2c-711dab3f53a9a", + "message" : "{\"id\": \"8910\"}" + } + ] + }` + + svc, err := getService() + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + batchCh := make(chan Batch) + svc.setBatchCh(msg.ID(), batchCh) + + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + + select { + case x := <-svc.batchMap[msg.ID()]: + require.NotNil(t, x) + require.Equal(t, "123456781", x.ID) + require.Equal(t, 2, len(x.Messages)) + + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive message to handle") + } + }) + + t.Run("test MessagePickupService.HandleInbound() - Batch - msg error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + msg := &service.DIDCommMsgMap{"@id": map[int]int{}} + err = svc.handleBatch(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "batch message unmarshal") + }) + + t.Run("test MessagePickupService.HandleInbound() - Noop", func(t *testing.T) { + const jsonStr = `{ + "@id": "123456781", + "@type": "https://didcomm.org/messagepickup/1.0/noop" + }` + + svc, err := getService() + require.NoError(t, err) + + msg, err := service.ParseDIDCommMsgMap([]byte(jsonStr)) + require.NoError(t, err) + + _, err = svc.HandleInbound(msg, MYDID, THEIRDID) + require.NoError(t, err) + }) + + t.Run("test MessagePickupService.HandleInbound() - Noop - msg error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + msg := &service.DIDCommMsgMap{"@id": map[int]int{}} + err = svc.handleNoop(msg) + require.Error(t, err) + require.Contains(t, err.Error(), "noop message unmarshal") + }) +} + +func TestAccept(t *testing.T) { + t.Run("test MessagePickupService.Accept() - Status", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + require.True(t, svc.Accept(StatusMsgType)) + require.True(t, svc.Accept(StatusRequestMsgType)) + require.True(t, svc.Accept(NoopMsgType)) + require.True(t, svc.Accept(BatchMsgType)) + require.True(t, svc.Accept(BatchPickupMsgType)) + require.False(t, svc.Accept("random-msg-type")) + }) +} + +func TestAddMessage(t *testing.T) { + t.Run("test MessagePickupService.AddMessage() - success", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + message := &model.Envelope{ + Protected: "eyJ0eXAiOiJwcnMuaHlwZXJsZWRnZXIuYXJpZXMtYXV0aC1t" + + "ZXNzYWdlIiwiYWxnIjoiRUNESC1TUytYQzIwUEtXIiwiZW5jIjoiWEMyMFAifQ", + IV: "JS2FxjEKdndnt-J7QX5pEnVwyBTu0_3d", + CipherText: "qQyzvajdvCDJbwxM", + Tag: "2FqZMMQuNPYfL0JsSkj8LQ", + } + + tyme, err := time.Parse(time.RFC3339, "2019-05-01T12:00:00Z") + require.NoError(t, err) + + b, err := json.Marshal(inbox{ + DID: "sample-their-did", + MessageCount: 2, + LastAddedTime: tyme, + LastDeliveredTime: tyme, + LastRemovedTime: tyme, + TotalSize: 3096, + Messages: []byte(`[{"id": "8910"}, {"id": "8911"}, {"id": "8912"}]`), + }) + require.NoError(t, err) + + err = svc.msgStore.Put(THEIRDID, b) + require.NoError(t, err) + + err = svc.AddMessage(message, THEIRDID) + require.NoError(t, err) + + b, err = mockStore.Store.Get(THEIRDID) + require.NoError(t, err) + + ibx := &inbox{} + err = json.Unmarshal(b, ibx) + require.NoError(t, err) + + require.Equal(t, 4, ibx.MessageCount) + }) + + t.Run("test MessagePickupService.AddMessage() - put error", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + b, err := json.Marshal(inbox{ + DID: "sample-their-did", + }) + require.NoError(t, err) + + // seed data for initial get in AddMessage + err = mockStore.Store.Put(THEIRDID, b) + require.NoError(t, err) + + mockStore.Store.ErrPut = errors.New("error put") + + message := &model.Envelope{} + + err = svc.AddMessage(message, THEIRDID) + require.Error(t, err) + require.Contains(t, err.Error(), "error put") + }) + + t.Run("test MessagePickupService.AddMessage() - get error", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + message := &model.Envelope{} + + mockStore.Store.ErrGet = errors.New("error get") + + err = svc.AddMessage(message, "not found") + require.Error(t, err) + require.Contains(t, err.Error(), "error get") + }) +} + +func TestStatusRequest(t *testing.T) { + t.Run("test MessagePickupService.StatusRequest() - success", func(t *testing.T) { + msgID := make(chan string) + s := make(map[string][]byte) + + provider := &mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + require.Equal(t, myDID, MYDID) + require.Equal(t, theirDID, THEIRDID) + + request, ok := msg.(*StatusRequest) + require.True(t, ok) + + msgID <- request.ID + + return nil + }}, + } + + connRec := &connection.Record{ + ConnectionID: "conn1", MyDID: MYDID, TheirDID: THEIRDID, State: "completed"} + connBytes, err := json.Marshal(connRec) + require.NoError(t, err) + + s["conn_conn1"] = connBytes + + r, err := connection.NewRecorder(provider) + require.NoError(t, err) + err = r.SaveConnectionRecord(connRec) + require.NoError(t, err) + + svc, err := New(provider, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + go func() { + status, err := svc.StatusRequest("conn1") + require.NoError(t, err) + + require.Equal(t, 6, status.MessageCount) + }() + + select { + case id := <-msgID: + require.NotNil(t, id) + s := Status{ + MessageCount: 6, + } + + // outbound has been handled, simulate a callback to finish the trip + ch := svc.getStatusCh(id) + ch <- s + + case <-time.After(2 * time.Second): + require.Fail(t, "didn't receive message to handle") + } + }) + + t.Run("test MessagePickupService.StatusRequest() - connection error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + expected := errors.New("get error") + svc.connectionLookup = &connectionsStub{ + getConnRecord: func(string) (*connection.Record, error) { + return nil, expected + }, + } + + _, err = svc.StatusRequest("conn1") + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + }) + + t.Run("test MessagePickupService.StatusRequest() - send to DID error", func(t *testing.T) { + s := make(map[string][]byte) + + provider := &mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + return errors.New("send error") + }}, + } + + connRec := &connection.Record{ + ConnectionID: "conn1", MyDID: MYDID, TheirDID: THEIRDID, State: "completed"} + connBytes, err := json.Marshal(connRec) + require.NoError(t, err) + + s["conn_conn1"] = connBytes + + r, err := connection.NewRecorder(provider) + require.NoError(t, err) + err = r.SaveConnectionRecord(connRec) + require.NoError(t, err) + + svc, err := New(provider, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + _, err = svc.StatusRequest("conn1") + require.Error(t, err) + require.Contains(t, err.Error(), "send route request") + }) +} + +func TestBatchPickup(t *testing.T) { + t.Run("test MessagePickupService.BatchPickup() - success", func(t *testing.T) { + msgID := make(chan string) + s := make(map[string][]byte) + + provider := &mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + require.Equal(t, myDID, MYDID) + require.Equal(t, theirDID, THEIRDID) + + batchpickup, ok := msg.(*BatchPickup) + require.True(t, ok) + + require.Equal(t, 1, batchpickup.BatchSize) + msgID <- batchpickup.ID + + return nil + }}, + } + + connRec := &connection.Record{ + ConnectionID: "conn1", MyDID: MYDID, TheirDID: THEIRDID, State: "completed"} + connBytes, err := json.Marshal(connRec) + require.NoError(t, err) + + s["conn_conn1"] = connBytes + + r, err := connection.NewRecorder(provider) + require.NoError(t, err) + err = r.SaveConnectionRecord(connRec) + require.NoError(t, err) + + svc, err := New(provider, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + go func() { + id := <-msgID + require.NotNil(t, id) + + s := Batch{ + Messages: []*Message{{Message: &model.Envelope{ + Protected: "eyJ0eXAiOiJwcnMuaHlwZXJsZWRnZXIuYXJpZXMtYXV0aC1t" + + "ZXNzYWdlIiwiYWxnIjoiRUNESC1TUytYQzIwUEtXIiwiZW5jIjoiWEMyMFAifQ", + IV: "JS2FxjEKdndnt-J7QX5pEnVwyBTu0_3d", + CipherText: "qQyzvajdvCDJbwxM", + Tag: "2FqZMMQuNPYfL0JsSkj8LQ", + }}}, + } + + // outbound has been handled, simulate a callback to finish the trip + ch := svc.getBatchCh(id) + ch <- s + }() + + p, err := svc.BatchPickup("conn1", 1) + require.NoError(t, err) + + require.Equal(t, 1, p) + }) + + t.Run("test MessagePickupService.BatchPickup() - connection error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + expected := errors.New("get error") + svc.connectionLookup = &connectionsStub{ + getConnRecord: func(string) (*connection.Record, error) { + return nil, expected + }, + } + + p, err := svc.BatchPickup("conn1", 4) + require.Error(t, err) + require.True(t, errors.Is(err, expected)) + require.Equal(t, -1, p) + }) + + t.Run("test MessagePickupService.BatchPickup() - send to DID error", func(t *testing.T) { + s := make(map[string][]byte) + + provider := &mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: &mockdispatcher.MockOutbound{ + ValidateSendToDID: func(msg interface{}, myDID, theirDID string) error { + return errors.New("send error") + }}, + } + + connRec := &connection.Record{ + ConnectionID: "conn1", MyDID: MYDID, TheirDID: THEIRDID, State: "completed"} + connBytes, err := json.Marshal(connRec) + require.NoError(t, err) + + s["conn_conn1"] = connBytes + + r, err := connection.NewRecorder(provider) + require.NoError(t, err) + err = r.SaveConnectionRecord(connRec) + require.NoError(t, err) + + svc, err := New(provider, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + _, err = svc.BatchPickup("conn1", 4) + require.Error(t, err) + require.Contains(t, err.Error(), "send route request") + }) +} + +func TestDecodeMessages(t *testing.T) { + t.Run("test inbox.DecodeMessages() - success", func(t *testing.T) { + ibx := &inbox{} + + _, err := ibx.DecodeMessages() + require.Error(t, err) + }) + + t.Run("test inbox.DecodeMessages() - error", func(t *testing.T) { + b, err := json.Marshal([]*Message{}) + require.NoError(t, err) + + ibx := &inbox{ + Messages: b, + } + + _, err = ibx.DecodeMessages() + require.NoError(t, err) + }) +} + +func TestHandleOutbound(t *testing.T) { + t.Run("test MessagePickupService.HandleOutbound() - not implemented", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + svc.connectionLookup = &connectionsStub{ + getConnRecord: func(string) (*connection.Record, error) { + return nil, storage.ErrDataNotFound + }, + } + + err = svc.HandleOutbound(nil, "not", "implemented") + require.Error(t, err) + require.Contains(t, err.Error(), "not implemented") + }) +} +func TestPullMessages(t *testing.T) { + t.Run("test MessagePickupService.pullMessages() - get inbox error", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + mockStore.Store.ErrGet = errors.New("error get") + + _, err = svc.pullMessages(THEIRDID, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "error get") + }) + + t.Run("test MessagePickupService.pullMessages() - put inbox error", func(t *testing.T) { + mockStore := mockstore.NewMockStoreProvider() + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockStore, + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + require.NoError(t, err) + + b, err := json.Marshal(&inbox{DID: THEIRDID}) + require.NoError(t, err) + + err = mockStore.Store.Put(THEIRDID, b) + require.NoError(t, err) + + mockStore.Store.ErrPut = errors.New("error put") + + _, err = svc.pullMessages(THEIRDID, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "error put") + }) +} + +func TestGetConnection(t *testing.T) { + t.Run("test MessagePickupService.getConnection() - error", func(t *testing.T) { + svc, err := getService() + require.NoError(t, err) + + svc.connectionLookup = &connectionsStub{ + getConnRecord: func(string) (*connection.Record, error) { + return nil, storage.ErrDataNotFound + }, + } + + _, err = svc.getConnection("test") + require.Error(t, err) + require.True(t, errors.Is(err, ErrConnectionNotFound)) + }) +} + +func getService() (*Service, error) { + svc, err := New(&mockprovider.Provider{ + StorageProviderValue: mockstore.NewMockStoreProvider(), + TransientStorageProviderValue: mockstore.NewMockStoreProvider(), + OutboundDispatcherValue: nil, + }, &mockTransportProvider{ + packagerValue: &mockPackager{}, + }) + + return svc, err +} + +// mockProvider mock provider +type mockTransportProvider struct { + packagerValue commontransport.Packager +} + +func (p *mockTransportProvider) Packager() commontransport.Packager { + return p.packagerValue +} + +func (p *mockTransportProvider) InboundMessageHandler() transport.InboundMessageHandler { + return func(message []byte, myDID, theirDID string) error { + logger.Debugf("message received is %s", message) + return nil + } +} + +func (p *mockTransportProvider) AriesFrameworkID() string { + return "aries-framework-instance-1" +} + +// mockPackager mock packager +type mockPackager struct { +} + +func (m *mockPackager) PackMessage(e *commontransport.Envelope) ([]byte, error) { + return e.Message, nil +} + +func (m *mockPackager) UnpackMessage(encMessage []byte) (*commontransport.Envelope, error) { + return &commontransport.Envelope{ + Message: []byte(`{ + "id": "8910", + "~transport": { + "return_route": "all" + } + }`), + }, nil +} + +type connectionsStub struct { + getConnIDByDIDs func(string, string) (string, error) + getConnRecord func(string) (*connection.Record, error) +} + +func (c *connectionsStub) GetConnectionIDByDIDs(myDID, theirDID string) (string, error) { + if c.getConnIDByDIDs != nil { + return c.getConnIDByDIDs(myDID, theirDID) + } + + return "", nil +} + +func (c *connectionsStub) GetConnectionRecord(id string) (*connection.Record, error) { + if c.getConnRecord != nil { + return c.getConnRecord(id) + } + + return nil, nil +} diff --git a/test/bdd/go.mod b/test/bdd/go.mod index 17498bd99b..bed8fe4124 100644 --- a/test/bdd/go.mod +++ b/test/bdd/go.mod @@ -22,7 +22,6 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/piprate/json-gold v0.3.0 - github.com/pkg/errors v0.9.1 // indirect github.com/sirupsen/logrus v1.4.2 // indirect github.com/trustbloc/sidetree-core-go v0.1.3 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect diff --git a/test/bdd/go.sum b/test/bdd/go.sum index 5fab2b9bbd..9184576cd6 100644 --- a/test/bdd/go.sum +++ b/test/bdd/go.sum @@ -5,6 +5,7 @@ cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSR github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7OZ575w+acHgRric5iCyQh+xv+KJ4HB8= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA= github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= github.com/Microsoft/hcsshim v0.8.7-0.20191101173118-65519b62243c/go.mod h1:7xhjOwRV2+0HXGmM0jxaEu+ZiXJFoVZOTfL/dmqbrD8= @@ -69,12 +70,16 @@ github.com/evanphx/json-patch v4.1.0+incompatible h1:K1MDoo4AZ4wU0GIU/fPmtZg7Vpz github.com/evanphx/json-patch v4.1.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.5.0+incompatible h1:ouOWdg56aJriqS0huScTkVXPC5IcNrDCXZ6OoTAWu7M= github.com/evanphx/json-patch v4.5.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/flimzy/diff v0.1.7/go.mod h1:lFJtC7SPsK0EroDmGTSrdtWKAxOk3rO+q+e04LL05Hs= +github.com/flimzy/testy v0.1.17/go.mod h1:3szguN8NXqgq9bt9Gu8TQVj698PJWmyx/VY1frwwKrM= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsouza/go-dockerclient v1.6.1 h1:qBvbtwBTpOYktncvxjFMHxJHuGG19lb2fvAFqfXeh7w= github.com/fsouza/go-dockerclient v1.6.1/go.mod h1:g2pGMa82+SdtAicFSpxGJc1Anx//HHssXyWLwMRxaqg= github.com/go-kivik/couchdb v2.0.0+incompatible/go.mod h1:5XJRkAMpBlEVA4q0ktIZjUPYBjoBmRoiWvwUBzP3BOQ= github.com/go-kivik/kivik v2.0.0+incompatible/go.mod h1:nIuJ8z4ikBrVUSk3Ua8NoDqYKULPNjuddjqRvlSUyyQ= +github.com/go-kivik/kiviktest v2.0.0+incompatible/go.mod h1:JdhVyzixoYhoIDUt6hRf1yAfYyaDa5/u9SDOindDkfQ= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= @@ -118,6 +123,7 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= @@ -198,6 +204,9 @@ github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJ github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/opencontainers/runtime-spec v0.1.2-0.20190507144316-5b71a03e2700/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-tools v0.0.0-20181011054405-1d69bd0f9c39/go.mod h1:r3f7wjNzSs2extwzU3Y+6pKfobzPh+kKFJ3ofN+3nfs= +github.com/otiai10/copy v1.0.2/go.mod h1:c7RpqBkwMom4bYTSkLSym4VSJz/XtncWRAj/J4PEIMY= +github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= +github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/piprate/json-gold v0.3.0 h1:a1vHx7Q1jOO1pjCtKwTI/WCzwaQwRt9VM7apK2uy200= github.com/piprate/json-gold v0.3.0/go.mod h1:OK1z7UgtBZk06n2cDE2OSq1kffmjFFp5/2yhLLCz9UM= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -252,6 +261,7 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1: github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +gitlab.com/flimzy/testy v0.2.1/go.mod h1:YObF4cq711ubd/3U0ydRQQVz7Cnq/ChgJpVwNr/AJac= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=