This repository has been archived by the owner on Mar 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: DIDComm Inbound Transport - Support for WebSocket
Signed-off-by: Rolson Quadras <[email protected]>
- Loading branch information
1 parent
f8f1511
commit bccdcfd
Showing
9 changed files
with
402 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
Copyright SecureKey Technologies Inc. All Rights Reserved. | ||
SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package ws | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
|
||
"github.com/gorilla/websocket" | ||
|
||
"github.com/hyperledger/aries-framework-go/pkg/common/log" | ||
"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" | ||
) | ||
|
||
var logger = log.New("aries-framework/ws") | ||
|
||
const processFailureErrMsg = "failed to process the message" | ||
|
||
// Inbound http(ws) type. | ||
type Inbound struct { | ||
externalAddr string | ||
server *http.Server | ||
} | ||
|
||
// NewInbound creates a new WebSocket inbound transport instance. | ||
func NewInbound(internalAddr, externalAddr string) (*Inbound, error) { | ||
if internalAddr == "" { | ||
return nil, errors.New("websocket address is mandatory") | ||
} | ||
|
||
if externalAddr == "" { | ||
return &Inbound{externalAddr: internalAddr, server: &http.Server{Addr: internalAddr}}, nil | ||
} | ||
|
||
return &Inbound{externalAddr: externalAddr, server: &http.Server{Addr: internalAddr}}, nil | ||
} | ||
|
||
// Start the http(ws) server. | ||
func (i *Inbound) Start(prov transport.InboundProvider) error { | ||
handler, err := newInboundHandler(prov) | ||
if err != nil { | ||
return fmt.Errorf("websocket server start failed: %w", err) | ||
} | ||
|
||
i.server.Handler = handler | ||
|
||
go func() { | ||
if err := i.server.ListenAndServe(); err != http.ErrServerClosed { | ||
logger.Fatalf("websocket server start with address [%s] failed, cause: %s", i.server.Addr, err) | ||
} | ||
}() | ||
|
||
return nil | ||
} | ||
|
||
// Stop the http(ws) server. | ||
func (i *Inbound) Stop() error { | ||
if err := i.server.Shutdown(context.Background()); err != nil { | ||
return fmt.Errorf("websocket server shutdown failed: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Endpoint provides the http(ws) connection details. | ||
func (i *Inbound) Endpoint() string { | ||
return i.externalAddr | ||
} | ||
|
||
func newInboundHandler(prov transport.InboundProvider) (http.Handler, error) { | ||
if prov == nil || prov.InboundMessageHandler() == nil { | ||
logger.Errorf("Error creating a new inbound handler: message handler function is nil") | ||
return nil, errors.New("creation of inbound handler failed") | ||
} | ||
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
processRequest(w, r, prov) | ||
}), nil | ||
} | ||
|
||
func processRequest(w http.ResponseWriter, r *http.Request, prov transport.InboundProvider) { | ||
upgrader := websocket.Upgrader{} | ||
|
||
c, err := upgrader.Upgrade(w, r, nil) | ||
if err != nil { | ||
logger.Errorf("failed to upgrade the connection : %v", err) | ||
return | ||
} | ||
|
||
defer func() { | ||
err := c.Close() | ||
if err != nil { | ||
logger.Errorf("failed to close connection: %v", err) | ||
} | ||
}() | ||
|
||
for { | ||
_, message, err := c.ReadMessage() | ||
if err != nil { | ||
logger.Errorf("Error reading request message: %v", err) | ||
|
||
break | ||
} | ||
|
||
unpackMsg, err := prov.Packager().UnpackMessage(message) | ||
if err != nil { | ||
logger.Errorf("failed to unpack msg: %v", err) | ||
|
||
err = c.WriteMessage(websocket.TextMessage, []byte(processFailureErrMsg)) | ||
if err != nil { | ||
logger.Errorf("error writing the message: %v", err) | ||
} | ||
|
||
continue | ||
} | ||
|
||
messageHandler := prov.InboundMessageHandler() | ||
|
||
resp := "" | ||
|
||
err = messageHandler(unpackMsg.Message) | ||
if err != nil { | ||
logger.Errorf("incoming msg processing failed: %v", err) | ||
|
||
resp = processFailureErrMsg | ||
} | ||
|
||
err = c.WriteMessage(websocket.TextMessage, []byte(resp)) | ||
if err != nil { | ||
logger.Errorf("error writing the message: %v", err) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
/* | ||
Copyright SecureKey Technologies Inc. All Rights Reserved. | ||
SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package ws | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"net/url" | ||
"strconv" | ||
"testing" | ||
"time" | ||
|
||
"github.com/hyperledger/aries-framework-go/pkg/internal/test/transportutil" | ||
|
||
"github.com/gorilla/websocket" | ||
"github.com/stretchr/testify/require" | ||
|
||
commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport" | ||
"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" | ||
mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager" | ||
) | ||
|
||
type mockProvider struct { | ||
packagerValue commontransport.Packager | ||
} | ||
|
||
func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler { | ||
return func(message []byte) error { | ||
logger.Infof("message received is %s", string(message)) | ||
if string(message) == "invalid-data" { | ||
return errors.New("error") | ||
} | ||
return nil | ||
} | ||
} | ||
|
||
func (p *mockProvider) Packager() commontransport.Packager { | ||
return p.packagerValue | ||
} | ||
|
||
func TestInboundTransport(t *testing.T) { | ||
t.Run("test inbound transport - with host/port", func(t *testing.T) { | ||
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5)) | ||
externalAddr := "http://example.com" + port | ||
inbound, err := NewInbound("localhost"+port, externalAddr) | ||
require.NoError(t, err) | ||
require.Equal(t, externalAddr, inbound.Endpoint()) | ||
}) | ||
|
||
t.Run("test inbound transport - with host/port, no external address", func(t *testing.T) { | ||
internalAddr := "example.com" + ":" + strconv.Itoa(transportutil.GetRandomPort(5)) | ||
inbound, err := NewInbound(internalAddr, "") | ||
require.NoError(t, err) | ||
require.Equal(t, internalAddr, inbound.Endpoint()) | ||
}) | ||
|
||
t.Run("test inbound transport - without host/port", func(t *testing.T) { | ||
inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "") | ||
require.NoError(t, err) | ||
require.NotEmpty(t, inbound) | ||
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("data")}} | ||
err = inbound.Start(&mockProvider{packagerValue: mockPackager}) | ||
require.NoError(t, err) | ||
|
||
err = inbound.Stop() | ||
require.NoError(t, err) | ||
}) | ||
|
||
t.Run("test inbound transport - nil context", func(t *testing.T) { | ||
inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "") | ||
require.NoError(t, err) | ||
require.NotEmpty(t, inbound) | ||
|
||
err = inbound.Start(nil) | ||
require.Error(t, err) | ||
}) | ||
|
||
t.Run("test inbound transport - invalid port number", func(t *testing.T) { | ||
_, err := NewInbound("", "") | ||
require.Error(t, err) | ||
require.Contains(t, err.Error(), "websocket address is mandatory") | ||
}) | ||
} | ||
|
||
func TestInboundDataProcessing(t *testing.T) { | ||
t.Run("test inbound transport - multiple invocation with same client", func(t *testing.T) { | ||
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5)) | ||
|
||
// initiate inbound with port | ||
inbound, err := NewInbound(port, "") | ||
require.NoError(t, err) | ||
require.NotEmpty(t, inbound) | ||
|
||
// start server | ||
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("valid-data")}} | ||
err = inbound.Start(&mockProvider{packagerValue: mockPackager}) | ||
require.NoError(t, err) | ||
|
||
// create ws client | ||
client, cleanup := websocketClient(t, port) | ||
defer cleanup() | ||
|
||
for i := 1; i <= 5; i++ { | ||
err = client.WriteMessage(websocket.TextMessage, []byte("random")) | ||
require.NoError(t, err) | ||
|
||
messageType, val, err := client.ReadMessage() | ||
require.NoError(t, err) | ||
require.Equal(t, messageType, websocket.TextMessage) | ||
require.Equal(t, "", string(val)) | ||
} | ||
}) | ||
|
||
t.Run("test inbound transport - unpacking error", func(t *testing.T) { | ||
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5)) | ||
|
||
// initiate inbound with port | ||
inbound, err := NewInbound(port, "") | ||
require.NoError(t, err) | ||
require.NotEmpty(t, inbound) | ||
|
||
// start server | ||
mockPackager := &mockpackager.Packager{UnpackErr: errors.New("error unpacking")} | ||
err = inbound.Start(&mockProvider{packagerValue: mockPackager}) | ||
require.NoError(t, err) | ||
|
||
// create ws client | ||
client, cleanup := websocketClient(t, port) | ||
defer cleanup() | ||
|
||
err = client.WriteMessage(websocket.TextMessage, []byte("")) | ||
require.NoError(t, err) | ||
|
||
messageType, val, err := client.ReadMessage() | ||
require.NoError(t, err) | ||
require.Equal(t, messageType, websocket.TextMessage) | ||
require.Equal(t, processFailureErrMsg, string(val)) | ||
}) | ||
|
||
t.Run("test inbound transport - message handler error", func(t *testing.T) { | ||
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5)) | ||
|
||
// initiate inbound with port | ||
inbound, err := NewInbound(port, "") | ||
require.NoError(t, err) | ||
require.NotEmpty(t, inbound) | ||
|
||
// start server | ||
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("invalid-data")}} | ||
err = inbound.Start(&mockProvider{packagerValue: mockPackager}) | ||
require.NoError(t, err) | ||
|
||
// create ws client | ||
client, cleanup := websocketClient(t, port) | ||
defer cleanup() | ||
|
||
err = client.WriteMessage(websocket.TextMessage, []byte("")) | ||
require.NoError(t, err) | ||
|
||
messageType, val, err := client.ReadMessage() | ||
require.NoError(t, err) | ||
require.Equal(t, messageType, websocket.TextMessage) | ||
require.Equal(t, processFailureErrMsg, string(val)) | ||
}) | ||
} | ||
|
||
func websocketClient(t *testing.T, port string) (*websocket.Conn, func()) { | ||
require.NoError(t, transportutil.VerifyListener("localhost"+port, time.Second)) | ||
|
||
u := url.URL{Scheme: "ws", Host: "localhost" + port, Path: ""} | ||
c, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) | ||
require.NoError(t, resp.Body.Close()) | ||
|
||
return c, func() { | ||
require.NoError(t, c.Close()) | ||
} | ||
} |
Oops, something went wrong.