diff --git a/pkg/didcomm/dispatcher/outbound.go b/pkg/didcomm/dispatcher/outbound.go index 072683de8f..b6f651c84b 100644 --- a/pkg/didcomm/dispatcher/outbound.go +++ b/pkg/didcomm/dispatcher/outbound.go @@ -9,9 +9,11 @@ package dispatcher import ( "encoding/json" "fmt" + "strings" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" ) @@ -19,17 +21,23 @@ import ( type provider interface { Packager() commontransport.Packager OutboundTransports() []transport.OutboundTransport + TransportReturnRoute() string } // OutboundDispatcher dispatch msgs to destination type OutboundDispatcher struct { - outboundTransports []transport.OutboundTransport - packager commontransport.Packager + outboundTransports []transport.OutboundTransport + packager commontransport.Packager + transportReturnRoute string } // NewOutbound return new dispatcher outbound instance func NewOutbound(prov provider) *OutboundDispatcher { - return &OutboundDispatcher{outboundTransports: prov.OutboundTransports(), packager: prov.Packager()} + return &OutboundDispatcher{ + outboundTransports: prov.OutboundTransports(), + packager: prov.Packager(), + transportReturnRoute: prov.TransportReturnRoute(), + } } // Send msg @@ -39,18 +47,37 @@ func (o *OutboundDispatcher) Send(msg interface{}, senderVerKey string, des *ser continue } - bytes, err := json.Marshal(msg) + req, err := json.Marshal(msg) if err != nil { return fmt.Errorf("failed marshal to bytes: %w", err) } + // update the outbound message with transport return route option [all or thread] + if o.transportReturnRoute == decorator.TransportReturnRouteAll || + o.transportReturnRoute == decorator.TransportReturnRouteThread { + // create the decorator with the option set in the framework + transportDec := &decorator.Transport{ReturnRoute: &decorator.ReturnRoute{Value: o.transportReturnRoute}} + + transportDecJSON, jsonErr := json.Marshal(transportDec) + if jsonErr != nil { + return fmt.Errorf("json marshal : %w", jsonErr) + } + + request := string(req) + index := strings.Index(request, "{") + + // add transport route option decorator to the original request + req = []byte(request[:index+1] + string(transportDecJSON)[1:len(string(transportDecJSON))-1] + "," + + request[index+1:]) + } + packedMsg, err := o.packager.PackMessage( - &commontransport.Envelope{Message: bytes, FromVerKey: senderVerKey, ToVerKeys: des.RecipientKeys}) + &commontransport.Envelope{Message: req, FromVerKey: senderVerKey, ToVerKeys: des.RecipientKeys}) if err != nil { return fmt.Errorf("failed to pack msg: %w", err) } - _, err = v.Send(packedMsg, des.ServiceEndpoint) + _, err = v.Send(packedMsg, des) if err != nil { return fmt.Errorf("failed to send msg using http outbound transport: %w", err) } diff --git a/pkg/didcomm/dispatcher/outbound_test.go b/pkg/didcomm/dispatcher/outbound_test.go index d0dc65fbaa..0deed133a9 100644 --- a/pkg/didcomm/dispatcher/outbound_test.go +++ b/pkg/didcomm/dispatcher/outbound_test.go @@ -7,13 +7,17 @@ SPDX-License-Identifier: Apache-2.0 package dispatcher import ( + "encoding/json" + "errors" "fmt" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" mockdidcomm "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm" mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" @@ -21,8 +25,10 @@ import ( func TestOutboundDispatcher_Send(t *testing.T) { t.Run("test success", func(t *testing.T) { - o := NewOutbound(&mockProvider{packagerValue: &mockpackager.Packager{}, - outboundTransportsValue: []transport.OutboundTransport{&mockdidcomm.MockOutboundTransport{AcceptValue: true}}}) + o := NewOutbound(&mockProvider{ + packagerValue: &mockpackager.Packager{}, + outboundTransportsValue: []transport.OutboundTransport{&mockdidcomm.MockOutboundTransport{AcceptValue: true}}, + }) require.NoError(t, o.Send("data", "", &service.Destination{ServiceEndpoint: "url"})) }) @@ -52,9 +58,89 @@ func TestOutboundDispatcher_Send(t *testing.T) { }) } +func TestOutboundDispatcherTransportReturnRoute(t *testing.T) { + t.Run("transport route option - value set all", func(t *testing.T) { + transportReturnRoute := "all" + req := &decorator.Thread{ + ID: uuid.New().String(), + } + + outboundReq := struct { + *decorator.Transport + *decorator.Thread + }{ + &decorator.Transport{ReturnRoute: &decorator.ReturnRoute{Value: transportReturnRoute}}, + req, + } + expectedRequest, err := json.Marshal(outboundReq) + require.NoError(t, err) + require.NotNil(t, expectedRequest) + + o := NewOutbound(&mockProvider{ + packagerValue: &mockPackager{}, + outboundTransportsValue: []transport.OutboundTransport{&mockOutboundTransport{ + expectedRequest: string(expectedRequest)}, + }, + transportReturnRoute: transportReturnRoute, + }) + + require.NoError(t, o.Send(req, "", &service.Destination{ServiceEndpoint: "url"})) + }) + + t.Run("transport route option - value set thread", func(t *testing.T) { + transportReturnRoute := "thread" + req := &decorator.Thread{ + ID: uuid.New().String(), + } + + outboundReq := struct { + *decorator.Transport + *decorator.Thread + }{ + &decorator.Transport{ReturnRoute: &decorator.ReturnRoute{Value: transportReturnRoute}}, + req, + } + expectedRequest, err := json.Marshal(outboundReq) + require.NoError(t, err) + require.NotNil(t, expectedRequest) + + o := NewOutbound(&mockProvider{ + packagerValue: &mockPackager{}, + outboundTransportsValue: []transport.OutboundTransport{&mockOutboundTransport{ + expectedRequest: string(expectedRequest)}, + }, + transportReturnRoute: transportReturnRoute, + }) + + require.NoError(t, o.Send(req, "", &service.Destination{ServiceEndpoint: "url"})) + }) + + t.Run("transport route option - no value set", func(t *testing.T) { + req := &decorator.Thread{ + ID: uuid.New().String(), + } + + expectedRequest, err := json.Marshal(req) + require.NoError(t, err) + require.NotNil(t, expectedRequest) + + o := NewOutbound(&mockProvider{ + packagerValue: &mockPackager{}, + outboundTransportsValue: []transport.OutboundTransport{&mockOutboundTransport{ + expectedRequest: string(expectedRequest)}, + }, + transportReturnRoute: "", + }) + + require.NoError(t, o.Send(req, "", &service.Destination{ServiceEndpoint: "url"})) + }) +} + +// mockProvider mock provider type mockProvider struct { packagerValue commontransport.Packager outboundTransportsValue []transport.OutboundTransport + transportReturnRoute string } func (p *mockProvider) Packager() commontransport.Packager { @@ -64,3 +150,36 @@ func (p *mockProvider) Packager() commontransport.Packager { func (p *mockProvider) OutboundTransports() []transport.OutboundTransport { return p.outboundTransportsValue } + +func (p *mockProvider) TransportReturnRoute() string { + return p.transportReturnRoute +} + +// mockOutboundTransport mock outbound transport +type mockOutboundTransport struct { + expectedRequest string +} + +func (o *mockOutboundTransport) Send(data []byte, destination *service.Destination) (string, error) { + if string(data) != o.expectedRequest { + return "", errors.New("invalid request") + } + + return "", nil +} + +func (o *mockOutboundTransport) Accept(url string) bool { + return true +} + +// mockPackager mock packager +type mockPackager struct { +} + +func (m *mockPackager) PackMessage(e *commontransport.Envelope) ([]byte, error) { + return e.Message, nil +} + +func (m *mockPackager) UnpackMessage(encMessage []byte) (*commontransport.Envelope, error) { + return nil, nil +} diff --git a/pkg/didcomm/protocol/decorator/decorator.go b/pkg/didcomm/protocol/decorator/decorator.go index 5f480c3ff0..e93531748e 100644 --- a/pkg/didcomm/protocol/decorator/decorator.go +++ b/pkg/didcomm/protocol/decorator/decorator.go @@ -8,6 +8,17 @@ package decorator import "time" +const ( + // TransportReturnRouteNone return route option none + TransportReturnRouteNone = "none" + + // TransportReturnRouteAll return route option all + TransportReturnRouteAll = "all" + + // TransportReturnRouteThread return route option thread + TransportReturnRouteThread = "thread" +) + // Thread thread data type Thread struct { ID string `json:"thid,omitempty"` @@ -18,3 +29,14 @@ type Thread struct { type Timing struct { ExpiresTime time.Time `json:"expires_time,omitempty"` } + +// Transport transport decorator +// https://github.com/hyperledger/aries-rfcs/tree/master/features/0092-transport-return-route +type Transport struct { + ReturnRoute *ReturnRoute `json:"~transport,omitempty"` +} + +// ReturnRoute works with Transport decorator. Acceptable values - "none", "all" or "thread". +type ReturnRoute struct { + Value string `json:"~return_route,omitempty"` +} diff --git a/pkg/didcomm/transport/http/outbound.go b/pkg/didcomm/transport/http/outbound.go index 1201a893bd..5205168d37 100644 --- a/pkg/didcomm/transport/http/outbound.go +++ b/pkg/didcomm/transport/http/outbound.go @@ -14,6 +14,8 @@ import ( "net/http" "strings" "time" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" ) //go:generate testdata/scripts/openssl_env.sh testdata/scripts/generate_test_keys.sh @@ -83,10 +85,10 @@ func NewOutbound(opts ...OutboundHTTPOpt) (*OutboundHTTPClient, error) { } // Send sends a2a exchange data via HTTP (client side) -func (cs *OutboundHTTPClient) Send(data []byte, url string) (string, error) { - resp, err := cs.client.Post(url, commContentType, bytes.NewBuffer(data)) +func (cs *OutboundHTTPClient) Send(data []byte, destination *service.Destination) (string, error) { + resp, err := cs.client.Post(destination.ServiceEndpoint, commContentType, bytes.NewBuffer(data)) if err != nil { - logger.Errorf("posting DID envelope to agent failed [%s, %v]", url, err) + logger.Errorf("posting DID envelope to agent failed [%s, %v]", destination.ServiceEndpoint, err) return "", err } @@ -95,7 +97,8 @@ func (cs *OutboundHTTPClient) Send(data []byte, url string) (string, error) { if resp != nil { isStatusSuccess := resp.StatusCode == http.StatusAccepted || resp.StatusCode == http.StatusOK if !isStatusSuccess { - return "", fmt.Errorf("received unsuccessful POST HTTP status from agent [%s, %v]", url, resp.Status) + return "", fmt.Errorf("received unsuccessful POST HTTP status from agent "+ + "[%s, %v]", destination.ServiceEndpoint, resp.Status) } // handle response defer func() { diff --git a/pkg/didcomm/transport/http/outbound_test.go b/pkg/didcomm/transport/http/outbound_test.go index a53f63a147..6eeded5a9e 100644 --- a/pkg/didcomm/transport/http/outbound_test.go +++ b/pkg/didcomm/transport/http/outbound_test.go @@ -12,6 +12,8 @@ import ( "fmt" "testing" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + "github.com/stretchr/testify/require" ) @@ -72,25 +74,31 @@ func TestOutboundHTTPTransport(t *testing.T) { // test Outbound transport's api // first with an empty url - r, e := ot.Send([]byte("Hello World"), "") + r, e := ot.Send([]byte("Hello World"), prepareDestination("serverURL")) require.Error(t, e) require.Empty(t, r) // now try a bad url - r, e = ot.Send([]byte("Hello World"), "https://badurl") + r, e = ot.Send([]byte("Hello World"), prepareDestination("https://badurl")) require.Error(t, e) require.Empty(t, r) // and try with a 'bad' payload with a valid url.. - r, e = ot.Send([]byte("bad"), serverURL) + r, e = ot.Send([]byte("bad"), prepareDestination(serverURL)) require.Error(t, e) require.Empty(t, r) // finally using a valid url - r, e = ot.Send([]byte("Hello World"), serverURL) + r, e = ot.Send([]byte("Hello World"), prepareDestination(serverURL)) require.NoError(t, e) require.NotEmpty(t, r) require.True(t, ot.Accept("http://example.com")) require.False(t, ot.Accept("123:22")) } + +func prepareDestination(endPoint string) *service.Destination { + return &service.Destination{ + ServiceEndpoint: endPoint, + } +} diff --git a/pkg/didcomm/transport/transport_interface.go b/pkg/didcomm/transport/transport_interface.go index 1f372688de..abf4d9b627 100644 --- a/pkg/didcomm/transport/transport_interface.go +++ b/pkg/didcomm/transport/transport_interface.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 package transport import ( + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" ) @@ -14,7 +15,7 @@ import ( // This is the client side of the agent type OutboundTransport interface { // Send send a2a exchange data - Send(data []byte, destination string) (string, error) + Send(data []byte, destination *service.Destination) (string, error) // Accept url Accept(string) bool } diff --git a/pkg/didcomm/transport/ws/outbound.go b/pkg/didcomm/transport/ws/outbound.go index 5e0a8e3d36..896cf930f2 100644 --- a/pkg/didcomm/transport/ws/outbound.go +++ b/pkg/didcomm/transport/ws/outbound.go @@ -13,6 +13,8 @@ import ( "strings" "nhooyr.io/websocket" + + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" ) const webSocketScheme = "ws" @@ -27,12 +29,12 @@ func NewOutbound() *OutboundClient { } // Send sends a2a data via WS. -func (cs *OutboundClient) Send(data []byte, url string) (string, error) { - if url == "" { +func (cs *OutboundClient) Send(data []byte, destination *service.Destination) (string, error) { + if destination.ServiceEndpoint == "" { return "", errors.New("url is mandatory") } - client, _, err := websocket.Dial(context.Background(), url, nil) + client, _, err := websocket.Dial(context.Background(), destination.ServiceEndpoint, nil) if err != nil { return "", fmt.Errorf("websocket client : %w", err) } diff --git a/pkg/didcomm/transport/ws/outbound_test.go b/pkg/didcomm/transport/ws/outbound_test.go index cc7620f173..f3eb552698 100644 --- a/pkg/didcomm/transport/ws/outbound_test.go +++ b/pkg/didcomm/transport/ws/outbound_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" "github.com/hyperledger/aries-framework-go/pkg/internal/test/transportutil" "github.com/stretchr/testify/require" @@ -32,7 +33,7 @@ func TestClient(t *testing.T) { outbound := NewOutbound() require.NotNil(t, outbound) - _, err := outbound.Send([]byte(""), "") + _, err := outbound.Send([]byte(""), prepareDestination("")) require.Error(t, err) require.Contains(t, err.Error(), "url is mandatory") }) @@ -41,7 +42,7 @@ func TestClient(t *testing.T) { outbound := NewOutbound() require.NotNil(t, outbound) - _, err := outbound.Send([]byte(""), "ws://invalid") + _, err := outbound.Send([]byte(""), prepareDestination("ws://invalid")) require.Error(t, err) require.Contains(t, err.Error(), "websocket client") }) @@ -52,7 +53,7 @@ func TestClient(t *testing.T) { addr := startWebSocketServer(t, echo) data := "hello" - resp, err := outbound.Send([]byte(data), "ws://"+addr) + resp, err := outbound.Send([]byte(data), prepareDestination("ws://"+addr)) require.NoError(t, err) require.Equal(t, data, resp) }) @@ -64,7 +65,7 @@ func TestClient(t *testing.T) { logger.Infof("inside http path") }) - _, err := outbound.Send([]byte("ws-request"), "ws://"+addr) + _, err := outbound.Send([]byte("ws-request"), prepareDestination("ws://"+addr)) require.Error(t, err) require.Contains(t, err.Error(), "websocket client") }) @@ -96,7 +97,7 @@ func TestClient(t *testing.T) { } }) - _, err := outbound.Send([]byte("ws-request"), "ws://"+addr) + _, err := outbound.Send([]byte("ws-request"), prepareDestination("ws://"+addr)) require.Error(t, err) require.Contains(t, err.Error(), "message type is not text message") }) @@ -111,7 +112,7 @@ func TestClient(t *testing.T) { require.NoError(t, c.Close(websocket.StatusAbnormalClosure, "error")) }) - _, err := outbound.Send([]byte("ws-request"), "ws://"+addr) + _, err := outbound.Send([]byte("ws-request"), prepareDestination("ws://"+addr)) require.Error(t, err) require.Contains(t, err.Error(), "websocket read message") }) @@ -129,7 +130,7 @@ func TestClient(t *testing.T) { require.NoError(t, c.Close(websocket.StatusNormalClosure, "closing the connection")) }) - _, err := outbound.Send([]byte("ws-request"), "ws://"+addr) + _, err := outbound.Send([]byte("ws-request"), prepareDestination("ws://"+addr)) require.Error(t, err) require.Contains(t, err.Error(), "websocket read message") }) @@ -175,3 +176,9 @@ func echo(t *testing.T, w http.ResponseWriter, r *http.Request) { require.NoError(t, err) } } + +func prepareDestination(endPoint string) *service.Destination { + return &service.Destination{ + ServiceEndpoint: endPoint, + } +} diff --git a/pkg/framework/aries/framework.go b/pkg/framework/aries/framework.go index f5f8281d16..3eb725ae7a 100644 --- a/pkg/framework/aries/framework.go +++ b/pkg/framework/aries/framework.go @@ -9,6 +9,8 @@ package aries import ( "fmt" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" "github.com/hyperledger/aries-framework-go/pkg/didcomm/packager" @@ -42,6 +44,7 @@ type Aries struct { packers []packer.Packer vdriRegistry vdriapi.Registry vdri []vdriapi.VDRI + transportReturnRoute string } // Option configures the framework. @@ -128,6 +131,21 @@ func WithInboundTransport(inboundTransport transport.InboundTransport) Option { } } +// WithTransportReturnRoute injects transport return route option to the Aries framework. Acceptable values - "none", +// "all" or "thread". RFC - https://github.com/hyperledger/aries-rfcs/tree/master/features/0092-transport-return-route +func WithTransportReturnRoute(transportReturnRoute string) Option { + return func(opts *Aries) error { + if transportReturnRoute != decorator.TransportReturnRouteNone && + transportReturnRoute != decorator.TransportReturnRouteAll && + transportReturnRoute != decorator.TransportReturnRouteThread { + return fmt.Errorf("invalid transport return route option : %s", transportReturnRoute) + } + + opts.transportReturnRoute = transportReturnRoute + return nil + } +} + // WithStoreProvider injects a storage provider to the Aries framework. func WithStoreProvider(prov storage.Provider) Option { return func(opts *Aries) error { @@ -192,6 +210,7 @@ func (a *Aries) Context() (*context.Provider, error) { context.WithPacker(a.primaryPacker, a.packers...), context.WithPackager(a.packager), context.WithVDRIRegistry(a.vdriRegistry), + context.WithTransportReturnRoute(a.transportReturnRoute), ) } @@ -279,9 +298,12 @@ func createVDRI(frameworkOpts *Aries) error { } func createOutboundDispatcher(frameworkOpts *Aries) error { - ctx, err := context.New(context.WithKMS(frameworkOpts.kms), + ctx, err := context.New( + context.WithKMS(frameworkOpts.kms), context.WithOutboundTransports(frameworkOpts.outboundTransports...), - context.WithPackager(frameworkOpts.packager)) + context.WithPackager(frameworkOpts.packager), + context.WithTransportReturnRoute(frameworkOpts.transportReturnRoute), + ) if err != nil { return fmt.Errorf("context creation failed: %w", err) } diff --git a/pkg/framework/aries/framework_test.go b/pkg/framework/aries/framework_test.go index 1ccec35dfd..7b75d71885 100644 --- a/pkg/framework/aries/framework_test.go +++ b/pkg/framework/aries/framework_test.go @@ -18,6 +18,8 @@ import ( "strings" "testing" + "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" + "github.com/stretchr/testify/require" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" @@ -358,14 +360,43 @@ func TestFramework(t *testing.T) { &didcomm.MockOutboundTransport{ExpectedResponse: "data1"})) require.NoError(t, err) require.Equal(t, 2, len(aries.outboundTransports)) - r, err := aries.outboundTransports[0].Send([]byte("data"), "url") + r, err := aries.outboundTransports[0].Send([]byte("data"), &service.Destination{ServiceEndpoint: "url"}) require.NoError(t, err) require.Equal(t, "data", r) - r, err = aries.outboundTransports[1].Send([]byte("data1"), "url") + r, err = aries.outboundTransports[1].Send([]byte("data1"), &service.Destination{ServiceEndpoint: "url"}) require.NoError(t, err) require.Equal(t, "data1", r) require.NoError(t, aries.Close()) }) + + t.Run("test new with transport return route", func(t *testing.T) { + path, cleanup := generateTempDir(t) + defer cleanup() + dbPath = path + + transportReturnRoute := decorator.TransportReturnRouteAll + aries, err := New(WithTransportReturnRoute(transportReturnRoute)) + require.NoError(t, err) + require.Equal(t, transportReturnRoute, aries.transportReturnRoute) + require.NoError(t, aries.Close()) + + transportReturnRoute = decorator.TransportReturnRouteThread + aries, err = New(WithTransportReturnRoute(transportReturnRoute)) + require.NoError(t, err) + require.Equal(t, transportReturnRoute, aries.transportReturnRoute) + require.NoError(t, aries.Close()) + + transportReturnRoute = decorator.TransportReturnRouteNone + aries, err = New(WithTransportReturnRoute(transportReturnRoute)) + require.NoError(t, err) + require.Equal(t, transportReturnRoute, aries.transportReturnRoute) + require.NoError(t, aries.Close()) + + transportReturnRoute = "invalid-transport-route" + _, err = New(WithTransportReturnRoute(transportReturnRoute)) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid transport return route option : "+transportReturnRoute) + }) } func Test_Packager(t *testing.T) { diff --git a/pkg/framework/context/context.go b/pkg/framework/context/context.go index 9df7be3a6a..9f8af7e407 100644 --- a/pkg/framework/context/context.go +++ b/pkg/framework/context/context.go @@ -33,6 +33,7 @@ type Provider struct { outboundDispatcher dispatcher.Outbound outboundTransports []transport.OutboundTransport vdriRegistry vdriapi.Registry + transportReturnRoute string } // New instantiates a new context provider. @@ -134,6 +135,11 @@ func (p *Provider) VDRIRegistry() vdriapi.Registry { return p.vdriRegistry } +// TransportReturnRoute returns transport return route +func (p *Provider) TransportReturnRoute() string { + return p.transportReturnRoute +} + // ProviderOption configures the framework. type ProviderOption func(opts *Provider) error @@ -153,6 +159,14 @@ func WithOutboundDispatcher(outboundDispatcher dispatcher.Outbound) ProviderOpti } } +// WithTransportReturnRoute injects transport return route option to the Aries framework. +func WithTransportReturnRoute(transportReturnRoute string) ProviderOption { + return func(opts *Provider) error { + opts.transportReturnRoute = transportReturnRoute + return nil + } +} + // WithProtocolServices injects a protocol services into the context. func WithProtocolServices(services ...dispatcher.Service) ProviderOption { return func(opts *Provider) error { diff --git a/pkg/framework/context/context_test.go b/pkg/framework/context/context_test.go index 2d16ba665a..838ac83577 100644 --- a/pkg/framework/context/context_test.go +++ b/pkg/framework/context/context_test.go @@ -200,11 +200,18 @@ func TestNewProvider(t *testing.T) { &mockdidcomm.MockOutboundTransport{ExpectedResponse: "data1"})) require.NoError(t, err) require.Equal(t, 2, len(prov.outboundTransports)) - r, err := prov.outboundTransports[0].Send([]byte("data"), "url") + r, err := prov.outboundTransports[0].Send([]byte("data"), &service.Destination{ServiceEndpoint: "url"}) require.NoError(t, err) require.Equal(t, "data", r) - r, err = prov.outboundTransports[1].Send([]byte("data1"), "url") + r, err = prov.outboundTransports[1].Send([]byte("data1"), &service.Destination{ServiceEndpoint: "url"}) require.NoError(t, err) require.Equal(t, "data1", r) }) + + t.Run("test new with transport return route", func(t *testing.T) { + transportReturnRoute := "none" + prov, err := New(WithTransportReturnRoute(transportReturnRoute)) + require.NoError(t, err) + require.Equal(t, transportReturnRoute, prov.TransportReturnRoute()) + }) } diff --git a/pkg/internal/mock/didcomm/mock_transport.go b/pkg/internal/mock/didcomm/mock_transport.go index fd895e6054..6a30f3a74e 100644 --- a/pkg/internal/mock/didcomm/mock_transport.go +++ b/pkg/internal/mock/didcomm/mock_transport.go @@ -5,6 +5,8 @@ SPDX-License-Identifier: Apache-2.0 package didcomm +import "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + // MockOutboundTransport mock outbound transport structure type MockOutboundTransport struct { ExpectedResponse string @@ -18,7 +20,7 @@ func NewMockOutboundTransport(expectedResponse string) *MockOutboundTransport { } // Send implementation of MockOutboundTransport.Send api -func (transport *MockOutboundTransport) Send(data []byte, destination string) (string, error) { +func (transport *MockOutboundTransport) Send(data []byte, destination *service.Destination) (string, error) { return transport.ExpectedResponse, transport.SendErr }